summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/base_provider.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 9d45aa44..1e2d4c64 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -47,9 +47,11 @@ class AsyncProvider(BaseProvider):
stream: bool = False,
**kwargs
) -> CreateResult:
- check_running_loop()
-
- yield asyncio.run(cls.create_async(model, messages, **kwargs))
+ loop = create_event_loop()
+ try:
+ yield loop.run_until_complete(cls.create_async(model, messages, **kwargs))
+ finally:
+ loop.close()
@staticmethod
@abstractmethod
@@ -70,10 +72,7 @@ class AsyncGeneratorProvider(AsyncProvider):
stream: bool = True,
**kwargs
) -> CreateResult:
- check_running_loop()
-
- # Force use selector event loop on windows
- loop = asyncio.SelectorEventLoop()
+ loop = get_new_event_loop()
try:
generator = cls.create_async_generator(
model,
@@ -108,12 +107,17 @@ class AsyncGeneratorProvider(AsyncProvider):
) -> AsyncGenerator:
raise NotImplementedError()
-# Don't create a new loop in a running loop
-def check_running_loop():
+
+def create_event_loop():
+ # Don't create a new loop in a running loop
if asyncio.events._get_running_loop() is not None:
raise RuntimeError(
'Use "create_async" instead of "create" function in a async loop.')
+ # Force use selector event loop on windows
+ return asyncio.SelectorEventLoop()
+
+
_cookies = {}
def get_cookies(cookie_domain: str) -> dict: