diff options
Diffstat (limited to 'g4f/Provider/create_images.py')
-rw-r--r-- | g4f/Provider/create_images.py | 84 |
1 files changed, 84 insertions, 0 deletions
diff --git a/g4f/Provider/create_images.py b/g4f/Provider/create_images.py new file mode 100644 index 00000000..9c76a742 --- /dev/null +++ b/g4f/Provider/create_images.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import re +import asyncio +from ..typing import CreateResult, Messages +from ..base_provider import BaseProvider, ProviderType + +system_message = """ +You can generate custom images with the DALL-E 3 image generator. +To generate a image with a prompt, do this: +<img data-prompt=\"keywords for the image\"> +Don't use images with data uri. It is important to use a prompt instead. +<img data-prompt=\"image caption\"> +""" + +class CreateImagesProvider(BaseProvider): + def __init__( + self, + provider: ProviderType, + create_images: callable, + create_async: callable, + system_message: str = system_message, + include_placeholder: bool = True + ) -> None: + self.provider = provider + self.create_images = create_images + self.create_images_async = create_async + self.system_message = system_message + self.__name__ = provider.__name__ + self.working = provider.working + self.supports_stream = provider.supports_stream + self.include_placeholder = include_placeholder + if hasattr(provider, "url"): + self.url = provider.url + + def create_completion( + self, + model: str, + messages: Messages, + stream: bool = False, + **kwargs + ) -> CreateResult: + messages.insert(0, {"role": "system", "content": self.system_message}) + buffer = "" + for chunk in self.provider.create_completion(model, messages, stream, **kwargs): + if buffer or "<" in chunk: + buffer += chunk + if ">" in buffer: + match = re.search(r'<img data-prompt="(.*?)">', buffer) + if match: + placeholder, prompt = match.group(0), match.group(1) + start, append = buffer.split(placeholder, 1) + if start: + yield start + if self.include_placeholder: + yield placeholder + yield from self.create_images(prompt) + if append: + yield append + else: + yield buffer + buffer = "" + else: + yield chunk + + async def create_async( + self, + model: str, + messages: Messages, + **kwargs + ) -> str: + messages.insert(0, {"role": "system", "content": self.system_message}) + response = await self.provider.create_async(model, messages, **kwargs) + matches = re.findall(r'(<img data-prompt="(.*?)">)', result) + results = [] + for _, prompt in matches: + results.append(self.create_images_async(prompt)) + results = await asyncio.gather(*results) + for idx, result in enumerate(results): + placeholder = matches[idx][0] + if self.include_placeholder: + result = placeholder + result + response = response.replace(placeholder, result) + return result
\ No newline at end of file |