summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to 'g4f')
-rw-r--r--g4f/client/__init__.py71
1 files changed, 33 insertions, 38 deletions
diff --git a/g4f/client/__init__.py b/g4f/client/__init__.py
index 6f679e4a..1f3cdab1 100644
--- a/g4f/client/__init__.py
+++ b/g4f/client/__init__.py
@@ -261,7 +261,7 @@ class Images:
if isinstance(provider_handler, IterListProvider):
if provider_handler.providers:
- provider_handler = provider.providers[0]
+ provider_handler = provider_handler.providers[0]
else:
raise ValueError(f"IterListProvider for model {model} has no providers")
@@ -287,44 +287,39 @@ class Images:
raise NoImageResponseError(f"Unexpected response type: {type(response)}")
async def _process_image_response(self, response: ImageResponse, response_format: str, proxy: str = None, model: str = None, provider: str = None) -> ImagesResponse:
- async def process_image_item(session: aiohttp.ClientSession, image_data: str):
- if image_data.startswith('http://') or image_data.startswith('https://'):
- if response_format == "url":
- return Image(url=image_data, revised_prompt=response.alt)
- elif response_format == "b64_json":
- # Fetch the image data and convert it to base64
- image_content = await self._fetch_image(session, image_data)
- file_name = self._save_image(image_data_bytes)
- b64_json = base64.b64encode(image_content).decode('utf-8')
- return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt)
- else:
- # Assume image_data is base64 data or binary
- if response_format == "url":
- if image_data.startswith('data:image'):
- # Remove the data URL scheme and get the base64 data
- base64_data = image_data.split(',', 1)[-1]
- else:
- base64_data = image_data
- # Decode the base64 data
- image_data_bytes = base64.b64decode(base64_data)
- # Convert bytes to an image
+ async def process_image_item(session: aiohttp.ClientSession, image_data: str):
+ image_data_bytes = None
+ if image_data.startswith("http://") or image_data.startswith("https://"):
+ if response_format == "url":
+ return Image(url=image_data, revised_prompt=response.alt)
+ elif response_format == "b64_json":
+ # Fetch the image data and convert it to base64
+ image_data_bytes = await self._fetch_image(session, image_data)
+ b64_json = base64.b64encode(image_data_bytes).decode("utf-8")
+ return Image(b64_json=b64_json, url=image_data, revised_prompt=response.alt)
+ else:
+ # Assume image_data is base64 data or binary
+ if response_format == "url":
+ if image_data.startswith("data:image"):
+ # Remove the data URL scheme and get the base64 data
+ base64_data = image_data.split(",", 1)[-1]
+ else:
+ base64_data = image_data
+ # Decode the base64 data
+ image_data_bytes = base64.b64decode(base64_data)
+ if image_data_bytes:
file_name = self._save_image(image_data_bytes)
return Image(url=file_name, revised_prompt=response.alt)
- elif response_format == "b64_json":
- if isinstance(image_data, bytes):
- file_name = self._save_image(image_data_bytes)
- b64_json = base64.b64encode(image_data).decode('utf-8')
- else:
- b64_json = image_data # If already base64-encoded string
- return Image(b64_json=b64_json, url=file_name, revised_prompt=response.alt)
-
- last_provider = get_last_provider(True)
- async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
- return ImagesResponse(
- await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
- model=last_provider.get("model") if model is None else model,
- provider=last_provider.get("name") if provider is None else provider
- )
+ else:
+ raise ValueError("Unable to process image data")
+
+ last_provider = get_last_provider(True)
+ async with aiohttp.ClientSession(cookies=response.get("cookies"), connector=get_connector(proxy=proxy)) as session:
+ return ImagesResponse(
+ await asyncio.gather(*[process_image_item(session, image_data) for image_data in response.get_list()]),
+ model=last_provider.get("model") if model is None else model,
+ provider=last_provider.get("name") if provider is None else provider
+ )
async def _fetch_image(self, session: aiohttp.ClientSession, url: str) -> bytes:
# Asynchronously fetch image data from the URL
@@ -464,4 +459,4 @@ class AsyncImages(Images):
async def create_variation(self, image: Union[str, bytes], model: str = None, provider: ProviderType = None, response_format: str = "url", **kwargs) -> ImagesResponse:
return await self.async_create_variation(
image, model, provider, response_format, **kwargs
- ) \ No newline at end of file
+ )