summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/needs_auth/Openai.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/needs_auth/Openai.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/g4f/Provider/needs_auth/Openai.py b/g4f/Provider/needs_auth/Openai.py
index f73c1011..9da6bad8 100644
--- a/g4f/Provider/needs_auth/Openai.py
+++ b/g4f/Provider/needs_auth/Openai.py
@@ -4,9 +4,10 @@ import json
from ..helper import filter_none
from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
-from ...typing import Union, Optional, AsyncResult, Messages
+from ...typing import Union, Optional, AsyncResult, Messages, ImageType
from ...requests import StreamSession, raise_for_status
from ...errors import MissingAuthError, ResponseError
+from ...image import to_data_uri
class Openai(AsyncGeneratorProvider, ProviderModelMixin):
label = "OpenAI API"
@@ -23,6 +24,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
messages: Messages,
proxy: str = None,
timeout: int = 120,
+ image: ImageType = None,
api_key: str = None,
api_base: str = "https://api.openai.com/v1",
temperature: float = None,
@@ -36,6 +38,19 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
) -> AsyncResult:
if cls.needs_auth and api_key is None:
raise MissingAuthError('Add a "api_key"')
+ if image is not None:
+ if not model and hasattr(cls, "default_vision_model"):
+ model = cls.default_vision_model
+ messages[-1]["content"] = [
+ {
+ "type": "image_url",
+ "image_url": {"url": to_data_uri(image)}
+ },
+ {
+ "type": "text",
+ "text": messages[-1]["content"]
+ }
+ ]
async with StreamSession(
proxies={"all": proxy},
headers=cls.get_headers(stream, api_key, headers),
@@ -51,7 +66,6 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
stream=stream,
**extra_data
)
-
async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
await raise_for_status(response)
if not stream:
@@ -103,8 +117,7 @@ class Openai(AsyncGeneratorProvider, ProviderModelMixin):
"Content-Type": "application/json",
**(
{"Authorization": f"Bearer {api_key}"}
- if cls.needs_auth and api_key is not None
- else {}
+ if api_key is not None else {}
),
**({} if headers is None else headers)
} \ No newline at end of file