diff options
author | Luneye <73485421+Luneye@users.noreply.github.com> | 2023-08-28 16:55:36 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-28 16:55:36 +0200 |
commit | 01294db6995511de37e9078e03ce32e54dbdad52 (patch) | |
tree | d6c4c14f4e6a3a81660ddb75272bc5da81cacecc /g4f/Provider/base_provider.py | |
parent | Update Bing.py (diff) | |
parent | ~ | code styling (diff) | |
download | gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.tar gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.tar.gz gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.tar.bz2 gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.tar.lz gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.tar.xz gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.tar.zst gpt4free-01294db6995511de37e9078e03ce32e54dbdad52.zip |
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 40 |
1 files changed, 18 insertions, 22 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 56d79ee6..d5f23931 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -9,20 +9,19 @@ import math class BaseProvider(ABC): url: str - working = False - needs_auth = False - supports_stream = False + working = False + needs_auth = False + supports_stream = False supports_gpt_35_turbo = False - supports_gpt_4 = False + supports_gpt_4 = False @staticmethod @abstractmethod def create_completion( model: str, messages: list[dict[str, str]], - stream: bool, - **kwargs: Any, - ) -> CreateResult: + stream: bool, **kwargs: Any) -> CreateResult: + raise NotImplementedError() @classmethod @@ -42,8 +41,10 @@ _cookies = {} def get_cookies(cookie_domain: str) -> dict: if cookie_domain not in _cookies: _cookies[cookie_domain] = {} + for cookie in browser_cookie3.load(cookie_domain): _cookies[cookie_domain][cookie.name] = cookie.value + return _cookies[cookie_domain] @@ -53,18 +54,15 @@ class AsyncProvider(BaseProvider): cls, model: str, messages: list[dict[str, str]], - stream: bool = False, - **kwargs: Any - ) -> CreateResult: + stream: bool = False, **kwargs: Any) -> CreateResult: + yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @abstractmethod async def create_async( model: str, - messages: list[dict[str, str]], - **kwargs: Any, - ) -> str: + messages: list[dict[str, str]], **kwargs: Any) -> str: raise NotImplementedError() @@ -74,9 +72,8 @@ class AsyncGeneratorProvider(AsyncProvider): cls, model: str, messages: list[dict[str, str]], - stream: bool = True, - **kwargs: Any - ) -> CreateResult: + stream: bool = True, **kwargs: Any) -> CreateResult: + if stream: yield from run_generator(cls.create_async_generator(model, messages, **kwargs)) else: @@ -86,9 +83,8 @@ class AsyncGeneratorProvider(AsyncProvider): async def create_async( cls, model: str, - messages: list[dict[str, str]], - **kwargs: Any, - ) -> str: + messages: list[dict[str, str]], **kwargs: Any) -> str: + chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)] if chunks: return "".join(chunks) @@ -97,14 +93,14 @@ class AsyncGeneratorProvider(AsyncProvider): @abstractmethod def create_async_generator( model: str, - messages: list[dict[str, str]], - ) -> AsyncGenerator: + messages: list[dict[str, str]]) -> AsyncGenerator: + raise NotImplementedError() def run_generator(generator: AsyncGenerator[Union[Any, str], Any]): loop = asyncio.new_event_loop() - gen = generator.__aiter__() + gen = generator.__aiter__() while True: try: |