summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-03-28 17:17:59 +0100
committerGitHub <noreply@github.com>2024-03-28 17:17:59 +0100
commit64e07b7fbf810176d66506786a946a3122ea7fc4 (patch)
tree1cf10ab4f117583fdb4a98712c18052e5a42cdf2 /g4f
parentMerge pull request #1758 from Zero6992/main (diff)
parentFix history support for OpenaiChat (diff)
downloadgpt4free-0.2.7.2.tar
gpt4free-0.2.7.2.tar.gz
gpt4free-0.2.7.2.tar.bz2
gpt4free-0.2.7.2.tar.lz
gpt4free-0.2.7.2.tar.xz
gpt4free-0.2.7.2.tar.zst
gpt4free-0.2.7.2.zip
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py11
-rw-r--r--g4f/gui/server/api.py15
-rw-r--r--g4f/gui/server/backend.py2
3 files changed, 15 insertions, 13 deletions
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 72f9f224..396d73dd 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -389,19 +389,17 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
print(f"{e.__class__.__name__}: {e}")
model = cls.get_model(model).replace("gpt-3.5-turbo", "text-davinci-002-render-sha")
- fields = Conversation() if conversation is None else copy(conversation)
+ fields = Conversation(conversation_id, parent_id) if conversation is None else copy(conversation)
fields.finish_reason = None
while fields.finish_reason is None:
- conversation_id = conversation_id if fields.conversation_id is None else fields.conversation_id
- parent_id = parent_id if fields.message_id is None else fields.message_id
websocket_request_id = str(uuid.uuid4())
data = {
"action": action,
"conversation_mode": {"kind": "primary_assistant"},
"force_paragen": False,
"force_rate_limit": False,
- "conversation_id": conversation_id,
- "parent_message_id": parent_id,
+ "conversation_id": fields.conversation_id,
+ "parent_message_id": fields.message_id,
"model": model,
"history_and_training_disabled": history_disabled and not auto_continue and not return_conversation,
"websocket_request_id": websocket_request_id
@@ -425,6 +423,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
await raise_for_status(response)
async for chunk in cls.iter_messages_chunk(response.iter_lines(), session, fields):
if return_conversation:
+ history_disabled = False
return_conversation = False
yield fields
yield chunk
@@ -432,7 +431,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
break
action = "continue"
await asyncio.sleep(5)
- if history_disabled and auto_continue and not return_conversation:
+ if history_disabled and auto_continue:
await cls.delete_conversation(session, cls._headers, fields.conversation_id)
@staticmethod
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py
index da934d57..b4e2b3d4 100644
--- a/g4f/gui/server/api.py
+++ b/g4f/gui/server/api.py
@@ -41,7 +41,7 @@ from g4f.providers.base_provider import ProviderModelMixin
from g4f.Provider.bing.create_images import patch_provider
from g4f.providers.conversation import BaseConversation
-conversations: dict[str, BaseConversation] = {}
+conversations: dict[dict[str, BaseConversation]] = {}
class Api():
@@ -106,7 +106,8 @@ class Api():
kwargs["image"] = open(self.image, "rb")
for message in self._create_response_stream(
self._prepare_conversation_kwargs(options, kwargs),
- options.get("conversation_id")
+ options.get("conversation_id"),
+ options.get('provider')
):
if not window.evaluate_js(f"if (!this.abort) this.add_message_chunk({json.dumps(message)}); !this.abort && !this.error;"):
break
@@ -193,8 +194,8 @@ class Api():
messages[-1]["content"] = get_search_message(messages[-1]["content"])
conversation_id = json_data.get("conversation_id")
- if conversation_id and conversation_id in conversations:
- kwargs["conversation"] = conversations[conversation_id]
+ if conversation_id and provider in conversations and conversation_id in conversations[provider]:
+ kwargs["conversation"] = conversations[provider][conversation_id]
model = json_data.get('model')
model = model if model else models.default
@@ -211,7 +212,7 @@ class Api():
**kwargs
}
- def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator:
+ def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator:
"""
Creates and returns a streaming response for the conversation.
@@ -231,7 +232,9 @@ class Api():
first = False
yield self._format_json("provider", get_last_provider(True))
if isinstance(chunk, BaseConversation):
- conversations[conversation_id] = chunk
+ if provider not in conversations:
+ conversations[provider] = {}
+ conversations[provider][conversation_id] = chunk
yield self._format_json("conversation", conversation_id)
elif isinstance(chunk, Exception):
logging.exception(chunk)
diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py
index fb8404d4..d30b97d9 100644
--- a/g4f/gui/server/backend.py
+++ b/g4f/gui/server/backend.py
@@ -85,7 +85,7 @@ class Backend_Api(Api):
kwargs = self._prepare_conversation_kwargs(json_data, kwargs)
return self.app.response_class(
- self._create_response_stream(kwargs, json_data.get("conversation_id")),
+ self._create_response_stream(kwargs, json_data.get("conversation_id"), json_data.get("provider")),
mimetype='text/event-stream'
)