diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/nexra/NexraChatGptWeb.py | 75 |
1 files changed, 35 insertions, 40 deletions
diff --git a/g4f/Provider/nexra/NexraChatGptWeb.py b/g4f/Provider/nexra/NexraChatGptWeb.py index d14a2162..f82694d4 100644 --- a/g4f/Provider/nexra/NexraChatGptWeb.py +++ b/g4f/Provider/nexra/NexraChatGptWeb.py @@ -1,29 +1,21 @@ from __future__ import annotations -from aiohttp import ClientSession, ContentTypeError import json +import requests -from ...typing import AsyncResult, Messages -from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ...typing import CreateResult, Messages +from ..base_provider import ProviderModelMixin, AbstractProvider from ..helper import format_prompt - -class NexraChatGptWeb(AsyncGeneratorProvider, ProviderModelMixin): +class NexraChatGptWeb(AbstractProvider, ProviderModelMixin): label = "Nexra ChatGPT Web" url = "https://nexra.aryahcr.cc/documentation/chatgpt/en" - api_endpoint = "https://nexra.aryahcr.cc/api/chat/{}" working = True - supports_gpt_35_turbo = True - supports_gpt_4 = True - supports_stream = True - default_model = 'gptweb' + default_model = "gptweb" models = [default_model] - - model_aliases = { - "gpt-4": "gptweb", - } - + model_aliases = {"gpt-4": "gptweb"} + api_endpoints = {"gptweb": "https://nexra.aryahcr.cc/api/chat/gptweb"} @classmethod def get_model(cls, model: str) -> str: @@ -33,37 +25,40 @@ class NexraChatGptWeb(AsyncGeneratorProvider, ProviderModelMixin): return cls.model_aliases[model] else: return cls.default_model - + @classmethod - async def create_async_generator( + def create_completion( cls, model: str, messages: Messages, proxy: str = None, markdown: bool = False, **kwargs - ) -> AsyncResult: + ) -> CreateResult: + model = cls.get_model(model) + api_endpoint = cls.api_endpoints.get(model, cls.api_endpoints[cls.default_model]) + headers = { - "Content-Type": "application/json" + 'Content-Type': 'application/json' } - async with ClientSession(headers=headers) as session: - prompt = format_prompt(messages) - data = { - "prompt": prompt, - "markdown": markdown - } - model = cls.get_model(model) - endpoint = cls.api_endpoint.format(model) - async with session.post(endpoint, json=data, proxy=proxy) as response: - response.raise_for_status() - response_text = await response.text() - - # Remove leading underscore if present - if response_text.startswith('_'): - response_text = response_text[1:] - - try: - response_data = json.loads(response_text) - yield response_data.get('gpt', response_text) - except json.JSONDecodeError: - yield response_text + + data = { + "prompt": format_prompt(messages), + "markdown": markdown + } + + response = requests.post(api_endpoint, headers=headers, json=data) + + return cls.process_response(response) + + @classmethod + def process_response(cls, response): + if response.status_code == 200: + try: + content = response.text.lstrip('_') + json_response = json.loads(content) + return json_response.get('gpt', '') + except json.JSONDecodeError: + return "Error: Unable to decode JSON response" + else: + return f"Error: {response.status_code}" |