diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-04-06 23:07:40 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-06 23:07:40 +0200 |
commit | 8229b62ce331e070772d809ebc554b5cb187612b (patch) | |
tree | 1f78833f11b63f1fd19cb3066362e4c06c749684 | |
parent | Update image_models.py (diff) | |
download | gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.tar gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.tar.gz gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.tar.bz2 gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.tar.lz gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.tar.xz gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.tar.zst gpt4free-8229b62ce331e070772d809ebc554b5cb187612b.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/client/async.py | 67 |
1 files changed, 44 insertions, 23 deletions
diff --git a/g4f/client/async.py b/g4f/client/async.py index 1ac738fd..76e410fc 100644 --- a/g4f/client/async.py +++ b/g4f/client/async.py @@ -14,9 +14,9 @@ from ..image import ImageResponse as ImageProviderResponse from ..errors import NoImageResponseError, RateLimitError, MissingAuthError from .. import get_model_and_provider, get_last_provider from .helper import read_json, find_stop, filter_none - +รค async def iter_response( - response: AsyncIerator[str], + response: AsyncIterator[str], stream: bool, response_format: dict = None, max_tokens: int = None, @@ -67,6 +67,39 @@ class Client(BaseClient): self.chat: Chat = Chat(self, provider) self.images: Images = Images(self, image_provider) +async def cast_iter_async(iter): + for chunk in iter: + yield chunk + +def create_response( + messages: Messages, + model: str, + provider: ProviderType = None, + stream: bool = False, + response_format: dict = None, + max_tokens: int = None, + stop: Union[list[str], str] = None, + api_key: str = None, + **kwargs +): + if hasattr(provider, "create_async_generator): + create = provider.create_async_generator + else: + create = provider.create_completion + response = create( + model, messages, stream, + **filter_none( + proxy=self.client.get_proxy(), + max_tokens=max_tokens, + stop=stop, + api_key=self.client.api_key if api_key is None else api_key + ), + **kwargs + ) + if not hasattr(provider, "create_async_generator") + response = cast_iter_async(response) + return response + class Completions(): def __init__(self, client: Client, provider: ProviderType = None): self.client: Client = client @@ -79,9 +112,6 @@ class Completions(): provider: ProviderType = None, stream: bool = False, response_format: dict = None, - max_tokens: int = None, - stop: Union[list[str], str] = None, - api_key: str = None, ignored : list[str] = None, ignore_working: bool = False, ignore_stream: bool = False, @@ -97,16 +127,7 @@ class Completions(): **kwargs ) stop = [stop] if isinstance(stop, str) else stop - response = provider.create_async( - model, messages, stream, - **filter_none( - proxy=self.client.get_proxy(), - max_tokens=max_tokens, - stop=stop, - api_key=self.client.api_key if api_key is None else api_key - ), - **kwargs - ) + response = create_response(messages, model, provider, stream, **kwargs) response = iter_response(response, stream, response_format, max_tokens, stop) response = iter_append_model_and_provider(response) return response if stream else anext(response) @@ -117,14 +138,14 @@ class Chat(): def __init__(self, client: Client, provider: ProviderType = None): self.completions = Completions(client, provider) -def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]: - for chunk in list(response): +async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]: + async for chunk in list(response): if isinstance(chunk, ImageProviderResponse): return ImagesResponse([Image(image) for image in chunk.get_list()]) -def create_image(client: Client, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> Iterator: +def create_image(client: Client, provider: ProviderType, prompt: str, model: str = "", **kwargs) -> AsyncIterator: prompt = f"create a image with: {prompt}" - return provider.create_completion( + return provider.create_async_generator( model, [{"role": "user", "content": prompt}], True, @@ -138,7 +159,7 @@ class Images(): self.provider: ImageProvider = provider self.models: ImageModels = ImageModels(client) - def generate(self, prompt, model: str = None, **kwargs) -> ImagesResponse: + async def generate(self, prompt, model: str = None, **kwargs) -> ImagesResponse: provider = self.models.get(model, self.provider) if isinstance(provider, type) and issubclass(provider, BaseProvider): response = create_image(self.client, provider, prompt, **kwargs) @@ -156,11 +177,11 @@ class Images(): raise NoImageResponseError() return image - def create_variation(self, image: ImageType, model: str = None, **kwargs): + async def create_variation(self, image: ImageType, model: str = None, **kwargs): provider = self.models.get(model, self.provider) result = None if isinstance(provider, type) and issubclass(provider, BaseProvider): - response = provider.create_completion( + response = provider.create_async_generator( "", [{"role": "user", "content": "create a image like this"}], True, @@ -168,7 +189,7 @@ class Images(): proxy=self.client.get_proxy(), **kwargs ) - for chunk in response: + async for chunk in response: if isinstance(chunk, ImageProviderResponse): result = ([chunk.images] if isinstance(chunk.images, str) else chunk.images) result = ImagesResponse([Image(image)for image in result]) |