diff options
Diffstat (limited to 'g4f/Provider/needs_auth/ThebApi.py')
-rw-r--r-- | g4f/Provider/needs_auth/ThebApi.py | 57 |
1 files changed, 18 insertions, 39 deletions
diff --git a/g4f/Provider/needs_auth/ThebApi.py b/g4f/Provider/needs_auth/ThebApi.py index 1c7baf8d..48879bcb 100644 --- a/g4f/Provider/needs_auth/ThebApi.py +++ b/g4f/Provider/needs_auth/ThebApi.py @@ -1,10 +1,7 @@ from __future__ import annotations -import requests - -from ...typing import Any, CreateResult, Messages -from ..base_provider import AbstractProvider, ProviderModelMixin -from ...errors import MissingAuthError +from ...typing import CreateResult, Messages +from .Openai import Openai models = { "theb-ai": "TheB.AI", @@ -30,7 +27,7 @@ models = { "qwen-7b-chat": "Qwen 7B" } -class ThebApi(AbstractProvider, ProviderModelMixin): +class ThebApi(Openai): url = "https://theb.ai" working = True needs_auth = True @@ -38,44 +35,26 @@ class ThebApi(AbstractProvider, ProviderModelMixin): models = list(models) @classmethod - def create_completion( + def create_async_generator( cls, model: str, messages: Messages, - stream: bool, - auth: str = None, - proxy: str = None, + api_base: str = "https://api.theb.ai/v1", + temperature: float = 1, + top_p: float = 1, **kwargs ) -> CreateResult: - if not auth: - raise MissingAuthError("Missing auth") - headers = { - 'accept': 'application/json', - 'authorization': f'Bearer {auth}', - 'content-type': 'application/json', - } - # response = requests.get("https://api.baizhi.ai/v1/models", headers=headers).json()["data"] - # models = dict([(m["id"], m["name"]) for m in response]) - # print(json.dumps(models, indent=4)) - data: dict[str, Any] = { - "model": cls.get_model(model), - "messages": messages, - "stream": False, + if "auth" in kwargs: + kwargs["api_key"] = kwargs["auth"] + system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"]) + if not system_message: + system_message = "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture." + messages = [message for message in messages if message["role"] != "system"] + data = { "model_params": { - "system_prompt": kwargs.get("system_message", "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."), - "temperature": 1, - "top_p": 1, - **kwargs + "system_prompt": system_message, + "temperature": temperature, + "top_p": top_p, } } - response = requests.post( - "https://api.theb.ai/v1/chat/completions", - headers=headers, - json=data, - proxies={"https": proxy} - ) - try: - response.raise_for_status() - yield response.json()["choices"][0]["message"]["content"] - except: - raise RuntimeError(f"Response: {next(response.iter_lines()).decode()}")
\ No newline at end of file + return super().create_async_generator(model, messages, api_base=api_base, extra_data=data, **kwargs)
\ No newline at end of file |