diff options
Diffstat (limited to '')
-rw-r--r-- | g4f/providers/base_provider.py | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 5d48f2e0..128fb5a0 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -2,11 +2,13 @@ from __future__ import annotations import sys import asyncio + from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter from typing import Callable, Union + from ..typing import CreateResult, AsyncResult, Messages from .types import BaseProvider, FinishReason from ..errors import NestAsyncioError, ModelNotSupportedError @@ -17,6 +19,17 @@ if sys.version_info < (3, 10): else: from types import NoneType +try: + import nest_asyncio + has_nest_asyncio = True +except ImportError: + has_nest_asyncio = False +try: + import uvloop + has_uvloop = True +except ImportError: + has_uvloop = False + # Set Windows event loop policy for better compatibility with asyncio and curl_cffi if sys.platform == 'win32': try: @@ -31,18 +44,14 @@ def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: try: loop = asyncio.get_running_loop() # Do not patch uvloop loop because its incompatible. - try: - import uvloop + if has_uvloop: if isinstance(loop, uvloop.Loop): - return loop - except (ImportError, ModuleNotFoundError): - pass - if check_nested and not hasattr(loop.__class__, "_nest_patched"): - try: - import nest_asyncio + return loop + if not hasattr(loop.__class__, "_nest_patched"): + if has_nest_asyncio: nest_asyncio.apply(loop) - except ImportError: - raise NestAsyncioError('Install "nest_asyncio" package') + elif check_nested: + raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio') return loop except RuntimeError: pass @@ -154,7 +163,7 @@ class AsyncProvider(AbstractProvider): Returns: CreateResult: The result of the completion creation. """ - get_running_loop(check_nested=True) + get_running_loop(check_nested=False) yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -208,7 +217,7 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - loop = get_running_loop(check_nested=True) + loop = get_running_loop(check_nested=False) new_loop = False if loop is None: loop = asyncio.new_event_loop() @@ -222,7 +231,7 @@ class AsyncGeneratorProvider(AsyncProvider): while True: yield loop.run_until_complete(await_callback(gen.__anext__)) except StopAsyncIteration: - ... + pass finally: if new_loop: loop.close() |