diff options
author | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-03-16 18:07:53 +0100 |
---|---|---|
committer | Heiner Lohaus <hlohaus@users.noreply.github.com> | 2024-03-16 18:07:53 +0100 |
commit | 4778356064a005f0dec78a8fef40a26289217d7c (patch) | |
tree | d1835ab078e6a028c94b212d8f7530f1a78cc66f /g4f/gui/server/api.py | |
parent | Add copilot conversation mode (diff) | |
download | gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.tar gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.tar.gz gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.tar.bz2 gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.tar.lz gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.tar.xz gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.tar.zst gpt4free-4778356064a005f0dec78a8fef40a26289217d7c.zip |
Diffstat (limited to 'g4f/gui/server/api.py')
-rw-r--r-- | g4f/gui/server/api.py | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index df7b487d..966319e4 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -13,8 +13,12 @@ from g4f.errors import VersionNotFoundError from g4f.Provider import ProviderType, __providers__, __map__ from g4f.providers.base_provider import ProviderModelMixin from g4f.Provider.bing.create_images import patch_provider +from g4f.Provider.Bing import Conversation + +conversations: dict[str, Conversation] = {} + +class Api(): -class Api(): def get_models(self) -> list[str]: """ Return a list of all models. @@ -73,7 +77,8 @@ class Api(): def get_conversation(self, options: dict, **kwargs) -> Iterator: window = webview.active_window() for message in self._create_response_stream( - self._prepare_conversation_kwargs(options, kwargs) + self._prepare_conversation_kwargs(options, kwargs), + options.get("conversation_id") ): window.evaluate_js(f"this.add_message_chunk({json.dumps(message)})") @@ -101,6 +106,10 @@ class Api(): from .internet import get_search_message messages[-1]["content"] = get_search_message(messages[-1]["content"]) + conversation_id = json_data.get("conversation_id") + if conversation_id and conversation_id in conversations: + kwargs["conversation"] = conversations[conversation_id] + model = json_data.get('model') model = model if model else models.default patch = patch_provider if json_data.get('patch_provider') else None @@ -112,10 +121,11 @@ class Api(): "stream": True, "ignore_stream": True, "patch_provider": patch, + "return_conversation": True, **kwargs } - def _create_response_stream(self, kwargs) -> Iterator: + def _create_response_stream(self, kwargs, conversation_id: str) -> Iterator: """ Creates and returns a streaming response for the conversation. @@ -133,12 +143,15 @@ class Api(): for chunk in ChatCompletion.create(**kwargs): if first: first = False - yield self._format_json('provider', get_last_provider(True)) - if isinstance(chunk, Exception): + yield self._format_json("provider", get_last_provider(True)) + if isinstance(chunk, Conversation): + conversations[conversation_id] = chunk + yield self._format_json("conversation", conversation_id) + elif isinstance(chunk, Exception): logging.exception(chunk) - yield self._format_json('message', get_error_message(chunk)) + yield self._format_json("message", get_error_message(chunk)) else: - yield self._format_json('content', chunk) + yield self._format_json("content", chunk) except Exception as e: logging.exception(e) yield self._format_json('error', get_error_message(e)) |