summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/base_provider.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/Provider/base_provider.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index d5f23931..def2cd6d 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -4,8 +4,7 @@ from ..typing import Any, CreateResult, AsyncGenerator, Union
import browser_cookie3
import asyncio
-from time import time
-import math
+
class BaseProvider(ABC):
url: str
@@ -48,6 +47,17 @@ def get_cookies(cookie_domain: str) -> dict:
return _cookies[cookie_domain]
+def format_prompt(messages: list[dict[str, str]], add_special_tokens=False):
+ if add_special_tokens or len(messages) > 1:
+ formatted = "\n".join(
+ ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages]
+ )
+ return f"{formatted}\nAssistant:"
+ else:
+ return messages.pop()["content"]
+
+
+
class AsyncProvider(BaseProvider):
@classmethod
def create_completion(
@@ -72,20 +82,19 @@ class AsyncGeneratorProvider(AsyncProvider):
cls,
model: str,
messages: list[dict[str, str]],
- stream: bool = True, **kwargs: Any) -> CreateResult:
-
- if stream:
- yield from run_generator(cls.create_async_generator(model, messages, **kwargs))
- else:
- yield from AsyncProvider.create_completion(cls=cls, model=model, messages=messages, **kwargs)
+ stream: bool = True,
+ **kwargs
+ ) -> CreateResult:
+ yield from run_generator(cls.create_async_generator(model, messages, stream=stream, **kwargs))
@classmethod
async def create_async(
cls,
model: str,
- messages: list[dict[str, str]], **kwargs: Any) -> str:
-
- chunks = [chunk async for chunk in cls.create_async_generator(model, messages, **kwargs)]
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> str:
+ chunks = [chunk async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)]
if chunks:
return "".join(chunks)
@@ -93,8 +102,9 @@ class AsyncGeneratorProvider(AsyncProvider):
@abstractmethod
def create_async_generator(
model: str,
- messages: list[dict[str, str]]) -> AsyncGenerator:
-
+ messages: list[dict[str, str]],
+ **kwargs
+ ) -> AsyncGenerator:
raise NotImplementedError()