diff options
Diffstat (limited to 'g4f/Provider/needs_auth')
-rw-r--r-- | g4f/Provider/needs_auth/OpenaiChat.py | 103 |
1 files changed, 55 insertions, 48 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index e507404b..3d19e003 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -5,15 +5,15 @@ import uuid import json import os import base64 +import time from aiohttp import ClientWebSocketResponse try: from py_arkose_generator.arkose import get_values_for_request - from async_property import async_cached_property - has_requirements = True + has_arkose_generator = True except ImportError: - async_cached_property = property - has_requirements = False + has_arkose_generator = False + try: from selenium.webdriver.common.by import By from selenium.webdriver.support.ui import WebDriverWait @@ -33,7 +33,7 @@ from ... import debug class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): """A class for creating and managing conversations with OpenAI chat service""" - + url = "https://chat.openai.com" working = True needs_auth = True @@ -47,7 +47,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): _api_key: str = None _headers: dict = None _cookies: Cookies = None - _last_message: int = 0 + _expires: int = None @classmethod async def create( @@ -80,7 +80,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): A Response object that contains the generator, action, messages, and options """ # Add the user input to the messages list - if prompt: + if prompt is not None: messages.append({ "role": "user", "content": prompt @@ -102,7 +102,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): messages, kwargs ) - + @classmethod async def upload_image( cls, @@ -162,7 +162,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): response.raise_for_status() image_data["download_url"] = (await response.json())["download_url"] return ImageRequest(image_data) - + @classmethod async def get_default_model(cls, session: StreamSession, headers: dict): """ @@ -185,7 +185,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): return cls.default_model raise RuntimeError(f"Response: {data}") return cls.default_model - + @classmethod def create_messages(cls, messages: Messages, image_request: ImageRequest = None): """ @@ -334,9 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): Raises: RuntimeError: If an error occurs during processing. """ - if not has_requirements: - raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package') - if not parent_id: + if parent_id is None: parent_id = str(uuid.uuid4()) # Read api_key from arguments @@ -348,7 +346,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): timeout=timeout ) as session: # Read api_key and cookies from cache / browser config - if cls._headers is None: + if cls._headers is None or cls._expires is None or time.time() > cls._expires: if api_key is None: # Read api_key from cookies cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies @@ -357,8 +355,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): else: api_key = cls._api_key if api_key is None else api_key # Read api_key with session cookies - if api_key is None and cookies: - api_key = await cls.fetch_access_token(session, cls._headers) + #if api_key is None and cookies: + # api_key = await cls.fetch_access_token(session, cls._headers) # Load default model if cls.default_model is None and api_key is not None: try: @@ -384,6 +382,19 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): else: cls._set_api_key(api_key) + async with session.post( + f"{cls.url}/backend-api/sentinel/chat-requirements", + json={"conversation_mode_kind": "primary_assistant"}, + headers=cls._headers + ) as response: + response.raise_for_status() + data = await response.json() + need_arkose = data["arkose"]["required"] + chat_token = data["token"] + + if need_arkose and not has_arkose_generator: + raise MissingRequirementsError('Install "py-arkose-generator" package') + try: image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None except Exception as e: @@ -394,12 +405,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha") fields = ResponseFields() while fields.finish_reason is None: - arkose_token = await cls.get_arkose_token(session) conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id parent_id = parent_id if fields.message_id is None else fields.message_id data = { "action": action, - "arkose_token": arkose_token, "conversation_mode": {"kind": "primary_assistant"}, "force_paragen": False, "force_rate_limit": False, @@ -417,7 +426,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): json=data, headers={ "Accept": "text/event-stream", - "OpenAI-Sentinel-Arkose-Token": arkose_token, + **({"OpenAI-Sentinel-Arkose-Token": await cls.get_arkose_token(session)} if need_arkose else {}), + "OpenAI-Sentinel-Chat-Requirements-Token": chat_token, **cls._headers } ) as response: @@ -437,17 +447,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await cls.delete_conversation(session, cls._headers, fields.conversation_id) @staticmethod - async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator: + async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator: while True: - yield base64.b64decode((await ws.receive_json())["body"]) + message = await ws.receive_json() + if message["conversation_id"] == conversation_id: + yield base64.b64decode(message["body"]) @classmethod async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator: last_message: int = 0 async for message in messages: if message.startswith(b'{"wss_url":'): - async with session.ws_connect(json.loads(message)["wss_url"]) as ws: - async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields): + message = json.loads(message) + async with session.ws_connect(message["wss_url"]) as ws: + async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields): yield chunk break async for chunk in cls.iter_messages_line(session, message, fields): @@ -467,6 +480,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): if not line.startswith(b"data: "): return elif line.startswith(b"data: [DONE]"): + if fields.finish_reason is None: + fields.finish_reason = "error" return try: line = json.loads(line[6:]) @@ -589,22 +604,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def _set_api_key(cls, api_key: str): cls._api_key = api_key + cls._expires = int(time.time()) + 60 * 60 * 4 cls._headers["Authorization"] = f"Bearer {api_key}" @classmethod def _update_cookie_header(cls): cls._headers["Cookie"] = cls._format_cookies(cls._cookies) -class EndTurn: - """ - Class to represent the end of a conversation turn. - """ - def __init__(self): - self.is_end = False - - def end(self): - self.is_end = True - class ResponseFields: """ Class to encapsulate response fields. @@ -633,8 +639,8 @@ class Response(): self._options = options self._fields = None - async def generator(self): - if self._generator: + async def generator(self) -> AsyncIterator: + if self._generator is not None: self._generator = None chunks = [] async for chunk in self._generator: @@ -644,27 +650,29 @@ class Response(): yield chunk chunks.append(str(chunk)) self._message = "".join(chunks) - if not self._fields: + if self._fields is None: raise RuntimeError("Missing response fields") - self.is_end = self._fields.end_turn + self.is_end = self._fields.finish_reason == "stop" def __aiter__(self): return self.generator() - @async_cached_property - async def message(self) -> str: + async def get_message(self) -> str: await self.generator() return self._message - async def get_fields(self): + async def get_fields(self) -> dict: await self.generator() - return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id} + return { + "conversation_id": self._fields.conversation_id, + "parent_id": self._fields.message_id + } - async def next(self, prompt: str, **kwargs) -> Response: + async def create_next(self, prompt: str, **kwargs) -> Response: return await OpenaiChat.create( **self._options, prompt=prompt, - messages=await self.messages, + messages=await self.get_messages(), action="next", **await self.get_fields(), **kwargs @@ -676,13 +684,13 @@ class Response(): raise RuntimeError("Can't continue message. Message already finished.") return await OpenaiChat.create( **self._options, - messages=await self.messages, + messages=await self.get_messages(), action="continue", **fields, **kwargs ) - async def variant(self, **kwargs) -> Response: + async def create_variant(self, **kwargs) -> Response: if self.action != "next": raise RuntimeError("Can't create variant from continue or variant request.") return await OpenaiChat.create( @@ -693,8 +701,7 @@ class Response(): **kwargs ) - @async_cached_property - async def messages(self): + async def get_messages(self) -> list: messages = self._messages - messages.append({"role": "assistant", "content": await self.message}) + messages.append({"role": "assistant", "content": await self.message()}) return messages
\ No newline at end of file |