summaryrefslogtreecommitdiffstats
path: root/g4f/client/async_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/client/async_client.py')
-rw-r--r--g4f/client/async_client.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 8e1ee33c..07ad3357 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -11,10 +11,9 @@ from .types import AsyncIterResponse, ImageProvider
from .image_models import ImageModels
from .helper import filter_json, find_stop, filter_none, cast_iter_async
from .service import get_last_provider, get_model_and_provider
-from ..typing import Union, Iterator, Messages, AsyncIterator, ImageType
+from ..typing import Union, Messages, AsyncIterator, ImageType
from ..errors import NoImageResponseError
from ..image import ImageResponse as ImageProviderResponse
-from ..providers.base_provider import AsyncGeneratorProvider
try:
anext
@@ -88,7 +87,7 @@ def create_response(
api_key: str = None,
**kwargs
):
- has_asnyc = isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider)
+ has_asnyc = hasattr(provider, "create_async_generator")
if has_asnyc:
create = provider.create_async_generator
else:
@@ -157,7 +156,7 @@ class Chat():
def __init__(self, client: AsyncClient, provider: ProviderType = None):
self.completions = Completions(client, provider)
-async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]:
+async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
async for chunk in response:
if isinstance(chunk, ImageProviderResponse):
return ImagesResponse([Image(image) for image in chunk.get_list()])
@@ -182,7 +181,7 @@ class Images():
async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
provider = self.models.get(model, self.provider)
- if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ if hasattr(provider, "create_async_generator"):
response = create_image(self.client, provider, prompt, **kwargs)
else:
response = await provider.create_async(prompt)
@@ -195,7 +194,7 @@ class Images():
async def create_variation(self, image: ImageType, model: str = None, **kwargs):
provider = self.models.get(model, self.provider)
result = None
- if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ if hasattr(provider, "create_async_generator"):
response = provider.create_async_generator(
"",
[{"role": "user", "content": "create a image like this"}],