From 3c2755bc72efa0d8e5d8b2883443530ba67ecad4 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 26 Sep 2023 10:03:37 +0200 Subject: Add ChatgptDuo and Aibn Provider Add support for "nest_asyncio", Reuse event_loops with event_loop_policy Support for "create_async" with synchron provider --- g4f/Provider/base_provider.py | 102 ++++++++++++------------------------------ 1 file changed, 29 insertions(+), 73 deletions(-) (limited to 'g4f/Provider/base_provider.py') diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index e8a54f78..a21dc871 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -1,13 +1,10 @@ from __future__ import annotations -import asyncio -import functools -from asyncio import SelectorEventLoop, AbstractEventLoop +from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import ABC, abstractmethod -import browser_cookie3 - +from .helper import get_event_loop, get_cookies, format_prompt from ..typing import AsyncGenerator, CreateResult @@ -40,20 +37,18 @@ class BaseProvider(ABC): **kwargs ) -> str: if not loop: - loop = asyncio.get_event_loop() - - partial_func = functools.partial( - cls.create_completion, - model, - messages, - False, - **kwargs - ) - response = await loop.run_in_executor( + loop = get_event_loop() + def create_func(): + return "".join(cls.create_completion( + model, + messages, + False, + **kwargs + )) + return await loop.run_in_executor( executor, - partial_func + create_func ) - return "".join(response) @classmethod @property @@ -76,11 +71,9 @@ class AsyncProvider(BaseProvider): stream: bool = False, **kwargs ) -> CreateResult: - loop = create_event_loop() - try: - yield loop.run_until_complete(cls.create_async(model, messages, **kwargs)) - finally: - loop.close() + loop = get_event_loop() + coro = cls.create_async(model, messages, **kwargs) + yield loop.run_until_complete(coro) @staticmethod @abstractmethod @@ -103,22 +96,19 @@ class AsyncGeneratorProvider(AsyncProvider): stream: bool = True, **kwargs ) -> CreateResult: - loop = create_event_loop() - try: - generator = cls.create_async_generator( - model, - messages, - stream=stream, - **kwargs - ) - gen = generator.__aiter__() - while True: - try: - yield loop.run_until_complete(gen.__anext__()) - except StopAsyncIteration: - break - finally: - loop.close() + loop = get_event_loop() + generator = cls.create_async_generator( + model, + messages, + stream=stream, + **kwargs + ) + gen = generator.__aiter__() + while True: + try: + yield loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + break @classmethod async def create_async( @@ -143,38 +133,4 @@ class AsyncGeneratorProvider(AsyncProvider): messages: list[dict[str, str]], **kwargs ) -> AsyncGenerator: - raise NotImplementedError() - - -# Don't create a new event loop in a running async loop. -# Force use selector event loop on windows and linux use it anyway. -def create_event_loop() -> SelectorEventLoop: - try: - asyncio.get_running_loop() - except RuntimeError: - return SelectorEventLoop() - raise RuntimeError( - 'Use "create_async" instead of "create" function in a running event loop.') - - -_cookies = {} - -def get_cookies(cookie_domain: str) -> dict: - if cookie_domain not in _cookies: - _cookies[cookie_domain] = {} - try: - for cookie in browser_cookie3.load(cookie_domain): - _cookies[cookie_domain][cookie.name] = cookie.value - except: - pass - return _cookies[cookie_domain] - - -def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): - if add_special_tokens or len(messages) > 1: - formatted = "\n".join( - ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] - ) - return f"{formatted}\nAssistant:" - else: - return messages[0]["content"] \ No newline at end of file + raise NotImplementedError() \ No newline at end of file -- cgit v1.2.3