diff options
author | Tekky <98614666+xtekky@users.noreply.github.com> | 2023-09-26 15:06:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-26 15:06:06 +0200 |
commit | 41af8aff6ce09be4131d50e0e24b8622628441fc (patch) | |
tree | 1519231f3513b57acf2548b77f1fc89e092fc4cb /g4f/Provider/base_provider.py | |
parent | ~ (diff) | |
parent | Add ChatgptDuo and Aibn Provider (diff) | |
download | gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.tar gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.tar.gz gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.tar.bz2 gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.tar.lz gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.tar.xz gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.tar.zst gpt4free-41af8aff6ce09be4131d50e0e24b8622628441fc.zip |
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 102 |
1 files changed, 29 insertions, 73 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index e8a54f78..a21dc871 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,13 +1,10 @@ from __future__ import annotations -import asyncio -import functools -from asyncio import SelectorEventLoop, AbstractEventLoop +from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import ABC, abstractmethod -import browser_cookie3 - +from .helper import get_event_loop, get_cookies, format_prompt from ..typing import AsyncGenerator, CreateResult @@ -40,20 +37,18 @@ class BaseProvider(ABC): **kwargs ) -> str: if not loop: - loop = asyncio.get_event_loop() - - partial_func = functools.partial( - cls.create_completion, - model, - messages, - False, - **kwargs - ) - response = await loop.run_in_executor( + loop = get_event_loop() + def create_func(): + return "".join(cls.create_completion( + model, + messages, + False, + **kwargs + )) + return await loop.run_in_executor( executor, - partial_func + create_func ) - return "".join(response) @classmethod @property @@ -76,11 +71,9 @@ class AsyncProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: - loop = create_event_loop() - try: - yield loop.run_until_complete(cls.create_async(model, messages, **kwargs)) - finally: - loop.close() + loop = get_event_loop() + coro = cls.create_async(model, messages, **kwargs) + yield loop.run_until_complete(coro) @staticmethod @abstractmethod @@ -103,22 +96,19 @@ class AsyncGeneratorProvider(AsyncProvider): stream: bool = True, **kwargs ) -> CreateResult: - loop = create_event_loop() - try: - generator = cls.create_async_generator( - model, - messages, - stream=stream, - **kwargs - ) - gen = generator.__aiter__() - while True: - try: - yield loop.run_until_complete(gen.__anext__()) - except StopAsyncIteration: - break - finally: - loop.close() + loop = get_event_loop() + generator = cls.create_async_generator( + model, + messages, + stream=stream, + **kwargs + ) + gen = generator.__aiter__() + while True: + try: + yield loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + break @classmethod async def create_async( @@ -143,38 +133,4 @@ class AsyncGeneratorProvider(AsyncProvider): messages: list[dict[str, str]], **kwargs ) -> AsyncGenerator: - raise NotImplementedError() - - -# Don't create a new event loop in a running async loop. -# Force use selector event loop on windows and linux use it anyway. -def create_event_loop() -> SelectorEventLoop: - try: - asyncio.get_running_loop() - except RuntimeError: - return SelectorEventLoop() - raise RuntimeError( - 'Use "create_async" instead of "create" function in a running event loop.') - - -_cookies = {} - -def get_cookies(cookie_domain: str) -> dict: - if cookie_domain not in _cookies: - _cookies[cookie_domain] = {} - try: - for cookie in browser_cookie3.load(cookie_domain): - _cookies[cookie_domain][cookie.name] = cookie.value - except: - pass - return _cookies[cookie_domain] - - -def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): - if add_special_tokens or len(messages) > 1: - formatted = "\n".join( - ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] - ) - return f"{formatted}\nAssistant:" - else: - return messages[0]["content"]
\ No newline at end of file + raise NotImplementedError()
\ No newline at end of file |