summaryrefslogtreecommitdiffstats
path: root/g4f/providers/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/providers/base_provider.py35
1 files changed, 22 insertions, 13 deletions
diff --git a/g4f/providers/base_provider.py b/g4f/providers/base_provider.py
index 5d48f2e0..128fb5a0 100644
--- a/g4f/providers/base_provider.py
+++ b/g4f/providers/base_provider.py
@@ -2,11 +2,13 @@ from __future__ import annotations
import sys
import asyncio
+
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
from abc import abstractmethod
from inspect import signature, Parameter
from typing import Callable, Union
+
from ..typing import CreateResult, AsyncResult, Messages
from .types import BaseProvider, FinishReason
from ..errors import NestAsyncioError, ModelNotSupportedError
@@ -17,6 +19,17 @@ if sys.version_info < (3, 10):
else:
from types import NoneType
+try:
+ import nest_asyncio
+ has_nest_asyncio = True
+except ImportError:
+ has_nest_asyncio = False
+try:
+ import uvloop
+ has_uvloop = True
+except ImportError:
+ has_uvloop = False
+
# Set Windows event loop policy for better compatibility with asyncio and curl_cffi
if sys.platform == 'win32':
try:
@@ -31,18 +44,14 @@ def get_running_loop(check_nested: bool) -> Union[AbstractEventLoop, None]:
try:
loop = asyncio.get_running_loop()
# Do not patch uvloop loop because its incompatible.
- try:
- import uvloop
+ if has_uvloop:
if isinstance(loop, uvloop.Loop):
- return loop
- except (ImportError, ModuleNotFoundError):
- pass
- if check_nested and not hasattr(loop.__class__, "_nest_patched"):
- try:
- import nest_asyncio
+ return loop
+ if not hasattr(loop.__class__, "_nest_patched"):
+ if has_nest_asyncio:
nest_asyncio.apply(loop)
- except ImportError:
- raise NestAsyncioError('Install "nest_asyncio" package')
+ elif check_nested:
+ raise NestAsyncioError('Install "nest_asyncio" package | pip install -U nest_asyncio')
return loop
except RuntimeError:
pass
@@ -154,7 +163,7 @@ class AsyncProvider(AbstractProvider):
Returns:
CreateResult: The result of the completion creation.
"""
- get_running_loop(check_nested=True)
+ get_running_loop(check_nested=False)
yield asyncio.run(cls.create_async(model, messages, **kwargs))
@staticmethod
@@ -208,7 +217,7 @@ class AsyncGeneratorProvider(AsyncProvider):
Returns:
CreateResult: The result of the streaming completion creation.
"""
- loop = get_running_loop(check_nested=True)
+ loop = get_running_loop(check_nested=False)
new_loop = False
if loop is None:
loop = asyncio.new_event_loop()
@@ -222,7 +231,7 @@ class AsyncGeneratorProvider(AsyncProvider):
while True:
yield loop.run_until_complete(await_callback(gen.__anext__))
except StopAsyncIteration:
- ...
+ pass
finally:
if new_loop:
loop.close()