diff options
Diffstat (limited to 'g4f/Provider/needs_auth/OpenaiChat.py')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 52 |
1 files changed, 32 insertions, 20 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index 85866272..b07bd49b 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -1,21 +1,32 @@ from __future__ import annotations + import asyncio import uuid import json import os -from py_arkose_generator.arkose import get_values_for_request -from async_property import async_cached_property -from selenium.webdriver.common.by import By -from selenium.webdriver.support.ui import WebDriverWait -from selenium.webdriver.support import expected_conditions as EC +try: + from py_arkose_generator.arkose import get_values_for_request + from async_property import async_cached_property + has_requirements = True +except ImportError: + async_cached_property = property + has_requirements = False +try: + from selenium.webdriver.common.by import By + from selenium.webdriver.support.ui import WebDriverWait + from selenium.webdriver.support import expected_conditions as EC + has_webdriver = True +except ImportError: + has_webdriver = False from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin from ..helper import format_prompt, get_cookies from ...webdriver import get_browser, get_driver_cookies -from ...typing import AsyncResult, Messages +from ...typing import AsyncResult, Messages, Cookies, ImageType from ...requests import StreamSession -from ...image import to_image, to_bytes, ImageType, ImageResponse +from ...image import to_image, to_bytes, ImageResponse, ImageRequest +from ...errors import MissingRequirementsError, MissingAccessToken class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): @@ -27,12 +38,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): supports_gpt_35_turbo = True supports_gpt_4 = True default_model = None - models = ["text-davinci-002-render-sha", "gpt-4", "gpt-4-gizmo"] - model_aliases = { - "gpt-3.5-turbo": "text-davinci-002-render-sha", - } + models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"] _cookies: dict = {} - _default_model: str = None @classmethod async def create( @@ -94,7 +101,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): session: StreamSession, headers: dict, image: ImageType - ) -> ImageResponse: + ) -> ImageRequest: """ Upload an image to the service and get the download URL @@ -104,7 +111,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): image: The image to upload, either a PIL Image object or a bytes object Returns: - An ImageResponse object that contains the download URL, file name, and other data + An ImageRequest object that contains the download URL, file name, and other data """ # Convert the image to a PIL Image object and get the extension image = to_image(image) @@ -145,7 +152,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): ) as response: response.raise_for_status() download_url = (await response.json())["download_url"] - return ImageResponse(download_url, image_data["file_name"], image_data) + return ImageRequest(download_url, image_data["file_name"], image_data) @classmethod async def get_default_model(cls, session: StreamSession, headers: dict): @@ -169,7 +176,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return cls.default_model @classmethod - def create_messages(cls, prompt: str, image_response: ImageResponse = None): + def create_messages(cls, prompt: str, image_response: ImageRequest = None): """ Create a list of messages for the user input @@ -282,7 +289,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): proxy: str = None, timeout: int = 120, access_token: str = None, - cookies: dict = None, + cookies: Cookies = None, auto_continue: bool = False, history_disabled: bool = True, action: str = "next", @@ -317,12 +324,16 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ + if not has_requirements: + raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package') if not parent_id: parent_id = str(uuid.uuid4()) if not cookies: - cookies = cls._cookies or get_cookies("chat.openai.com") + cookies = cls._cookies or get_cookies("chat.openai.com", False) if not access_token and "access_token" in cookies: access_token = cookies["access_token"] + if not access_token and not has_webdriver: + raise MissingAccessToken(f'Missing "access_token"') if not access_token: login_url = os.environ.get("G4F_LOGIN_URL") if login_url: @@ -331,7 +342,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): cls._cookies = cookies headers = {"Authorization": f"Bearer {access_token}"} - async with StreamSession( proxies={"https": proxy}, impersonate="chrome110", @@ -346,13 +356,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): except Exception as e: yield e end_turn = EndTurn() + model = cls.get_model(model or await cls.get_default_model(session, headers)) + model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model while not end_turn.is_end: data = { "action": action, "arkose_token": await cls.get_arkose_token(session), "conversation_id": conversation_id, "parent_message_id": parent_id, - "model": cls.get_model(model or await cls.get_default_model(session, headers)), + "model": model, "history_and_training_disabled": history_disabled and not auto_continue, } if action != "continue": |