From cfa45e701645335aa6fe27e11aa208ac208c01ec Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Fri, 8 Mar 2024 11:01:38 +0100 Subject: Expire cache, Fix multiple websocket conversations in OpenaiChat Map system messages to user messages in GeminiPro --- g4f/Provider/GeminiPro.py | 18 ++++++++---------- g4f/Provider/needs_auth/OpenaiChat.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 16 deletions(-) (limited to 'g4f/Provider') diff --git a/g4f/Provider/GeminiPro.py b/g4f/Provider/GeminiPro.py index 1c5487b1..a22304d5 100644 --- a/g4f/Provider/GeminiPro.py +++ b/g4f/Provider/GeminiPro.py @@ -26,38 +26,35 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): stream: bool = False, proxy: str = None, api_key: str = None, - api_base: str = None, - use_auth_header: bool = True, + api_base: str = "https://generativelanguage.googleapis.com/v1beta", + use_auth_header: bool = False, image: ImageType = None, connector: BaseConnector = None, **kwargs ) -> AsyncResult: - model = "gemini-pro-vision" if not model and image else model + model = "gemini-pro-vision" if model is None and image is not None else model model = cls.get_model(model) if not api_key: raise MissingAuthError('Missing "api_key"') headers = params = None - if api_base and use_auth_header: + if use_auth_header: headers = {"Authorization": f"Bearer {api_key}"} else: params = {"key": api_key} - if not api_base: - api_base = f"https://generativelanguage.googleapis.com/v1beta" - method = "streamGenerateContent" if stream else "generateContent" url = f"{api_base.rstrip('/')}/models/{model}:{method}" async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session: contents = [ { - "role": "model" if message["role"] == "assistant" else message["role"], + "role": "model" if message["role"] == "assistant" else "user", "parts": [{"text": message["content"]}] } for message in messages ] - if image: + if image is not None: image = to_bytes(image) contents[-1]["parts"].append({ "inline_data": { @@ -87,7 +84,8 @@ class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin): lines = [b"{\n"] elif chunk == b",\r\n" or chunk == b"]": try: - data = json.loads(b"".join(lines)) + data = b"".join(lines) + data = json.loads(data) yield data["candidates"][0]["content"]["parts"][0]["text"] except: data = data.decode() if isinstance(data, bytes) else data diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py index e507404b..1a6fd947 100644 --- a/g4f/Provider/needs_auth/OpenaiChat.py +++ b/g4f/Provider/needs_auth/OpenaiChat.py @@ -5,6 +5,7 @@ import uuid import json import os import base64 +import time from aiohttp import ClientWebSocketResponse try: @@ -47,7 +48,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): _api_key: str = None _headers: dict = None _cookies: Cookies = None - _last_message: int = 0 + _expires: int = None @classmethod async def create( @@ -348,7 +349,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): timeout=timeout ) as session: # Read api_key and cookies from cache / browser config - if cls._headers is None: + if cls._headers is None or time.time() > cls._expires: if api_key is None: # Read api_key from cookies cookies = get_cookies("chat.openai.com", False) if cookies is None else cookies @@ -437,17 +438,20 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): await cls.delete_conversation(session, cls._headers, fields.conversation_id) @staticmethod - async def iter_messages_ws(ws: ClientWebSocketResponse) -> AsyncIterator: + async def iter_messages_ws(ws: ClientWebSocketResponse, conversation_id: str) -> AsyncIterator: while True: - yield base64.b64decode((await ws.receive_json())["body"]) + message = await ws.receive_json() + if message["conversation_id"] == conversation_id: + yield base64.b64decode(message["body"]) @classmethod async def iter_messages_chunk(cls, messages: AsyncIterator, session: StreamSession, fields: ResponseFields) -> AsyncIterator: last_message: int = 0 async for message in messages: if message.startswith(b'{"wss_url":'): - async with session.ws_connect(json.loads(message)["wss_url"]) as ws: - async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws), session, fields): + message = json.loads(message) + async with session.ws_connect(message["wss_url"]) as ws: + async for chunk in cls.iter_messages_chunk(cls.iter_messages_ws(ws, message["conversation_id"]), session, fields): yield chunk break async for chunk in cls.iter_messages_line(session, message, fields): @@ -589,6 +593,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin): @classmethod def _set_api_key(cls, api_key: str): cls._api_key = api_key + cls._expires = int(time.time()) + 60 * 60 * 4 cls._headers["Authorization"] = f"Bearer {api_key}" @classmethod -- cgit v1.2.3