summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/ThebApi.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/needs_auth/ThebApi.py')
-rw-r--r--g4f/Provider/needs_auth/ThebApi.py57
1 files changed, 18 insertions, 39 deletions
diff --git a/g4f/Provider/needs_auth/ThebApi.py b/g4f/Provider/needs_auth/ThebApi.py
index 1c7baf8d..48879bcb 100644
--- a/g4f/Provider/needs_auth/ThebApi.py
+++ b/g4f/Provider/needs_auth/ThebApi.py
@@ -1,10 +1,7 @@
from __future__ import annotations
-import requests
-
-from ...typing import Any, CreateResult, Messages
-from ..base_provider import AbstractProvider, ProviderModelMixin
-from ...errors import MissingAuthError
+from ...typing import CreateResult, Messages
+from .Openai import Openai
models = {
"theb-ai": "TheB.AI",
@@ -30,7 +27,7 @@ models = {
"qwen-7b-chat": "Qwen 7B"
}
-class ThebApi(AbstractProvider, ProviderModelMixin):
+class ThebApi(Openai):
url = "https://theb.ai"
working = True
needs_auth = True
@@ -38,44 +35,26 @@ class ThebApi(AbstractProvider, ProviderModelMixin):
models = list(models)
@classmethod
- def create_completion(
+ def create_async_generator(
cls,
model: str,
messages: Messages,
- stream: bool,
- auth: str = None,
- proxy: str = None,
+ api_base: str = "https://api.theb.ai/v1",
+ temperature: float = 1,
+ top_p: float = 1,
**kwargs
) -> CreateResult:
- if not auth:
- raise MissingAuthError("Missing auth")
- headers = {
- 'accept': 'application/json',
- 'authorization': f'Bearer {auth}',
- 'content-type': 'application/json',
- }
- # response = requests.get("https://api.baizhi.ai/v1/models", headers=headers).json()["data"]
- # models = dict([(m["id"], m["name"]) for m in response])
- # print(json.dumps(models, indent=4))
- data: dict[str, Any] = {
- "model": cls.get_model(model),
- "messages": messages,
- "stream": False,
+ if "auth" in kwargs:
+ kwargs["api_key"] = kwargs["auth"]
+ system_message = "\n".join([message["content"] for message in messages if message["role"] == "system"])
+ if not system_message:
+ system_message = "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."
+ messages = [message for message in messages if message["role"] != "system"]
+ data = {
"model_params": {
- "system_prompt": kwargs.get("system_message", "You are ChatGPT, a large language model trained by OpenAI, based on the GPT-3.5 architecture."),
- "temperature": 1,
- "top_p": 1,
- **kwargs
+ "system_prompt": system_message,
+ "temperature": temperature,
+ "top_p": top_p,
}
}
- response = requests.post(
- "https://api.theb.ai/v1/chat/completions",
- headers=headers,
- json=data,
- proxies={"https": proxy}
- )
- try:
- response.raise_for_status()
- yield response.json()["choices"][0]["message"]["content"]
- except:
- raise RuntimeError(f"Response: {next(response.iter_lines()).decode()}") \ No newline at end of file
+ return super().create_async_generator(model, messages, api_base=api_base, extra_data=data, **kwargs) \ No newline at end of file