diff options
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 40 |
1 files changed, 18 insertions, 22 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 56d79ee6..d5f23931 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -9,20 +9,19 @@ import math class BaseProvider(ABC): url: str - working = False - needs_auth = False - supports_stream = False + working = False + needs_auth = False + supports_stream = False supports_gpt_35_turbo = False - supports_gpt_4 = False + supports_gpt_4 = False @staticmethod @abstractmethod def create_completion( model: str, messages: list[dict[str, str]], - stream: bool, - **kwargs: Any, - ) -> CreateResult: + stream: bool, **kwargs: Any) -> CreateResult: + raise NotImplementedError() @classmethod @@ -42,8 +41,10 @@ _cookies = {} def get_cookies(cookie_domain: str) -> dict: if cookie_domain not in _cookies: _cookies[cookie_domain] = {} + for cookie in browser_cookie3.load(cookie_domain): _cookies[cookie_domain][cookie.name] = cookie.value + return _cookies[cookie_domain] @@ -53,18 +54,15 @@ class AsyncProvider(BaseProvider): cls, model: str, messages: list[dict[str, str]], - stream: bool = False, - **kwargs: Any - ) -> CreateResult: + stream: bool = False, **kwargs: Any) -> CreateResult: + yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @abstractmethod async def create_async( model: str, - messages: list[dict[str, str]], - **kwargs: Any, - ) -> str: + messages: list[dict[str, str]], **kwargs: Any) -> str: raise NotImplementedError() @@ -74,9 +72,8 @@ class AsyncGeneratorProvider(AsyncProvider): cls, model: str, messages: list[dict[str, str]], - stream: bool = True, - **kwargs: Any - ) -> CreateResult: + stream: bool = True, **kwargs: Any) -> CreateResult: + if stream: yield from run_generator(cls.create_async_generator(model, messages, **kwargs)) else: @@ -86,9 +83,8 @@ class AsyncGeneratorProvider(AsyncProvider): async def create_async( cls, model: str, - messages: list[dict[str, str]], - **kwargs: Any, - ) -> str: + messages: list[dict[str, str]], **kwargs: Any) -> str: + chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)] if chunks: return "".join(chunks) @@ -97,14 +93,14 @@ class AsyncGeneratorProvider(AsyncProvider): @abstractmethod def create_async_generator( model: str, - messages: list[dict[str, str]], - ) -> AsyncGenerator: + messages: list[dict[str, str]]) -> AsyncGenerator: + raise NotImplementedError() def run_generator(generator: AsyncGenerator[Union[Any, str], Any]): loop = asyncio.new_event_loop() - gen = generator.__aiter__() + gen = generator.__aiter__() while True: try: |