diff options
Diffstat (limited to 'g4f')
-rw-r--r-- | g4f/client/__init__.py | 71 |
1 files changed, 33 insertions, 38 deletions
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 6f679e4a..1f3cdab1 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -261,7 +261,7 @@ class Images: if isinstance(provider_handler, IterListProvider): if provider_handler.providers: - provider_handler = provider.providers[0] + provider_handler = provider_handler.providers[0] else: raise ValueError(f"IterListProvider for model {model} has no providers") @@ -287,44 +287,39 @@ class Images: raise NoImageResponseError(f"Unexpected response type: {type(response)}") async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse: - async def process_image_item(session: aiohttp.ClientSession, image_data: str): - if image_data.startswith('http://') or image_data.startswith('https://'): - if response_format == "url": - return Image(url=image_data, revised_prompt=response.alt) - elif response_format == "b64_json": - # Fetch the image data and convert it to base64 - image_content = await self._fetch_image(session, image_data) - file_name = self._save_image(image_data_bytes) - b64_json = base64.b64encode(image_content).decode('utf-8') - return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt) - else: - # Assume image_data is base64 data or binary - if response_format == "url": - if image_data.startswith('data:image'): - # Remove the data URL scheme and get the base64 data - base64_data = image_data.split(',', 1)[-1] - else: - base64_data = image_data - # Decode the base64 data - image_data_bytes = base64.b64decode(base64_data) - # Convert bytes to an image + async def process_image_item(session: aiohttp.ClientSession, image_data: str): + image_data_bytes = None + if image_data.startswith("http://") or image_data.startswith("https://"): + if response_format == "url": + return Image(url=image_data, revised_prompt=response.alt) + elif response_format == "b64_json": + # Fetch the image data and convert it to base64 + image_data_bytes = await self._fetch_image(session, image_data) + b64_json = base64.b64encode(image_data_bytes).decode("utf-8") + return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt) + else: + # Assume image_data is base64 data or binary + if response_format == "url": + if image_data.startswith("data:image"): + # Remove the data URL scheme and get the base64 data + base64_data = image_data.split(",", 1)[-1] + else: + base64_data = image_data + # Decode the base64 data + image_data_bytes = base64.b64decode(base64_data) + if image_data_bytes: file_name = self._save_image(image_data_bytes) return Image(url=file_name, revised_prompt=response.alt) - elif response_format == "b64_json": - if isinstance(image_data, bytes): - file_name = self._save_image(image_data_bytes) - b64_json = base64.b64encode(image_data).decode('utf-8') - else: - b64_json = image_data # If already base64-encoded string - return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt) - - last_provider = get_last_provider(True) - async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session: - return ImagesResponse( - await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]), - model=last_provider.get("model") if model is None else model, - provider=last_provider.get("name") if provider is None else provider - ) + else: + raise ValueError("Unable to process image data") + + last_provider = get_last_provider(True) + async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session: + return ImagesResponse( + await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]), + model=last_provider.get("model") if model is None else model, + provider=last_provider.get("name") if provider is None else provider + ) async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes: # Asynchronously fetch image data from the URL @@ -464,4 +459,4 @@ class AsyncImages(Images): async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse: return await self.async_create_variation( image, model, provider, response_format, **kwargs - )
\ No newline at end of file + ) |