From c2e3107cb8bfbdeba78b70b3da3b64a82345fbab Mon Sep 17 00:00:00 2001 From: kqlio67 Date: Tue, 22 Oct 2024 12:16:27 +0300 Subject: feat(g4f/gui/server/api.py): improve image handling and response streaming --- g4f/gui/server/api.py | 124 +++++++++++++++++++++++++------------------------- 1 file changed, 63 insertions(+), 61 deletions(-) (limited to 'g4f/gui/server') diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py index 64b84767..57f3eaa1 100644 --- a/g4f/gui/server/api.py +++ b/g4f/gui/server/api.py @@ -23,8 +23,8 @@ from g4f.providers.conversation import BaseConversation conversations: dict[dict[str, BaseConversation]] = {} images_dir = "./generated_images" -class Api(): +class Api: @staticmethod def get_models() -> list[str]: """ @@ -42,9 +42,11 @@ class Api(): if provider in __map__: provider: ProviderType = __map__[provider] if issubclass(provider, ProviderModelMixin): - return [{"model": model, "default": model == provider.default_model} for model in provider.get_models()] - else: - return [] + return [ + {"model": model, "default": model == provider.default_model} + for model in provider.get_models() + ] + return [] @staticmethod def get_image_models() -> list[dict]: @@ -66,7 +68,7 @@ class Api(): "image_model": model, "vision_model": parent.default_vision_model if hasattr(parent, "default_vision_model") else None }) - index.append(parent.__name__) + index.append(parent.__name__) elif hasattr(provider, "default_vision_model") and provider.__name__ not in index: image_models.append({ "provider": provider.__name__, @@ -84,15 +86,13 @@ class Api(): Return a list of all working providers. """ return { - provider.__name__: (provider.label - if hasattr(provider, "label") - else provider.__name__) + - (" (WebDriver)" - if "webdriver" in provider.get_parameters() - else "") + - (" (Auth)" - if provider.needs_auth - else "") + provider.__name__: ( + provider.label if hasattr(provider, "label") else provider.__name__ + ) + ( + " (WebDriver)" if "webdriver" in provider.get_parameters() else "" + ) + ( + " (Auth)" if provider.needs_auth else "" + ) for provider in __providers__ if provider.working } @@ -126,7 +126,7 @@ class Api(): Returns: dict: Arguments prepared for chat completion. - """ + """ model = json_data.get('model') or models.default provider = json_data.get('provider') messages = json_data['messages'] @@ -155,61 +155,62 @@ class Api(): } def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str) -> Iterator: - """ - Creates and returns a streaming response for the conversation. - - Args: - kwargs (dict): Arguments for creating the chat completion. - - Yields: - str: JSON formatted response chunks for the stream. - - Raises: - Exception: If an error occurs during the streaming process. - """ try: + result = ChatCompletion.create(**kwargs) first = True - for chunk in ChatCompletion.create(**kwargs): + if isinstance(result, ImageResponse): + # Якщо результат є ImageResponse, обробляємо його як одиночний елемент if first: first = False yield self._format_json("provider", get_last_provider(True)) - if isinstance(chunk, BaseConversation): - if provider not in conversations: - conversations[provider] = {} - conversations[provider][conversation_id] = chunk - yield self._format_json("conversation", conversation_id) - elif isinstance(chunk, Exception): - logging.exception(chunk) - yield self._format_json("message", get_error_message(chunk)) - elif isinstance(chunk, ImagePreview): - yield self._format_json("preview", chunk.to_string()) - elif isinstance(chunk, ImageResponse): - async def copy_images(images: list[str], cookies: Optional[Cookies] = None): - async with ClientSession( - connector=get_connector(None, os.environ.get("G4F_PROXY")), - cookies=cookies - ) as session: - async def copy_image(image): - async with session.get(image) as response: - target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") - with open(target, "wb") as f: - async for chunk in response.content.iter_any(): - f.write(chunk) - with open(target, "rb") as f: - extension = is_accepted_format(f.read(12)).split("/")[-1] - extension = "jpg" if extension == "jpeg" else extension - new_target = f"{target}.{extension}" - os.rename(target, new_target) - return f"/images/{os.path.basename(new_target)}" - return await asyncio.gather(*[copy_image(image) for image in images]) - images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies"))) - yield self._format_json("content", str(ImageResponse(images, chunk.alt))) - elif not isinstance(chunk, FinishReason): - yield self._format_json("content", str(chunk)) + yield self._format_json("content", str(result)) + else: + # Якщо результат є ітерабельним, обробляємо його як раніше + for chunk in result: + if first: + first = False + yield self._format_json("provider", get_last_provider(True)) + if isinstance(chunk, BaseConversation): + if provider not in conversations: + conversations[provider] = {} + conversations[provider][conversation_id] = chunk + yield self._format_json("conversation", conversation_id) + elif isinstance(chunk, Exception): + logging.exception(chunk) + yield self._format_json("message", get_error_message(chunk)) + elif isinstance(chunk, ImagePreview): + yield self._format_json("preview", chunk.to_string()) + elif isinstance(chunk, ImageResponse): + # Обробка ImageResponse + images = asyncio.run(self._copy_images(chunk.get_list(), chunk.options.get("cookies"))) + yield self._format_json("content", str(ImageResponse(images, chunk.alt))) + elif not isinstance(chunk, FinishReason): + yield self._format_json("content", str(chunk)) except Exception as e: logging.exception(e) yield self._format_json('error', get_error_message(e)) + # Додайте цей метод до класу Api + async def _copy_images(self, images: list[str], cookies: Optional[Cookies] = None): + async with ClientSession( + connector=get_connector(None, os.environ.get("G4F_PROXY")), + cookies=cookies + ) as session: + async def copy_image(image): + async with session.get(image) as response: + target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") + with open(target, "wb") as f: + async for chunk in response.content.iter_any(): + f.write(chunk) + with open(target, "rb") as f: + extension = is_accepted_format(f.read(12)).split("/")[-1] + extension = "jpg" if extension == "jpeg" else extension + new_target = f"{target}.{extension}" + os.rename(target, new_target) + return f"/images/{os.path.basename(new_target)}" + + return await asyncio.gather(*[copy_image(image) for image in images]) + def _format_json(self, response_type: str, content): """ Formats and returns a JSON response. @@ -226,6 +227,7 @@ class Api(): response_type: content } + def get_error_message(exception: Exception) -> str: """ Generates a formatted error message from an exception. -- cgit v1.2.3