summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/OpenaiChat.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py120
1 files changed, 75 insertions, 45 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 556c3d9b..8c2668ab 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -20,9 +20,9 @@ except ImportError:
pass
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from ..helper import format_prompt, get_cookies
-from ...webdriver import get_browser, get_driver_cookies
-from ...typing import AsyncResult, Messages, Cookies, ImageType
+from ..helper import get_cookies
+from ...webdriver import get_browser
+from ...typing import AsyncResult, Messages, Cookies, ImageType, Union
from ...requests import get_args_from_browser
from ...requests.aiohttp import StreamSession
from ...image import to_image, to_bytes, ImageResponse, ImageRequest
@@ -37,6 +37,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
needs_auth = True
supports_gpt_35_turbo = True
supports_gpt_4 = True
+ supports_message_history = True
default_model = None
models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
model_aliases = {"text-davinci-002-render-sha": "gpt-3.5-turbo"}
@@ -170,6 +171,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""
if not cls.default_model:
async with session.get(f"{cls.url}/backend-api/models", headers=headers) as response:
+ cls._update_request_args(session)
response.raise_for_status()
data = await response.json()
if "categories" in data:
@@ -179,7 +181,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return cls.default_model
@classmethod
- def create_messages(cls, prompt: str, image_request: ImageRequest = None):
+ def create_messages(cls, messages: Messages, image_request: ImageRequest = None):
"""
Create a list of messages for the user input
@@ -190,31 +192,27 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
Returns:
A list of messages with the user input and the image, if any
"""
+ # Create a message object with the user role and the content
+ messages = [{
+ "id": str(uuid.uuid4()),
+ "author": {"role": message["role"]},
+ "content": {"content_type": "text", "parts": [message["content"]]},
+ } for message in messages]
+
# Check if there is an image response
- if not image_request:
- # Create a content object with the text type and the prompt
- content = {"content_type": "text", "parts": [prompt]}
- else:
- # Create a content object with the multimodal text type and the image and the prompt
- content = {
+ if image_request:
+ # Change content in last user message
+ messages[-1]["content"] = {
"content_type": "multimodal_text",
"parts": [{
"asset_pointer": f"file-service://{image_request.get('file_id')}",
"height": image_request.get("height"),
"size_bytes": image_request.get("file_size"),
"width": image_request.get("width"),
- }, prompt]
+ }, messages[-1]["content"]["parts"][0]]
}
- # Create a message object with the user role and the content
- messages = [{
- "id": str(uuid.uuid4()),
- "author": {"role": "user"},
- "content": content,
- }]
- # Check if there is an image response
- if image_request:
# Add the metadata object with the attachments
- messages[0]["metadata"] = {
+ messages[-1]["metadata"] = {
"attachments": [{
"height": image_request.get("height"),
"id": image_request.get("file_id"),
@@ -225,7 +223,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
}]
}
return messages
-
+
@classmethod
async def get_generated_image(cls, session: StreamSession, headers: dict, line: dict) -> ImageResponse:
"""
@@ -333,30 +331,33 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
raise MissingRequirementsError('Install "py-arkose-generator" and "async_property" package')
if not parent_id:
parent_id = str(uuid.uuid4())
- if cls._args is None and cookies is None:
- cookies = get_cookies("chat.openai.com", False)
+
+ # Read api_key from args
api_key = kwargs["access_token"] if "access_token" in kwargs else api_key
- if api_key is None and cookies is not None:
- api_key = cookies["access_token"] if "access_token" in cookies else api_key
if cls._args is None:
- cls._args = {
- "headers": {"Cookie": "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")},
- "cookies": {} if cookies is None else cookies
- }
- if api_key is not None:
- cls._args["headers"]["Authorization"] = f"Bearer {api_key}"
+ if api_key is None:
+ # Read api_key from cookies
+ cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies
+ api_key = cookies["access_token"] if "access_token" in cookies else api_key
+ cls._args = cls._create_request_args(cookies)
+
async with StreamSession(
proxies={"https": proxy},
impersonate="chrome",
- timeout=timeout,
- headers=cls._args["headers"]
+ timeout=timeout
) as session:
+ if api_key is None and cookies:
+ # Read api_key from session
+ api_key = await cls.fetch_access_token(session, cls._args["headers"])
+
if api_key is not None:
+ cls._args["headers"]["Authorization"] = f"Bearer {api_key}"
try:
cls.default_model = await cls.get_default_model(session, cls._args["headers"])
except Exception as e:
if debug.logging:
print(f"{e.__class__.__name__}: {e}")
+
if cls.default_model is None:
login_url = os.environ.get("G4F_LOGIN_URL")
if login_url:
@@ -366,12 +367,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
except MissingRequirementsError:
raise MissingAuthError(f'Missing or invalid "access_token". Add a new "api_key" please')
cls.default_model = await cls.get_default_model(session, cls._args["headers"])
+
try:
- image_response = None
- if image:
- image_response = await cls.upload_image(session, cls._args["headers"], image, kwargs.get("image_name"))
+ image_response = await cls.upload_image(
+ session,
+ cls._args["headers"],
+ image,
+ kwargs.get("image_name")
+ ) if image else None
except Exception as e:
yield e
+
end_turn = EndTurn()
model = cls.get_model(model)
model = "text-davinci-002-render-sha" if model == "gpt-3.5-turbo" else model
@@ -389,13 +395,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"history_and_training_disabled": history_disabled and not auto_continue,
}
if action != "continue":
- prompt = format_prompt(messages) if not conversation_id else messages[-1]["content"]
- data["messages"] = cls.create_messages(prompt, image_response)
-
- # Update cookies before next request
- for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
- cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value
- cls._args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in cls._args["cookies"].items())
+ messages = messages if not conversation_id else [messages[-1]]
+ data["messages"] = cls.create_messages(messages, image_response)
async with session.post(
f"{cls.url}/backend-api/conversation",
@@ -406,6 +407,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
**cls._args["headers"]
}
) as response:
+ cls._update_request_args(session)
if not response.ok:
raise RuntimeError(f"Response {response.status}: {await response.text()}")
last_message: int = 0
@@ -475,13 +477,13 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"let session = await fetch('/api/auth/session');"
"let data = await session.json();"
"let accessToken = data['accessToken'];"
- "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4);"
+ "let expires = new Date(); expires.setTime(expires.getTime() + 60 * 60 * 4 * 1000);"
"document.cookie = 'access_token=' + accessToken + ';expires=' + expires.toUTCString() + ';path=/';"
"return accessToken;"
)
args = get_args_from_browser(f"{cls.url}/", driver, do_bypass_cloudflare=False)
args["headers"]["Authorization"] = f"Bearer {access_token}"
- args["headers"]["Cookie"] = "; ".join(f"{k}={v}" for k, v in args["cookies"].items() if k != "access_token")
+ args["headers"]["Cookie"] = cls._format_cookies(args["cookies"])
return args
finally:
driver.close()
@@ -516,6 +518,34 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
return decoded_json["token"]
raise RuntimeError(f"Response: {decoded_json}")
+ @classmethod
+ async def fetch_access_token(cls, session: StreamSession, headers: dict):
+ async with session.get(
+ f"{cls.url}/api/auth/session",
+ headers=headers
+ ) as response:
+ if response.ok:
+ data = await response.json()
+ if "accessToken" in data:
+ return data["accessToken"]
+
+ @staticmethod
+ def _format_cookies(cookies: Cookies):
+ return "; ".join(f"{k}={v}" for k, v in cookies.items() if k != "access_token")
+
+ @classmethod
+ def _create_request_args(cls, cookies: Union[Cookies, None]):
+ return {
+ "headers": {} if cookies is None else {"Cookie": cls._format_cookies(cookies)},
+ "cookies": {} if cookies is None else cookies
+ }
+
+ @classmethod
+ def _update_request_args(cls, session: StreamSession):
+ for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
+ cls._args["cookies"][c.name if hasattr(c, "name") else c.key] = c.value
+ cls._args["headers"]["Cookie"] = cls._format_cookies(cls._args["cookies"])
+
class EndTurn:
"""
Class to represent the end of a conversation turn.