diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-11-17 18:32:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-17 18:32:51 +0100 |
commit | 275574d71ece22975de7df0e226d466a2056605b (patch) | |
tree | 3f113ea8beb7c43920871019512aeb8de9d1b4f7 /g4f/Provider/needs_auth | |
parent | Fix api streaming, fix AsyncClient (#2357) (diff) | |
parent | Add nodriver to Gemini provider, (diff) | |
download | gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.gz gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.bz2 gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.lz gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.xz gpt4free-275574d71ece22975de7df0e226d466a2056605b.tar.zst gpt4free-275574d71ece22975de7df0e226d466a2056605b.zip |
Diffstat (limited to 'g4f/Provider/needs_auth')
-rw-r--r-- | g4f/Provider/needs_auth/Gemini.py | 80 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/GeminiPro.py | 4 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/HuggingFace.py | 29 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/MetaAI.py | 3 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/MetaAIAccount.py | 2 | ||||
-rw-r--r-- | g4f/Provider/needs_auth/__init__.py | 1 |
6 files changed, 40 insertions, 79 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py index dad54c84..781aa410 100644 --- a/g4f/Provider/needs_auth/Gemini.py +++ b/g4f/Provider/needs_auth/Gemini.py @@ -6,24 +6,20 @@ import random import re from aiohttp import ClientSession, BaseConnector - -from ..helper import get_connector - try: - from selenium.webdriver.common.by import By - from selenium.webdriver.support.ui import WebDriverWait - from selenium.webdriver.support import expected_conditions as EC + import nodriver + has_nodriver = True except ImportError: - pass + has_nodriver = False from ... import debug from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator from ..base_provider import AsyncGeneratorProvider, BaseConversation from ..helper import format_prompt, get_cookies from ...requests.raise_for_status import raise_for_status -from ...errors import MissingAuthError, MissingRequirementsError +from ...requests.aiohttp import get_connector +from ...errors import MissingAuthError from ...image import ImageResponse, to_bytes -from ...webdriver import get_browser, get_driver_cookies REQUEST_HEADERS = { "authority": "gemini.google.com", @@ -64,9 +60,9 @@ class Gemini(AsyncGeneratorProvider): @classmethod async def nodriver_login(cls, proxy: str = None) -> AsyncIterator[str]: - try: - import nodriver as uc - except ImportError: + if not has_nodriver: + if debug.logging: + print("Skip nodriver login in Gemini provider") return try: from platformdirs import user_config_dir @@ -75,7 +71,7 @@ class Gemini(AsyncGeneratorProvider): user_data_dir = None if debug.logging: print(f"Open nodriver with user_dir: {user_data_dir}") - browser = await uc.start( + browser = await nodriver.start( user_data_dir=user_data_dir, browser_args=None if proxy is None else [f"--proxy-server={proxy}"], ) @@ -92,30 +88,6 @@ class Gemini(AsyncGeneratorProvider): cls._cookies = cookies @classmethod - async def webdriver_login(cls, proxy: str) -> AsyncIterator[str]: - driver = None - try: - driver = get_browser(proxy=proxy) - try: - driver.get(f"{cls.url}/app") - WebDriverWait(driver, 5).until( - EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")) - ) - except: - login_url = os.environ.get("G4F_LOGIN_URL") - if login_url: - yield f"Please login: [Google Gemini]({login_url})\n\n" - WebDriverWait(driver, 240).until( - EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")) - ) - cls._cookies = get_driver_cookies(driver) - except MissingRequirementsError: - pass - finally: - if driver: - driver.close() - - @classmethod async def create_async_generator( cls, model: str, @@ -143,9 +115,6 @@ class Gemini(AsyncGeneratorProvider): if not cls._snlm0e: async for chunk in cls.nodriver_login(proxy): yield chunk - if cls._cookies is None: - async for chunk in cls.webdriver_login(proxy): - yield chunk if not cls._snlm0e: if cls._cookies is None or "__Secure-1PSID" not in cls._cookies: raise MissingAuthError('Missing "__Secure-1PSID" cookie') @@ -211,20 +180,23 @@ class Gemini(AsyncGeneratorProvider): yield content[last_content_len:] last_content_len = len(content) if image_prompt: - images = [image[0][3][3] for image in response_part[4][0][12][7][0]] - if response_format == "b64_json": - yield ImageResponse(images, image_prompt, {"cookies": cls._cookies}) - else: - resolved_images = [] - preview = [] - for image in images: - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - async with client.get(image, allow_redirects=False) as fetch: - image = fetch.headers["location"] - resolved_images.append(image) - preview.append(image.replace('=s512', '=s200')) - yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) + try: + images = [image[0][3][3] for image in response_part[4][0][12][7][0]] + if response_format == "b64_json": + yield ImageResponse(images, image_prompt, {"cookies": cls._cookies}) + else: + resolved_images = [] + preview = [] + for image in images: + async with client.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + async with client.get(image, allow_redirects=False) as fetch: + image = fetch.headers["location"] + resolved_images.append(image) + preview.append(image.replace('=s512', '=s200')) + yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview}) + except TypeError: + pass def build_request( prompt: str, diff --git a/g4f/Provider/needs_auth/GeminiPro.py b/g4f/Provider/needs_auth/GeminiPro.py index 7e52a194..a7f1e0aa 100644 --- a/g4f/Provider/needs_auth/GeminiPro.py +++ b/g4f/Provider/needs_auth/GeminiPro.py @@ -16,9 +16,9 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): working = True supports_message_history = True needs_auth = True - default_model = "gemini-1.5-pro-latest" + default_model = "gemini-1.5-pro" default_vision_model = default_model - models = [default_model, "gemini-pro", "gemini-pro-vision", "gemini-1.5-flash"] + models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"] @classmethod async def create_async_generator( diff --git a/g4f/Provider/needs_auth/HuggingFace.py b/g4f/Provider/needs_auth/HuggingFace.py index ecc75d1c..35270e60 100644 --- a/g4f/Provider/needs_auth/HuggingFace.py +++ b/g4f/Provider/needs_auth/HuggingFace.py @@ -1,13 +1,11 @@ from __future__ import annotations import json -from aiohttp import ClientSession, BaseConnector from ...typing import AsyncResult, Messages from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..helper import get_connector -from ...errors import RateLimitError, ModelNotFoundError -from ...requests.raise_for_status import raise_for_status +from ...errors import ModelNotFoundError +from ...requests import StreamSession, raise_for_status from ..HuggingChat import HuggingChat @@ -21,22 +19,12 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): model_aliases = HuggingChat.model_aliases @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - - @classmethod async def create_async_generator( cls, model: str, messages: Messages, stream: bool = True, proxy: str = None, - connector: BaseConnector = None, api_base: str = "https://api-inference.huggingface.co", api_key: str = None, max_new_tokens: int = 1024, @@ -62,7 +50,6 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): } if api_key is not None: headers["Authorization"] = f"Bearer {api_key}" - params = { "return_full_text": False, "max_new_tokens": max_new_tokens, @@ -70,10 +57,9 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): **kwargs } payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream} - - async with ClientSession( + async with StreamSession( headers=headers, - connector=get_connector(connector, proxy) + proxy=proxy ) as session: async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response: if response.status == 404: @@ -81,7 +67,7 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): await raise_for_status(response) if stream: first = True - async for line in response.content: + async for line in response.iter_lines(): if line.startswith(b"data:"): data = json.loads(line[5:]) if not data["token"]["special"]: @@ -89,7 +75,8 @@ class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin): if first: first = False chunk = chunk.lstrip() - yield chunk + if chunk: + yield chunk else: yield (await response.json())[0]["generated_text"].strip() @@ -101,4 +88,4 @@ def format_prompt(messages: Messages) -> str: for idx, message in enumerate(messages) if message["role"] == "assistant" ]) - return f"{history}<s>[INST] {question} [/INST]" + return f"{history}<s>[INST] {question} [/INST]"
\ No newline at end of file diff --git a/g4f/Provider/needs_auth/MetaAI.py b/g4f/Provider/needs_auth/MetaAI.py index 4b730abd..568de701 100644 --- a/g4f/Provider/needs_auth/MetaAI.py +++ b/g4f/Provider/needs_auth/MetaAI.py @@ -79,7 +79,6 @@ class MetaAI(AsyncGeneratorProvider, ProviderModelMixin): self.access_token = None if self.access_token is None and cookies is None: await self.update_access_token() - if self.access_token is None: url = "https://www.meta.ai/api/graphql/" payload = {"lsd": self.lsd, 'fb_dtsg': self.dtsg} @@ -128,6 +127,8 @@ class MetaAI(AsyncGeneratorProvider, ProviderModelMixin): json_line = json.loads(line) except json.JSONDecodeError: continue + if json_line.get("errors"): + raise RuntimeError("\n".join([error.get("message") for error in json_line.get("errors")])) bot_response_message = json_line.get("data", {}).get("node", {}).get("bot_response_message", {}) streaming_state = bot_response_message.get("streaming_state") fetch_id = bot_response_message.get("fetch_id") or fetch_id diff --git a/g4f/Provider/needs_auth/MetaAIAccount.py b/g4f/Provider/needs_auth/MetaAIAccount.py index 2d54f3e0..0a586006 100644 --- a/g4f/Provider/needs_auth/MetaAIAccount.py +++ b/g4f/Provider/needs_auth/MetaAIAccount.py @@ -2,7 +2,7 @@ from __future__ import annotations from ...typing import AsyncResult, Messages, Cookies from ..helper import format_prompt, get_cookies -from ..MetaAI import MetaAI +from .MetaAI import MetaAI class MetaAIAccount(MetaAI): needs_auth = True diff --git a/g4f/Provider/needs_auth/__init__.py b/g4f/Provider/needs_auth/__init__.py index 26c50c0a..ace53876 100644 --- a/g4f/Provider/needs_auth/__init__.py +++ b/g4f/Provider/needs_auth/__init__.py @@ -11,6 +11,7 @@ from .GeminiPro import GeminiPro from .Groq import Groq from .HuggingFace import HuggingFace from .MetaAI import MetaAI +from .MetaAIAccount import MetaAIAccount from .OpenaiAPI import OpenaiAPI from .OpenaiChat import OpenaiChat from .PerplexityApi import PerplexityApi |