summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
authorHeiner Lohaus <hlohaus@users.noreply.github.com>2024-01-01 17:48:57 +0100
committerHeiner Lohaus <hlohaus@users.noreply.github.com>2024-01-01 17:48:57 +0100
commitc617b18d12c2f9d82ce7c73aae46d353b83f625a (patch)
tree898f5090865a8aea64fb87e56f9ebfc979a6b706 /g4f/Provider/base_provider.py
parentPatch event loop on win, Check event loop closed (diff)
downloadgpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar
gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.gz
gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.bz2
gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.lz
gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.xz
gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.tar.zst
gpt4free-c617b18d12c2f9d82ce7c73aae46d353b83f625a.zip
Diffstat (limited to '')
-rw-r--r--g4f/Provider/base_provider.py48
1 files changed, 21 insertions, 27 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 62029f5d..6da7f6c6 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,12 +1,14 @@
from __future__ import annotations
import sys
+import asyncio
from asyncio import AbstractEventLoop
from concurrent.futures import ThreadPoolExecutor
-from abc import ABC, abstractmethod
+from abc import abstractmethod
from inspect import signature, Parameter
from .helper import get_event_loop, get_cookies, format_prompt
-from ..typing import CreateResult, AsyncResult, Messages
+from ..typing import CreateResult, AsyncResult, Messages, Union
+from ..base_provider import BaseProvider
if sys.version_info < (3, 10):
NoneType = type(None)
@@ -20,25 +22,7 @@ if sys.platform == 'win32':
):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
-class BaseProvider(ABC):
- url: str
- working: bool = False
- needs_auth: bool = False
- supports_stream: bool = False
- supports_gpt_35_turbo: bool = False
- supports_gpt_4: bool = False
- supports_message_history: bool = False
-
- @staticmethod
- @abstractmethod
- def create_completion(
- model: str,
- messages: Messages,
- stream: bool,
- **kwargs
- ) -> CreateResult:
- raise NotImplementedError()
-
+class AbstractProvider(BaseProvider):
@classmethod
async def create_async(
cls,
@@ -60,9 +44,12 @@ class BaseProvider(ABC):
**kwargs
))
- return await loop.run_in_executor(
- executor,
- create_func
+ return await asyncio.wait_for(
+ loop.run_in_executor(
+ executor,
+ create_func
+ ),
+ timeout=kwargs.get("timeout", 0)
)
@classmethod
@@ -102,16 +89,19 @@ class BaseProvider(ABC):
return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
-class AsyncProvider(BaseProvider):
+class AsyncProvider(AbstractProvider):
@classmethod
def create_completion(
cls,
model: str,
messages: Messages,
stream: bool = False,
+ *,
+ loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
- loop = get_event_loop()
+ if not loop:
+ loop = get_event_loop()
coro = cls.create_async(model, messages, **kwargs)
yield loop.run_until_complete(coro)
@@ -134,9 +124,12 @@ class AsyncGeneratorProvider(AsyncProvider):
model: str,
messages: Messages,
stream: bool = True,
+ *,
+ loop: AbstractEventLoop = None,
**kwargs
) -> CreateResult:
- loop = get_event_loop()
+ if not loop:
+ loop = get_event_loop()
generator = cls.create_async_generator(
model,
messages,
@@ -171,6 +164,7 @@ class AsyncGeneratorProvider(AsyncProvider):
def create_async_generator(
model: str,
messages: Messages,
+ stream: bool = True,
**kwargs
) -> AsyncResult:
raise NotImplementedError()