summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/Openai.py
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-04-07 11:27:26 +0200
committerGitHub <noreply@github.com>2024-04-07 11:27:26 +0200
commitd327afc60620913f5d2b0a9985b03a7934468ad4 (patch)
tree395de9142af3e6b9c0e5e3968ee7f8234b8b25e2 /g4f/Provider/needs_auth/Openai.py
parentUpdate Gemini.py (diff)
parentUpdate provider.py (diff)
downloadgpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.gz
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.bz2
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.lz
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.xz
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.zst
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.zip
Diffstat (limited to 'g4f/Provider/needs_auth/Openai.py')
-rw-r--r--g4f/Provider/needs_auth/Openai.py96
1 files changed, 65 insertions, 31 deletions
diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py
index b876cd0b..6cd2cf86 100644
--- a/g4f/Provider/needs_auth/Openai.py
+++ b/g4f/Provider/needs_auth/Openai.py
@@ -3,10 +3,10 @@ from __future__ import annotations
import json
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
-from ...typing import AsyncResult, Messages
+from ...typing import Union, Optional, AsyncResult, Messages
from ...requests.raise_for_status import raise_for_status
from ...requests import StreamSession
-from ...errors import MissingAuthError
+from ...errors import MissingAuthError, ResponseError
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
url = "https://openai.com"
@@ -27,48 +27,82 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
temperature: float = None,
max_tokens: int = None,
top_p: float = None,
- stop: str = None,
+ stop: Union[str, list[str]] = None,
stream: bool = False,
+ headers: dict = None,
+ extra_data: dict = {},
**kwargs
) -> AsyncResult:
- if api_key is None:
+ if cls.needs_auth and api_key is None:
raise MissingAuthError('Add a "api_key"')
async with StreamSession(
proxies={"all": proxy},
- headers=cls.get_headers(api_key),
+ headers=cls.get_headers(stream, api_key, headers),
timeout=timeout
) as session:
- data = {
- "messages": messages,
- "model": cls.get_model(model),
- "temperature": temperature,
- "max_tokens": max_tokens,
- "top_p": top_p,
- "stop": stop,
- "stream": stream,
- }
+ data = filter_none(
+ messages=messages,
+ model=cls.get_model(model),
+ temperature=temperature,
+ max_tokens=max_tokens,
+ top_p=top_p,
+ stop=stop,
+ stream=stream,
+ **extra_data
+ )
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
- async for line in response.iter_lines():
- if line.startswith(b"data: ") or not stream:
- async for chunk in cls.read_line(line[6:] if stream else line, stream):
- yield chunk
+ if not stream:
+ data = await response.json()
+ choice = data["choices"][0]
+ if "content" in choice["message"]:
+ yield choice["message"]["content"].strip()
+ finish = cls.read_finish_reason(choice)
+ if finish is not None:
+ yield finish
+ else:
+ first = True
+ async for line in response.iter_lines():
+ if line.startswith(b"data: "):
+ chunk = line[6:]
+ if chunk == b"[DONE]":
+ break
+ data = json.loads(chunk)
+ if "error_message" in data:
+ raise ResponseError(data["error_message"])
+ choice = data["choices"][0]
+ if "content" in choice["delta"] and choice["delta"]["content"]:
+ delta = choice["delta"]["content"]
+ if first:
+ delta = delta.lstrip()
+ if delta:
+ first = False
+ yield delta
+ finish = cls.read_finish_reason(choice)
+ if finish is not None:
+ yield finish
@staticmethod
- async def read_line(line: str, stream: bool):
- if line == b"[DONE]":
- return
- choice = json.loads(line)["choices"][0]
- if stream and "content" in choice["delta"] and choice["delta"]["content"]:
- yield choice["delta"]["content"]
- elif not stream and "content" in choice["message"]:
- yield choice["message"]["content"]
+ def read_finish_reason(choice: dict) -> Optional[FinishReason]:
if "finish_reason" in choice and choice["finish_reason"] is not None:
- yield FinishReason(choice["finish_reason"])
+ return FinishReason(choice["finish_reason"])
- @staticmethod
- def get_headers(api_key: str) -> dict:
+ @classmethod
+ def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
return {
- "Authorization": f"Bearer {api_key}",
+ "Accept": "text/event-stream" if stream else "application/json",
"Content-Type": "application/json",
- } \ No newline at end of file
+ **(
+ {"Authorization": f"Bearer {api_key}"}
+ if cls.needs_auth and api_key is not None
+ else {}
+ ),
+ **({} if headers is None else headers)
+ }
+
+def filter_none(**kwargs) -> dict:
+ return {
+ key: value
+ for key, value in kwargs.items()
+ if value is not None
+ } \ No newline at end of file