diff options
Diffstat (limited to 'g4f/image.py')
-rw-r--r-- | g4f/image.py | 48 |
1 files changed, 44 insertions, 4 deletions
diff --git a/g4f/image.py b/g4f/image.py index 8a3d7a74..114dcc13 100644 --- a/g4f/image.py +++ b/g4f/image.py @@ -1,9 +1,13 @@ from __future__ import annotations +import os import re +import time +import uuid from io import BytesIO import base64 -from .typing import ImageType, Union, Image +import asyncio +from aiohttp import ClientSession try: from PIL.Image import open as open_image, new as new_image @@ -12,7 +16,10 @@ try: except ImportError: has_requirements = False +from .typing import ImageType, Union, Image, Optional, Cookies from .errors import MissingRequirementsError +from .providers.response import ResponseType +from .requests.aiohttp import get_connector ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'} @@ -23,10 +30,12 @@ EXTENSIONS_MAP: dict[str, str] = { "image/webp": "webp", } +# Define the directory for generated images +images_dir = "./generated_images" + def fix_url(url:str) -> str: """ replace ' ' by '+' (to be markdown compliant)""" return url.replace(" ","+") - def to_image(image: ImageType, is_svg: bool = False) -> Image: """ @@ -223,7 +232,6 @@ def format_images_markdown(images: Union[str, list], alt: str, preview: Union[st preview = [preview.replace('{image}', image) if preview else image for image in images] result = "\n".join( f"[![#{idx+1} {alt}]({fix_url(preview[idx])})]({fix_url(image)})" - #f'[<img src="{preview[idx]}" width="200" alt="#{idx+1} {alt}">]({image})' for idx, image in enumerate(images) ) start_flag = "<!-- generated images start -->\n" @@ -260,7 +268,39 @@ def to_data_uri(image: ImageType) -> str: return f"data:{is_accepted_format(data)};base64,{data_base64}" return image -class ImageResponse: +# Function to ensure the images directory exists +def ensure_images_dir(): + if not os.path.exists(images_dir): + os.makedirs(images_dir) + +async def copy_images(images: list[str], cookies: Optional[Cookies] = None, proxy: Optional[str] = None): + ensure_images_dir() + async with ClientSession( + connector=get_connector( + proxy=os.environ.get("G4F_PROXY") if proxy is None else proxy + ), + cookies=cookies + ) as session: + async def copy_image(image: str) -> str: + target = os.path.join(images_dir, f"{int(time.time())}_{str(uuid.uuid4())}") + if image.startswith("data:"): + with open(target, "wb") as f: + f.write(extract_data_uri(image)) + else: + async with session.get(image) as response: + with open(target, "wb") as f: + async for chunk in response.content.iter_chunked(4096): + f.write(chunk) + with open(target, "rb") as f: + extension = is_accepted_format(f.read(12)).split("/")[-1] + extension = "jpg" if extension == "jpeg" else extension + new_target = f"{target}.{extension}" + os.rename(target, new_target) + return f"/images/{os.path.basename(new_target)}" + + return await asyncio.gather(*[copy_image(image) for image in images]) + +class ImageResponse(ResponseType): def __init__( self, images: Union[str, list], |