diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/GigaChat.py | 23 |
1 files changed, 9 insertions, 14 deletions
diff --git a/g4f/Provider/GigaChat.py b/g4f/Provider/GigaChat.py index 699353b1..c1ec7f5e 100644 --- a/g4f/Provider/GigaChat.py +++ b/g4f/Provider/GigaChat.py @@ -1,35 +1,28 @@ from __future__ import annotations -import base64 import os import ssl import time import uuid import json -from aiohttp import ClientSession, BaseConnector, TCPConnector +from aiohttp import ClientSession, TCPConnector, BaseConnector from g4f.requests import raise_for_status -from ..typing import AsyncResult, Messages, ImageType +from ..typing import AsyncResult, Messages from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -from ..image import to_bytes, is_accepted_format from ..errors import MissingAuthError from .helper import get_connector -access_token = '' +access_token = "" token_expires_at = 0 -ssl_ctx = ssl.create_default_context( - cafile=os.path.dirname(__file__) + '/gigachat_crt/russian_trusted_root_ca_pem.crt') - - class GigaChat(AsyncGeneratorProvider, ProviderModelMixin): url = "https://developers.sber.ru/gigachat" working = True supports_message_history = True supports_system_message = True supports_stream = True - needs_auth = True default_model = "GigaChat:latest" models = ["GigaChat:latest", "GigaChat-Plus", "GigaChat-Pro"] @@ -42,18 +35,20 @@ class GigaChat(AsyncGeneratorProvider, ProviderModelMixin): stream: bool = True, proxy: str = None, api_key: str = None, + coonector: BaseConnector = None, scope: str = "GIGACHAT_API_PERS", update_interval: float = 0, **kwargs ) -> AsyncResult: global access_token, token_expires_at model = cls.get_model(model) - if not api_key: raise MissingAuthError('Missing "api_key"') - - connector = TCPConnector(ssl_context=ssl_ctx) - + + cafile = os.path.join(os.path.dirname(__file__), "gigachat_crt/russian_trusted_root_ca_pem.crt") + ssl_context = ssl.create_default_context(cafile=cafile) if os.path.exists(cafile) else None + if connector is None and ssl_context is not None: + connector = TCPConnector(ssl_context=ssl_context) async with ClientSession(connector=get_connector(connector, proxy)) as session: if token_expires_at - int(time.time() * 1000) < 60000: async with session.post(url="https://ngw.devices.sberbank.ru:9443/api/v2/oauth", |