summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/OpenaiChat.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth/OpenaiChat.py')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py103
1 files changed, 55 insertions, 48 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index e507404b..3d19e003 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -5,15 +5,15 @@ import uuid
import json
import os
import base64
+import time
from aiohttp import ClientWebSocketResponse
try:
from py_arkose_generator.arkose import get_values_for_request
- from async_property import async_cached_property
- has_requirements = True
+ has_arkose_generator = True
except ImportError:
- async_cached_property = property
- has_requirements = False
+ has_arkose_generator = False
+
try:
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
@@ -33,7 +33,7 @@ from ... import debug
class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service"""
-
+
url = "https://chat.openai.com"
working = True
needs_auth = True
@@ -47,7 +47,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
_api_key: str = None
_headers: dict = None
_cookies: Cookies = None
- _last_message: int = 0
+ _expires: int = None
@classmethod
async def create(
@@ -80,7 +80,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
A Response object that contains the generator, action, messages, and options
"""
# Add the user input to the messages list
- if prompt:
+ if prompt is not None:
messages.append({
"role": "user",
"content": prompt
@@ -102,7 +102,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
messages,
kwargs
)
-
+
@classmethod
async def upload_image(
cls,
@@ -162,7 +162,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
response.raise_for_status()
image_data["download_url"] = (await response.json())["download_url"]
return ImageRequest(image_data)
-
+
@classmethod
async def get_default_model(cls, session: StreamSession, headers: dict):
"""
@@ -185,7 +185,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return cls.default_model
raise RuntimeError(f"Response: {data}")
return cls.default_model
-
+
@classmethod
def create_messages(cls, messages: Messages, image_request: ImageRequest = None):
"""
@@ -334,9 +334,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Raises:
RuntimeError: If an error occurs during processing.
"""
- if not has_requirements:
- raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package')
- if not parent_id:
+ if parent_id is None:
parent_id = str(uuid.uuid4())
# Read api_key from arguments
@@ -348,7 +346,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
timeout=timeout
) as session:
# Read api_key and cookies from cache / browser config
- if cls._headers is None:
+ if cls._headers is None or cls._expires is None or time.time() > cls._expires:
if api_key is None:
# Read api_key from cookies
cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
@@ -357,8 +355,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
else:
api_key = cls._api_key if api_key is None else api_key
# Read api_key with session cookies
- if api_key is None and cookies:
- api_key = await cls.fetch_access_token(session, cls._headers)
+ #if api_key is None and cookies:
+ # api_key = await cls.fetch_access_token(session, cls._headers)
# Load default model
if cls.default_model is None and api_key is not None:
try:
@@ -384,6 +382,19 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
else:
cls._set_api_key(api_key)
+ async with session.post(
+ f"{cls.url}/backend-api/sentinel/chat-requirements",
+ json={"conversation_mode_kind": "primary_assistant"},
+ headers=cls._headers
+ ) as response:
+ response.raise_for_status()
+ data = await response.json()
+ need_arkose = data["arkose"]["required"]
+ chat_token = data["token"]
+
+ if need_arkose and not has_arkose_generator:
+ raise MissingRequirementsError('Install "py-arkose-generator" package')
+
try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
@@ -394,12 +405,10 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
fields = ResponseFields()
while fields.finish_reason is None:
- arkose_token = await cls.get_arkose_token(session)
conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
parent_id = parent_id if fields.message_id is None else fields.message_id
data = {
"action": action,
- "arkose_token": arkose_token,
"conversation_mode": {"kind": "primary_assistant"},
"force_paragen": False,
"force_rate_limit": False,
@@ -417,7 +426,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
json=data,
headers={
"Accept": "text/event-stream",
- "OpenAI-Sentinel-Arkose-Token": arkose_token,
+ **({"OpenAI-Sentinel-Arkose-Token": await cls.get_arkose_token(session)} if need_arkose else {}),
+ "OpenAI-Sentinel-Chat-Requirements-Token": chat_token,
**cls._headers
}
) as response:
@@ -437,17 +447,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await cls.delete_conversation(session, cls._headers, fields.conversation_id)
@staticmethod
- async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator:
+ async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator:
while True:
- yield base64.b64decode((await ws.receive_json())["body"])
+ message = await ws.receive_json()
+ if message["conversation_id"] == conversation_id:
+ yield base64.b64decode(message["body"])
@classmethod
async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator:
last_message: int = 0
async for message in messages:
if message.startswith(b'{"wss_url":'):
- async with session.ws_connect(json.loads(message)["wss_url"]) as ws:
- async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields):
+ message = json.loads(message)
+ async with session.ws_connect(message["wss_url"]) as ws:
+ async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields):
yield chunk
break
async for chunk in cls.iter_messages_line(session, message, fields):
@@ -467,6 +480,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
+ if fields.finish_reason is None:
+ fields.finish_reason = "error"
return
try:
line = json.loads(line[6:])
@@ -589,22 +604,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
@classmethod
def _set_api_key(cls, api_key: str):
cls._api_key = api_key
+ cls._expires = int(time.time()) + 60 * 60 * 4
cls._headers["Authorization"] = f"Bearer {api_key}"
@classmethod
def _update_cookie_header(cls):
cls._headers["Cookie"] = cls._format_cookies(cls._cookies)
-class EndTurn:
- """
- Class to represent the end of a conversation turn.
- """
- def __init__(self):
- self.is_end = False
-
- def end(self):
- self.is_end = True
-
class ResponseFields:
"""
Class to encapsulate response fields.
@@ -633,8 +639,8 @@ class Response():
self._options = options
self._fields = None
- async def generator(self):
- if self._generator:
+ async def generator(self) -> AsyncIterator:
+ if self._generator is not None:
self._generator = None
chunks = []
async for chunk in self._generator:
@@ -644,27 +650,29 @@ class Response():
yield chunk
chunks.append(str(chunk))
self._message = "".join(chunks)
- if not self._fields:
+ if self._fields is None:
raise RuntimeError("Missing response fields")
- self.is_end = self._fields.end_turn
+ self.is_end = self._fields.finish_reason == "stop"
def __aiter__(self):
return self.generator()
- @async_cached_property
- async def message(self) -> str:
+ async def get_message(self) -> str:
await self.generator()
return self._message
- async def get_fields(self):
+ async def get_fields(self) -> dict:
await self.generator()
- return {"conversation_id": self._fields.conversation_id, "parent_id": self._fields.message_id}
+ return {
+ "conversation_id": self._fields.conversation_id,
+ "parent_id": self._fields.message_id
+ }
- async def next(self, prompt: str, **kwargs) -> Response:
+ async def create_next(self, prompt: str, **kwargs) -> Response:
return await OpenaiChat.create(
**self._options,
prompt=prompt,
- messages=await self.messages,
+ messages=await self.get_messages(),
action="next",
**await self.get_fields(),
**kwargs
@@ -676,13 +684,13 @@ class Response():
raise RuntimeError("Can't continue message. Message already finished.")
return await OpenaiChat.create(
**self._options,
- messages=await self.messages,
+ messages=await self.get_messages(),
action="continue",
**fields,
**kwargs
)
- async def variant(self, **kwargs) -> Response:
+ async def create_variant(self, **kwargs) -> Response:
if self.action != "next":
raise RuntimeError("Can't create variant from continue or variant request.")
return await OpenaiChat.create(
@@ -693,8 +701,7 @@ class Response():
**kwargs
)
- @async_cached_property
- async def messages(self):
+ async def get_messages(self) -> list:
messages = self._messages
- messages.append({"role": "assistant", "content": await self.message})
+ messages.append({"role": "assistant", "content": await self.message()})
return messages \ No newline at end of file