summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/OpenaiChat.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth/OpenaiChat.py')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py48
1 files changed, 20 insertions, 28 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 0aff99a7..b83025e4 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -3,10 +3,10 @@ from __future__ import annotations
import asyncio
import uuid
import json
-import os
import base64
import time
from aiohttp import ClientWebSocketResponse
+from copy import copy
try:
import webview
@@ -22,13 +22,13 @@ except ImportError:
pass
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from ..helper import get_cookies
from ...webdriver import get_browser
from ...typing import AsyncResult, Messages, Cookies, ImageType, Union, AsyncIterator
from ...requests import get_args_from_browser, raise_for_status
from ...requests.aiohttp import StreamSession
from ...image import to_image, to_bytes, ImageResponse, ImageRequest
-from ...errors import MissingRequirementsError, MissingAuthError, ProviderNotWorkingError
+from ...errors import MissingAuthError
+from ...providers.conversation import BaseConversation
from ..openai.har_file import getArkoseAndAccessToken
from ... import debug
@@ -56,11 +56,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
prompt: str = None,
model: str = "",
messages: Messages = [],
- history_disabled: bool = False,
- action: str = "next",
- conversation_id: str = None,
- parent_id: str = None,
- image: ImageType = None,
**kwargs
) -> Response:
"""
@@ -89,12 +84,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
generator = cls.create_async_generator(
model,
messages,
- history_disabled=history_disabled,
- action=action,
- conversation_id=conversation_id,
- parent_id=parent_id,
- image=image,
- response_fields=True,
+ return_conversation=True,
**kwargs
)
return Response(
@@ -209,7 +199,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
} for message in messages]
# Check if there is an image response
- if image_request:
+ if image_request is not None:
# Change content in last user message
messages[-1]["content"] = {
"content_type": "multimodal_text",
@@ -308,10 +298,11 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
history_disabled: bool = True,
action: str = "next",
conversation_id: str = None,
+ conversation: Conversation = None,
parent_id: str = None,
image: ImageType = None,
image_name: str = None,
- response_fields: bool = False,
+ return_conversation: bool = False,
**kwargs
) -> AsyncResult:
"""
@@ -330,7 +321,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
conversation_id (str): ID of the conversation.
parent_id (str): ID of the parent message.
image (ImageType): Image to include in the conversation.
- response_fields (bool): Flag to include response fields in the output.
+ return_conversation (bool): Flag to include response fields in the output.
**kwargs: Additional keyword arguments.
Yields:
@@ -387,6 +378,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
arkose_token, api_key, cookies = await getArkoseAndAccessToken(proxy)
cls._create_request_args(cookies)
cls._set_api_key(api_key)
+ if arkose_token is None:
+ raise MissingAuthError("No arkose token found in .har file")
try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
@@ -396,7 +389,8 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print(f"{e.__class__.__name__}: {e}")
model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
- fields = ResponseFields()
+ fields = Conversation() if conversation is None else copy(conversation)
+ fields.finish_reason = None
while fields.finish_reason is None:
conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
parent_id = parent_id if fields.message_id is None else fields.message_id
@@ -422,8 +416,6 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
}
if need_arkose:
headers["OpenAI-Sentinel-Arkose-Token"] = arkose_token
- headers["OpenAI-Sentinel-Chat-Requirements-Token"] = chat_token
-
async with session.post(
f"{cls.url}/backend-api/conversation",
json=data,
@@ -432,15 +424,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
cls._update_request_args(session)
await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields):
- if response_fields:
- response_fields = False
+ if return_conversation:
+ return_conversation = False
yield fields
yield chunk
if not auto_continue:
break
action = "continue"
await asyncio.sleep(5)
- if history_disabled and auto_continue:
+ if history_disabled and auto_continue and not return_conversation:
await cls.delete_conversation(session, cls._headers, fields.conversation_id)
@staticmethod
@@ -458,7 +450,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
cls,
messages: AsyncIterator,
session: StreamSession,
- fields: ResponseFields
+ fields: Conversation
) -> AsyncIterator:
last_message: int = 0
async for message in messages:
@@ -487,7 +479,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
break
@classmethod
- async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: ResponseFields) -> AsyncIterator:
+ async def iter_messages_line(cls, session: StreamSession, line: bytes, fields: Conversation) -> AsyncIterator:
if not line.startswith(b"data: "):
return
elif line.startswith(b"data: [DONE]"):
@@ -618,7 +610,7 @@ this.fetch = async (url, options) => {
@classmethod
def _update_request_args(cls, session: StreamSession):
for c in session.cookie_jar if hasattr(session, "cookie_jar") else session.cookies.jar:
- cls._cookies[c.name if hasattr(c, "name") else c.key] = c.value
+ cls._cookies[c.key if hasattr(c, "key") else c.name] = c.value
cls._update_cookie_header()
@classmethod
@@ -631,7 +623,7 @@ this.fetch = async (url, options) => {
def _update_cookie_header(cls):
cls._headers["Cookie"] = cls._format_cookies(cls._cookies)
-class ResponseFields:
+class Conversation(BaseConversation):
"""
Class to encapsulate response fields.
"""
@@ -664,7 +656,7 @@ class Response():
self._generator = None
chunks = []
async for chunk in self._generator:
- if isinstance(chunk, ResponseFields):
+ if isinstance(chunk, Conversation):
self._fields = chunk
else:
yield chunk