From 5a79d8cbd7d99510c9f7f504e876f5197a64927b Mon Sep 17 00:00:00 2001 From: kqlio67 Date: Tue, 22 Oct 2024 15:27:01 +0300 Subject: Restored provider (g4f/Provider/nexra/NexraSDLora.py) --- g4f/Provider/nexra/NexraSDLora.py | 81 ++++++++++++++++++++------------------- g4f/models.py | 11 +++++- 2 files changed, 51 insertions(+), 41 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/nexra/NexraSDLora.py b/g4f/Provider/nexra/NexraSDLora.py index a33afa04..a12bff1a 100644 --- a/g4f/Provider/nexra/NexraSDLora.py +++ b/g4f/Provider/nexra/NexraSDLora.py @@ -1,28 +1,26 @@ from __future__ import annotations -from aiohttp import ClientSession import json - -from ...typing import AsyncResult, Messages -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin +import requests +from ...typing import CreateResult, Messages +from ..base_provider import ProviderModelMixin, AbstractProvider from ...image import ImageResponse - -class NexraSDLora(AsyncGeneratorProvider, ProviderModelMixin): +class NexraSDLora(AbstractProvider, ProviderModelMixin): label = "Nexra Stable Diffusion Lora" url = "https://nexra.aryahcr.cc/documentation/stable-diffusion/en" api_endpoint = "https://nexra.aryahcr.cc/api/image/complements" - working = False + working = True - default_model = 'sdxl-lora' + default_model = "sdxl-lora" models = [default_model] @classmethod def get_model(cls, model: str) -> str: return cls.default_model - + @classmethod - async def create_async_generator( + def create_completion( cls, model: str, messages: Messages, @@ -31,38 +29,41 @@ class NexraSDLora(AsyncGeneratorProvider, ProviderModelMixin): guidance: str = 0.3, # Min: 0, Max: 5 steps: str = 2, # Min: 2, Max: 10 **kwargs - ) -> AsyncResult: + ) -> CreateResult: model = cls.get_model(model) - + headers = { - "Content-Type": "application/json" + 'Content-Type': 'application/json' } - async with ClientSession(headers=headers) as session: - prompt = messages[0]['content'] - data = { - "prompt": prompt, - "model": model, - "response": response, - "data": { - "guidance": guidance, - "steps": steps - } + + data = { + "prompt": messages[-1]["content"], + "model": model, + "response": response, + "data": { + "guidance": guidance, + "steps": steps } - async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response: - text_data = await response.text() - - if response.status == 200: - try: - json_start = text_data.find('{') - json_data = text_data[json_start:] - - data = json.loads(json_data) - if 'images' in data and len(data['images']) > 0: - image_url = data['images'][-1] - yield ImageResponse(image_url, prompt) - else: - yield ImageResponse("No images found in the response.", prompt) - except json.JSONDecodeError: - yield ImageResponse("Failed to parse JSON. Response might not be in JSON format.", prompt) + } + + response = requests.post(cls.api_endpoint, headers=headers, json=data) + + result = cls.process_response(response) + yield result + + @classmethod + def process_response(cls, response): + if response.status_code == 200: + try: + content = response.text.strip() + content = content.lstrip('_') + data = json.loads(content) + if data.get('status') and data.get('images'): + image_url = data['images'][0] + return ImageResponse(images=[image_url], alt="Generated Image") else: - yield ImageResponse(f"Request failed with status: {response.status}", prompt) + return "Error: No image URL found in the response" + except json.JSONDecodeError as e: + return f"Error: Unable to decode JSON response. Details: {str(e)}" + else: + return f"Error: {response.status_code}, Response: {response.text}" diff --git a/g4f/models.py b/g4f/models.py index 542967f2..bfc68096 100644 --- a/g4f/models.py +++ b/g4f/models.py @@ -53,6 +53,7 @@ from .Provider import ( NexraMidjourney, NexraQwen, NexraSD15, + NexraSDLora, NexraSDTurbo, OpenaiChat, PerplexityLabs, @@ -742,10 +743,17 @@ sdxl_turbo = Model( ) +sdxl_lora = Model( + name = 'sdxl-lora', + base_provider = 'Stability AI', + best_provider = NexraSDLora + +) + sdxl = Model( name = 'sdxl', base_provider = 'Stability AI', - best_provider = IterListProvider([ReplicateHome, DeepInfraImage, sdxl_turbo.best_provider]) + best_provider = IterListProvider([ReplicateHome, DeepInfraImage, sdxl_turbo.best_provider, sdxl_lora.best_provider]) ) @@ -1111,6 +1119,7 @@ class ModelUtils: ### Stability AI ### 'sdxl': sdxl, +'sdxl-lora': sdxl_lora, 'sdxl-turbo': sdxl_turbo, 'sd-1.5': sd_1_5, 'sd-3': sd_3, -- cgit v1.2.3