diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-11-16 13:19:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-11-16 13:19:51 +0100 |
commit | 6ce493d4dfc2884832ff5b5be4479a55818b2fe7 (patch) | |
tree | 92e9efce62f7832ebe56969c120d8e92e75881a3 /g4f/providers/base_provider.py | |
parent | Update internet.py (diff) | |
download | gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.tar gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.tar.gz gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.tar.bz2 gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.tar.lz gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.tar.xz gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.tar.zst gpt4free-6ce493d4dfc2884832ff5b5be4479a55818b2fe7.zip |
Diffstat (limited to 'g4f/providers/base_provider.py')
-rw-r--r-- | g4f/providers/base_provider.py | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py index 5d48f2e0..128fb5a0 100644 --- a/g4f/providers/base_provider.py +++ b/g4f/providers/base_provider.py @@ -2,11 +2,13 @@ from __future__ import annotations import sys import asyncio + from asyncio import AbstractEventLoop from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod from inspect import signature, Parameter from typing import Callable, Union + from ..typing import CreateResult, AsyncResult, Messages from .types import BaseProvider, FinishReason from ..errors import NestAsyncioError, ModelNotSupportedError @@ -17,6 +19,17 @@ if sys.version_info < (3, 10): else: from types import NoneType +try: + import nest_asyncio + has_nest_asyncio = True +except ImportError: + has_nest_asyncio = False +try: + import uvloop + has_uvloop = True +except ImportError: + has_uvloop = False + # Set Windows event loop policy for better compatibility with asyncio and curl_cffi if sys.platform == 'win32': try: @@ -31,18 +44,14 @@ def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]: try: loop = asyncio.get_running_loop() # Do not patch uvloop loop because its incompatible. - try: - import uvloop + if has_uvloop: if isinstance(loop, uvloop.Loop): - return loop - except (ImportError, ModuleNotFoundError): - pass - if check_nested and not hasattr(loop.__class__, "_nest_patched"): - try: - import nest_asyncio + return loop + if not hasattr(loop.__class__, "_nest_patched"): + if has_nest_asyncio: nest_asyncio.apply(loop) - except ImportError: - raise NestAsyncioError('Install "nest_asyncio" package') + elif check_nested: + raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio') return loop except RuntimeError: pass @@ -154,7 +163,7 @@ class AsyncProvider(AbstractProvider): Returns: CreateResult: The result of the completion creation. """ - get_running_loop(check_nested=True) + get_running_loop(check_nested=False) yield asyncio.run(cls.create_async(model, messages, **kwargs)) @staticmethod @@ -208,7 +217,7 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: CreateResult: The result of the streaming completion creation. """ - loop = get_running_loop(check_nested=True) + loop = get_running_loop(check_nested=False) new_loop = False if loop is None: loop = asyncio.new_event_loop() @@ -222,7 +231,7 @@ class AsyncGeneratorProvider(AsyncProvider): while True: yield loop.run_until_complete(await_callback(gen.__anext__)) except StopAsyncIteration: - ... + pass finally: if new_loop: loop.close() |