summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/OpenaiChat.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/OpenaiChat.py41
1 files changed, 24 insertions, 17 deletions
diff --git a/g4f/Provider/OpenaiChat.py b/g4f/Provider/OpenaiChat.py
index c023c898..9ca0cd58 100644
--- a/g4f/Provider/OpenaiChat.py
+++ b/g4f/Provider/OpenaiChat.py
@@ -4,8 +4,11 @@ try:
except ImportError:
has_module = False
-from .base_provider import AsyncGeneratorProvider, get_cookies
-from ..typing import AsyncGenerator
+from .base_provider import AsyncGeneratorProvider, get_cookies, format_prompt
+from ..typing import AsyncGenerator
+from httpx import AsyncClient
+import json
+
class OpenaiChat(AsyncGeneratorProvider):
url = "https://chat.openai.com"
@@ -14,6 +17,7 @@ class OpenaiChat(AsyncGeneratorProvider):
supports_gpt_35_turbo = True
supports_gpt_4 = True
supports_stream = True
+ _access_token = None
@classmethod
async def create_async_generator(
@@ -21,9 +25,9 @@ class OpenaiChat(AsyncGeneratorProvider):
model: str,
messages: list[dict[str, str]],
proxy: str = None,
- access_token: str = None,
+ access_token: str = _access_token,
cookies: dict = None,
- **kwargs
+ **kwargs: dict
) -> AsyncGenerator:
config = {"access_token": access_token, "model": model}
@@ -37,21 +41,12 @@ class OpenaiChat(AsyncGeneratorProvider):
)
if not access_token:
- cookies = cookies if cookies else get_cookies("chat.openai.com")
- response = await bot.session.get("https://chat.openai.com/api/auth/session", cookies=cookies)
- access_token = response.json()["accessToken"]
- bot.set_access_token(access_token)
-
- if len(messages) > 1:
- formatted = "\n".join(
- ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
- )
- prompt = f"{formatted}\nAssistant:"
- else:
- prompt = messages.pop()["content"]
+ cookies = cookies if cookies else get_cookies("chat.openai.com")
+ cls._access_token = await get_access_token(bot.session, cookies)
+ bot.set_access_token(cls._access_token)
returned = None
- async for message in bot.ask(prompt):
+ async for message in bot.ask(format_prompt(messages)):
message = message["message"]
if returned:
if message.startswith(returned):
@@ -61,6 +56,9 @@ class OpenaiChat(AsyncGeneratorProvider):
else:
yield message
returned = message
+
+ await bot.delete_conversation(bot.conversation_id)
+
@classmethod
@property
@@ -73,3 +71,12 @@ class OpenaiChat(AsyncGeneratorProvider):
]
param = ", ".join([": ".join(p) for p in params])
return f"g4f.provider.{cls.__name__} supports: ({param})"
+
+
+async def get_access_token(session: AsyncClient, cookies: dict):
+ response = await session.get("https://chat.openai.com/api/auth/session", cookies=cookies)
+ response.raise_for_status()
+ try:
+ return response.json()["accessToken"]
+ except json.decoder.JSONDecodeError:
+ raise RuntimeError(f"Response: {response.text}") \ No newline at end of file