diff options
Diffstat (limited to 'g4f/Provider/ReplicateHome.py')
-rw-r--r-- | g4f/Provider/ReplicateHome.py | 46 |
1 files changed, 19 insertions, 27 deletions
diff --git a/g4f/Provider/ReplicateHome.py b/g4f/Provider/ReplicateHome.py index a7fc9b54..00de09e0 100644 --- a/g4f/Provider/ReplicateHome.py +++ b/g4f/Provider/ReplicateHome.py @@ -6,6 +6,8 @@ from aiohttp import ClientSession, ContentTypeError from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..requests.aiohttp import get_connector +from ..requests.raise_for_status import raise_for_status from .helper import format_prompt from ..image import ImageResponse @@ -32,10 +34,8 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): 'yorickvp/llava-13b', ] - - models = text_models + image_models - + model_aliases = { # image_models "sd-3": "stability-ai/stable-diffusion-3", @@ -56,23 +56,14 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): # text_models "google-deepmind/gemma-2b-it": "dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626", "yorickvp/llava-13b": "80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb", - } @classmethod - def get_model(cls, model: str) -> str: - if model in cls.models: - return model - elif model in cls.model_aliases: - return cls.model_aliases[model] - else: - return cls.default_model - - @classmethod async def create_async_generator( cls, model: str, messages: Messages, + prompt: str = None, proxy: str = None, **kwargs ) -> AsyncResult: @@ -96,29 +87,30 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36" } - async with ClientSession(headers=headers) as session: - if model in cls.image_models: - prompt = messages[-1]['content'] if messages else "" - else: - prompt = format_prompt(messages) - + async with ClientSession(headers=headers, connector=get_connector(proxy=proxy)) as session: + if prompt is None: + if model in cls.image_models: + prompt = messages[-1]['content'] + else: + prompt = format_prompt(messages) + data = { "model": model, "version": cls.model_versions[model], "input": {"prompt": prompt}, } - - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() + + async with session.post(cls.api_endpoint, json=data) as response: + await raise_for_status(response) result = await response.json() prediction_id = result['id'] - + poll_url = f"https://homepage.replicate.com/api/poll?id={prediction_id}" max_attempts = 30 delay = 5 for _ in range(max_attempts): - async with session.get(poll_url, proxy=proxy) as response: - response.raise_for_status() + async with session.get(poll_url) as response: + await raise_for_status(response) try: result = await response.json() except ContentTypeError: @@ -131,7 +123,7 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): if result['status'] == 'succeeded': if model in cls.image_models: image_url = result['output'][0] - yield ImageResponse(image_url, "Generated image") + yield ImageResponse(image_url, prompt) return else: for chunk in result['output']: @@ -140,6 +132,6 @@ class ReplicateHome(AsyncGeneratorProvider, ProviderModelMixin): elif result['status'] == 'failed': raise Exception(f"Prediction failed: {result.get('error')}") await asyncio.sleep(delay) - + if result['status'] != 'succeeded': raise Exception("Prediction timed out") |