From 8e2723938a280c7b525bac1d847fe80a5c2022ef Mon Sep 17 00:00:00 2001 From: kqlio67 Date: Sun, 17 Nov 2024 15:33:18 +0200 Subject: Refactor Image Processing and Error Handling in g4f Client Module --- g4f/client/__init__.py | 73 +++++++++++++++++++++++--------------------------- 1 file changed, 34 insertions(+), 39 deletions(-) (limited to 'g4f/client/__init__.py') diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py index 5ffe9288..3adb18ef 100644 --- a/g4f/client/__init__.py +++ b/g4f/client/__init__.py @@ -247,7 +247,7 @@ class Images: """ Synchronous generate method that runs the async_generate method in an event loop. """ - return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy **kwargs)) + return asyncio.run(self.async_generate(prompt, model, provider, response_format=response_format, proxy=proxy, **kwargs)) async def async_generate(self, prompt: str, model: str = None, provider: ProviderType = None, response_format: str = "url", proxy: str = None, **kwargs) -> ImagesResponse: if provider is None: @@ -261,7 +261,7 @@ class Images: if isinstance(provider_handler, IterListProvider): if provider_handler.providers: - provider_handler = provider.providers[0] + provider_handler = provider_handler.providers[0] else: raise ValueError(f"IterListProvider for model {model} has no providers") @@ -287,44 +287,39 @@ class Images: raise NoImageResponseError(f"Unexpected response type: {type(response)}") async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse: - async def process_image_item(session: aiohttp.ClientSession, image_data: str): - if image_data.startswith('http://') or image_data.startswith('https://'): - if response_format == "url": - return Image(url=image_data, revised_prompt=response.alt) - elif response_format == "b64_json": - # Fetch the image data and convert it to base64 - image_content = await self._fetch_image(session, image_data) - file_name = self._save_image(image_data_bytes) - b64_json = base64.b64encode(image_content).decode('utf-8') - return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt) - else: - # Assume image_data is base64 data or binary - if response_format == "url": - if image_data.startswith('data:image'): - # Remove the data URL scheme and get the base64 data - base64_data = image_data.split(',', 1)[-1] - else: - base64_data = image_data - # Decode the base64 data - image_data_bytes = base64.b64decode(base64_data) - # Convert bytes to an image + async def process_image_item(session: aiohttp.ClientSession, image_data: str): + image_data_bytes = None + if image_data.startswith("http://") or image_data.startswith("https://"): + if response_format == "url": + return Image(url=image_data, revised_prompt=response.alt) + elif response_format == "b64_json": + # Fetch the image data and convert it to base64 + image_data_bytes = await self._fetch_image(session, image_data) + b64_json = base64.b64encode(image_data_bytes).decode("utf-8") + return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt) + else: + # Assume image_data is base64 data or binary + if response_format == "url": + if image_data.startswith("data:image"): + # Remove the data URL scheme and get the base64 data + base64_data = image_data.split(",", 1)[-1] + else: + base64_data = image_data + # Decode the base64 data + image_data_bytes = base64.b64decode(base64_data) + if image_data_bytes: file_name = self._save_image(image_data_bytes) return Image(url=file_name, revised_prompt=response.alt) - elif response_format == "b64_json": - if isinstance(image_data, bytes): - file_name = self._save_image(image_data_bytes) - b64_json = base64.b64encode(image_data).decode('utf-8') - else: - b64_json = image_data # If already base64-encoded string - return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt) - - last_provider = get_last_provider(True) - async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session: - return ImagesResponse( - await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]), - model=last_provider.get("model") if model is None else model, - provider=last_provider.get("name") if provider is None else provider - ) + else: + raise ValueError("Unable to process image data") + + last_provider = get_last_provider(True) + async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session: + return ImagesResponse( + await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]), + model=last_provider.get("model") if model is None else model, + provider=last_provider.get("name") if provider is None else provider + ) async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes: # Asynchronously fetch image data from the URL @@ -465,4 +460,4 @@ class AsyncImages(Images): async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse: return await self.async_create_variation( image, model, provider, response_format, **kwargs - ) \ No newline at end of file + ) -- cgit v1.2.3