diff options
Diffstat (limited to 'g4f/client')
-rw-r--r-- | g4f/client/async_client.py | 90 | ||||
-rw-r--r-- | g4f/client/service.py | 6 | ||||
-rw-r--r-- | g4f/client/stubs.py | 23 |
3 files changed, 96 insertions, 23 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py index 07ad3357..1508e566 100644 --- a/g4f/client/async_client.py +++ b/g4f/client/async_client.py @@ -3,6 +3,9 @@ from __future__ import annotations import time import random import string +import asyncio +import base64 +from aiohttp import ClientSession, BaseConnector from .types import Client as BaseClient from .types import ProviderType, FinishReason @@ -11,9 +14,11 @@ from .types import AsyncIterResponse, ImageProvider from .image_models import ImageModels from .helper import filter_json, find_stop, filter_none, cast_iter_async from .service import get_last_provider, get_model_and_provider +from ..Provider import ProviderUtils from ..typing import Union, Messages, AsyncIterator, ImageType -from ..errors import NoImageResponseError -from ..image import ImageResponse as ImageProviderResponse +from ..errors import NoImageResponseError, ProviderNotFoundError +from ..requests.aiohttp import get_connector +from ..image import ImageResponse as ImageProviderResponse, ImageDataResponse try: anext @@ -156,12 +161,28 @@ class Chat(): def __init__(self, client: AsyncClient, provider: ProviderType = None): self.completions = Completions(client, provider) -async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]: +async def iter_image_response( + response: AsyncIterator, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None +) -> Union[ImagesResponse, None]: async for chunk in response: if isinstance(chunk, ImageProviderResponse): - return ImagesResponse([Image(image) for image in chunk.get_list()]) + if response_format == "b64_json": + async with ClientSession( + connector=get_connector(connector, proxy) + ) as session: + async def fetch_image(image): + async with session.get(image) as response: + return base64.b64encode(await response.content.read()).decode() + images = await asyncio.gather(*[fetch_image(image) for image in chunk.get_list()]) + return ImagesResponse([Image(None, image, chunk.alt) for image in images], int(time.time())) + return ImagesResponse([Image(image, None, chunk.alt) for image in chunk.get_list()], int(time.time())) + elif isinstance(chunk, ImageDataResponse): + return ImagesResponse([Image(None, image, chunk.alt) for image in chunk.get_list()], int(time.time())) -def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: +def create_image(provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: prompt = f"create a image with: {prompt}" if provider.__name__ == "You": kwargs["chat_mode"] = "create" @@ -169,7 +190,6 @@ def create_image(client: AsyncClient, provider: ProviderType, prompt: str, model model, [{"role": "user", "content": prompt}], stream=True, - proxy=client.get_proxy(), **kwargs ) @@ -179,31 +199,71 @@ class Images(): self.provider: ImageProvider = provider self.models: ImageModels = ImageModels(client) - async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse: - provider = self.models.get(model, self.provider) + def get_provider(self, model: str, provider: ProviderType = None): + if isinstance(provider, str): + if provider in ProviderUtils.convert: + provider = ProviderUtils.convert[provider] + else: + raise ProviderNotFoundError(f'Provider not found: {provider}') + else: + provider = self.models.get(model, self.provider) + return provider + + async def generate( + self, + prompt, + model: str = "", + provider: ProviderType = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ) -> ImagesResponse: + provider = self.get_provider(model, provider) if hasattr(provider, "create_async_generator"): - response = create_image(self.client, provider, prompt, **kwargs) + response = create_image( + provider, + prompt, + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), + **kwargs + ) else: response = await provider.create_async(prompt) return ImagesResponse([Image(image) for image in response.get_list()]) - image = await iter_image_response(response) + image = await iter_image_response(response, response_format, connector, proxy) if image is None: raise NoImageResponseError() return image - async def create_variation(self, image: ImageType, model: str = None, **kwargs): - provider = self.models.get(model, self.provider) + async def create_variation( + self, + image: ImageType, + model: str = None, + response_format: str = None, + connector: BaseConnector = None, + proxy: str = None, + **kwargs + ): + provider = self.get_provider(model, provider) result = None if hasattr(provider, "create_async_generator"): response = provider.create_async_generator( "", [{"role": "user", "content": "create a image like this"}], - True, + stream=True, image=image, - proxy=self.client.get_proxy(), + **filter_none( + response_format=response_format, + connector=connector, + proxy=self.client.get_proxy() if proxy is None else proxy, + ), **kwargs ) - result = iter_image_response(response) + result = iter_image_response(response, response_format, connector, proxy) if result is None: raise NoImageResponseError() return result diff --git a/g4f/client/service.py b/g4f/client/service.py index dd6bf4b6..5fdb150c 100644 --- a/g4f/client/service.py +++ b/g4f/client/service.py @@ -4,7 +4,7 @@ from typing import Union from .. import debug, version from ..errors import ProviderNotFoundError, ModelNotFoundError, ProviderNotWorkingError, StreamNotSupportedError -from ..models import Model, ModelUtils +from ..models import Model, ModelUtils, default from ..Provider import ProviderUtils from ..providers.types import BaseRetryProvider, ProviderType from ..providers.retry_provider import IterProvider @@ -60,7 +60,9 @@ def get_model_and_provider(model : Union[Model, str], model = ModelUtils.convert[model] if not provider: - if isinstance(model, str): + if not model: + model = default + elif isinstance(model, str): raise ModelNotFoundError(f'Model not found: {model}') provider = model.best_provider diff --git a/g4f/client/stubs.py b/g4f/client/stubs.py index ceb51b83..8cf2bcba 100644 --- a/g4f/client/stubs.py +++ b/g4f/client/stubs.py @@ -96,13 +96,24 @@ class ChatCompletionDeltaChoice(Model): } class Image(Model): - url: str + def __init__(self, url: str = None, b64_json: str = None, revised_prompt: str = None) -> None: + if url is not None: + self.url = url + if b64_json is not None: + self.b64_json = b64_json + if revised_prompt is not None: + self.revised_prompt = revised_prompt - def __init__(self, url: str) -> None: - self.url = url + def to_json(self): + return self.__dict__ class ImagesResponse(Model): - data: list[Image] - - def __init__(self, data: list) -> None: + def __init__(self, data: list[Image], created: int = 0) -> None: self.data = data + self.created = created + + def to_json(self): + return { + **self.__dict__, + "data": [image.to_json() for image in self.data] + }
\ No newline at end of file |