summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHeiner Lohaus <hlohaus@users.noreply.github.com>2024-05-15 02:27:51 +0200
committerHeiner Lohaus <hlohaus@users.noreply.github.com>2024-05-15 02:27:51 +0200
commit59fcf9d2d3be66c5988731f8e8ffa707d01c6539 (patch)
tree5329c60685eca0a15d86ff1d3193fee25eb864ca
parentMerge pull request #1934 from krishna2206/main (diff)
downloadgpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.tar
gpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.tar.gz
gpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.tar.bz2
gpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.tar.lz
gpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.tar.xz
gpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.tar.zst
gpt4free-59fcf9d2d3be66c5988731f8e8ffa707d01c6539.zip
-rw-r--r--g4f/Provider/needs_auth/Gemini.py25
-rw-r--r--g4f/Provider/needs_auth/OpenaiChat.py15
-rw-r--r--g4f/Provider/openai/har_file.py2
-rw-r--r--g4f/Provider/openai/proofofwork.py7
-rw-r--r--g4f/client/async_client.py11
-rw-r--r--g4f/gui/client/index.html6
-rw-r--r--g4f/gui/client/static/css/dracula.min.css7
-rw-r--r--g4f/gui/client/static/css/style.css3
-rw-r--r--g4f/gui/client/static/js/chat.v1.js46
-rw-r--r--g4f/providers/retry_provider.py186
10 files changed, 222 insertions, 86 deletions
diff --git a/g4f/Provider/needs_auth/Gemini.py b/g4f/Provider/needs_auth/Gemini.py
index 75cdd199..25ad1c6e 100644
--- a/g4f/Provider/needs_auth/Gemini.py
+++ b/g4f/Provider/needs_auth/Gemini.py
@@ -17,12 +17,12 @@ except ImportError:
pass
from ... import debug
-from ...typing import Messages, Cookies, ImageType, AsyncResult
+from ...typing import Messages, Cookies, ImageType, AsyncResult, AsyncIterator
from ..base_provider import AsyncGeneratorProvider
from ..helper import format_prompt, get_cookies
from ...requests.raise_for_status import raise_for_status
from ...errors import MissingAuthError, MissingRequirementsError
-from ...image import to_bytes, ImageResponse
+from ...image import to_bytes, to_data_uri, ImageResponse
from ...webdriver import get_browser, get_driver_cookies
REQUEST_HEADERS = {
@@ -59,7 +59,7 @@ class Gemini(AsyncGeneratorProvider):
_cookies: Cookies = None
@classmethod
- async def nodriver_login(cls) -> Cookies:
+ async def nodriver_login(cls) -> AsyncIterator[str]:
try:
import nodriver as uc
except ImportError:
@@ -72,6 +72,9 @@ class Gemini(AsyncGeneratorProvider):
if debug.logging:
print(f"Open nodriver with user_dir: {user_data_dir}")
browser = await uc.start(user_data_dir=user_data_dir)
+ login_url = os.environ.get("G4F_LOGIN_URL")
+ if login_url:
+ yield f"Please login: [Google Gemini]({login_url})\n\n"
page = await browser.get(f"{cls.url}/app")
await page.select("div.ql-editor.textarea", 240)
cookies = {}
@@ -79,10 +82,10 @@ class Gemini(AsyncGeneratorProvider):
if c.domain.endswith(".google.com"):
cookies[c.name] = c.value
await page.close()
- return cookies
+ cls._cookies = cookies
@classmethod
- async def webdriver_login(cls, proxy: str):
+ async def webdriver_login(cls, proxy: str) -> AsyncIterator[str]:
driver = None
try:
driver = get_browser(proxy=proxy)
@@ -131,13 +134,14 @@ class Gemini(AsyncGeneratorProvider):
) as session:
snlm0e = await cls.fetch_snlm0e(session, cls._cookies) if cls._cookies else None
if not snlm0e:
- cls._cookies = await cls.nodriver_login();
+ async for chunk in cls.nodriver_login():
+ yield chunk
if cls._cookies is None:
async for chunk in cls.webdriver_login(proxy):
yield chunk
if not snlm0e:
- if "__Secure-1PSID" not in cls._cookies:
+ if cls._cookies is None or "__Secure-1PSID" not in cls._cookies:
raise MissingAuthError('Missing "__Secure-1PSID" cookie')
snlm0e = await cls.fetch_snlm0e(session, cls._cookies)
if not snlm0e:
@@ -193,6 +197,13 @@ class Gemini(AsyncGeneratorProvider):
image = fetch.headers["location"]
resolved_images.append(image)
preview.append(image.replace('=s512', '=s200'))
+ # preview_url = image.replace('=s512', '=s200')
+ # async with client.get(preview_url) as fetch:
+ # preview_data = to_data_uri(await fetch.content.read())
+ # async with client.get(image) as fetch:
+ # data = to_data_uri(await fetch.content.read())
+ # preview.append(preview_data)
+ # resolved_images.append(data)
yield ImageResponse(resolved_images, image_prompt, {"orginal_links": images, "preview": preview})
def build_request(
diff --git a/g4f/Provider/needs_auth/OpenaiChat.py b/g4f/Provider/needs_auth/OpenaiChat.py
index 056a3702..03ea4539 100644
--- a/g4f/Provider/needs_auth/OpenaiChat.py
+++ b/g4f/Provider/needs_auth/OpenaiChat.py
@@ -38,7 +38,7 @@ DEFAULT_HEADERS = {
"accept": "*/*",
"accept-encoding": "gzip, deflate, br, zstd",
"accept-language": "en-US,en;q=0.5",
- "referer": "https://chat.openai.com/",
+ "referer": "https://chatgpt.com/",
"sec-ch-ua": "\"Brave\";v=\"123\", \"Not:A-Brand\";v=\"8\", \"Chromium\";v=\"123\"",
"sec-ch-ua-mobile": "?0",
"sec-ch-ua-platform": "\"Windows\"",
@@ -53,15 +53,15 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
"""A class for creating and managing conversations with OpenAI chat service"""
label = "OpenAI ChatGPT"
- url = "https://chat.openai.com"
+ url = "https://chatgpt.com"
working = True
supports_gpt_35_turbo = True
supports_gpt_4 = True
supports_message_history = True
supports_system_message = True
default_model = None
- default_vision_model = "gpt-4-vision"
- models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo"]
+ default_vision_model = "gpt-4o"
+ models = ["gpt-3.5-turbo", "gpt-4", "gpt-4-gizmo", "gpt-4o"]
model_aliases = {
"text-davinci-002-render-sha": "gpt-3.5-turbo",
"": "gpt-3.5-turbo",
@@ -442,6 +442,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
try:
image_request = await cls.upload_image(session, cls._headers, image, image_name) if image else None
except Exception as e:
+ image_request = None
if debug.logging:
print("OpenaiChat: Upload image failed")
print(f"{e.__class__.__name__}: {e}")
@@ -601,7 +602,7 @@ class OpenaiChat(AsyncGeneratorProvider, ProviderModelMixin):
this._fetch = this.fetch;
this.fetch = async (url, options) => {
const response = await this._fetch(url, options);
- if (url == "https://chat.openai.com/backend-api/conversation") {
+ if (url == "https://chatgpt.com/backend-api/conversation") {
this._headers = options.headers;
return response;
}
@@ -637,7 +638,7 @@ this.fetch = async (url, options) => {
if debug.logging:
print(f"Open nodriver with user_dir: {user_data_dir}")
browser = await uc.start(user_data_dir=user_data_dir)
- page = await browser.get("https://chat.openai.com/")
+ page = await browser.get("https://chatgpt.com/")
await page.select("[id^=headlessui-menu-button-]", 240)
api_key = await page.evaluate(
"(async () => {"
@@ -652,7 +653,7 @@ this.fetch = async (url, options) => {
)
cookies = {}
for c in await page.browser.cookies.get_all():
- if c.domain.endswith("chat.openai.com"):
+ if c.domain.endswith("chatgpt.com"):
cookies[c.name] = c.value
user_agent = await page.evaluate("window.navigator.userAgent")
await page.close()
diff --git a/g4f/Provider/openai/har_file.py b/g4f/Provider/openai/har_file.py
index 6a34c97a..220c20bf 100644
--- a/g4f/Provider/openai/har_file.py
+++ b/g4f/Provider/openai/har_file.py
@@ -26,7 +26,7 @@ class arkReq:
self.userAgent = userAgent
arkPreURL = "https://tcr9i.chat.openai.com/fc/gt2/public_key/35536E1E-65B4-4D96-9D97-6ADB7EFF8147"
-sessionUrl = "https://chat.openai.com/api/auth/session"
+sessionUrl = "https://chatgpt.com/api/auth/session"
chatArk: arkReq = None
accessToken: str = None
cookies: dict = None
diff --git a/g4f/Provider/openai/proofofwork.py b/g4f/Provider/openai/proofofwork.py
index e44ef6f7..51d96bc4 100644
--- a/g4f/Provider/openai/proofofwork.py
+++ b/g4f/Provider/openai/proofofwork.py
@@ -16,12 +16,9 @@ def generate_proof_token(required: bool, seed: str, difficulty: str, user_agent:
# Get current UTC time
now_utc = datetime.now(timezone.utc)
- # Convert UTC time to Eastern Time
- now_et = now_utc.astimezone(timezone(timedelta(hours=-5)))
+ parse_time = now_utc.strftime('%a, %d %b %Y %H:%M:%S GMT')
- parse_time = now_et.strftime('%a, %d %b %Y %H:%M:%S GMT')
-
- config = [core + screen, parse_time, 4294705152, 0, user_agent]
+ config = [core + screen, parse_time, None, 0, user_agent, "https://tcr9i.chat.openai.com/v2/35536E1E-65B4-4D96-9D97-6ADB7EFF8147/api.js","dpl=53d243de46ff04dadd88d293f088c2dd728f126f","en","en-US",442,"pluginsāˆ’[object PluginArray]","","alert"]
diff_len = len(difficulty) // 2
diff --git a/g4f/client/async_client.py b/g4f/client/async_client.py
index 8e1ee33c..07ad3357 100644
--- a/g4f/client/async_client.py
+++ b/g4f/client/async_client.py
@@ -11,10 +11,9 @@ from .types import AsyncIterResponse, ImageProvider
from .image_models import ImageModels
from .helper import filter_json, find_stop, filter_none, cast_iter_async
from .service import get_last_provider, get_model_and_provider
-from ..typing import Union, Iterator, Messages, AsyncIterator, ImageType
+from ..typing import Union, Messages, AsyncIterator, ImageType
from ..errors import NoImageResponseError
from ..image import ImageResponse as ImageProviderResponse
-from ..providers.base_provider import AsyncGeneratorProvider
try:
anext
@@ -88,7 +87,7 @@ def create_response(
api_key: str = None,
**kwargs
):
- has_asnyc = isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider)
+ has_asnyc = hasattr(provider, "create_async_generator")
if has_asnyc:
create = provider.create_async_generator
else:
@@ -157,7 +156,7 @@ class Chat():
def __init__(self, client: AsyncClient, provider: ProviderType = None):
self.completions = Completions(client, provider)
-async def iter_image_response(response: Iterator) -> Union[ImagesResponse, None]:
+async def iter_image_response(response: AsyncIterator) -> Union[ImagesResponse, None]:
async for chunk in response:
if isinstance(chunk, ImageProviderResponse):
return ImagesResponse([Image(image) for image in chunk.get_list()])
@@ -182,7 +181,7 @@ class Images():
async def generate(self, prompt, model: str = "", **kwargs) -> ImagesResponse:
provider = self.models.get(model, self.provider)
- if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ if hasattr(provider, "create_async_generator"):
response = create_image(self.client, provider, prompt, **kwargs)
else:
response = await provider.create_async(prompt)
@@ -195,7 +194,7 @@ class Images():
async def create_variation(self, image: ImageType, model: str = None, **kwargs):
provider = self.models.get(model, self.provider)
result = None
- if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ if hasattr(provider, "create_async_generator"):
response = provider.create_async_generator(
"",
[{"role": "user", "content": "create a image like this"}],
diff --git a/g4f/gui/client/index.html b/g4f/gui/client/index.html
index 66bcaaab..064e4594 100644
--- a/g4f/gui/client/index.html
+++ b/g4f/gui/client/index.html
@@ -19,8 +19,7 @@
<script src="/static/js/highlightjs-copy.min.js"></script>
<script src="/static/js/chat.v1.js" defer></script>
<script src="https://cdn.jsdelivr.net/npm/markdown-it@13.0.1/dist/markdown-it.min.js"></script>
- <link rel="stylesheet"
- href="//cdn.jsdelivr.net/gh/highlightjs/cdn-release@11.7.0/build/styles/base16/dracula.min.css">
+ <link rel="stylesheet" href="/static/css/dracula.min.css">
<script>
MathJax = {
chtml: {
@@ -244,8 +243,5 @@
<div class="mobile-sidebar">
<i class="fa-solid fa-bars"></i>
</div>
- <script>
- </script>
</body>
-
</html>
diff --git a/g4f/gui/client/static/css/dracula.min.css b/g4f/gui/client/static/css/dracula.min.css
new file mode 100644
index 00000000..729bbbfb
--- /dev/null
+++ b/g4f/gui/client/static/css/dracula.min.css
@@ -0,0 +1,7 @@
+/*!
+ Theme: Dracula
+ Author: Mike Barkmin (http://github.com/mikebarkmin) based on Dracula Theme (http://github.com/dracula)
+ License: ~ MIT (or more permissive) [via base16-schemes-source]
+ Maintainer: @highlightjs/core-team
+ Version: 2021.09.0
+*/pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}.hljs{color:#e9e9f4;background:#282936}.hljs ::selection,.hljs::selection{background-color:#4d4f68;color:#e9e9f4}.hljs-comment{color:#626483}.hljs-tag{color:#62d6e8}.hljs-operator,.hljs-punctuation,.hljs-subst{color:#e9e9f4}.hljs-operator{opacity:.7}.hljs-bullet,.hljs-deletion,.hljs-name,.hljs-selector-tag,.hljs-template-variable,.hljs-variable{color:#ea51b2}.hljs-attr,.hljs-link,.hljs-literal,.hljs-number,.hljs-symbol,.hljs-variable.constant_{color:#b45bcf}.hljs-class .hljs-title,.hljs-title,.hljs-title.class_{color:#00f769}.hljs-strong{font-weight:700;color:#00f769}.hljs-addition,.hljs-code,.hljs-string,.hljs-title.class_.inherited__{color:#ebff87}.hljs-built_in,.hljs-doctag,.hljs-keyword.hljs-atrule,.hljs-quote,.hljs-regexp{color:#a1efe4}.hljs-attribute,.hljs-function .hljs-title,.hljs-section,.hljs-title.function_,.ruby .hljs-property{color:#62d6e8}.diff .hljs-meta,.hljs-keyword,.hljs-template-tag,.hljs-type{color:#b45bcf}.hljs-emphasis{color:#b45bcf;font-style:italic}.hljs-meta,.hljs-meta .hljs-keyword,.hljs-meta .hljs-string{color:#00f769}.hljs-meta .hljs-keyword,.hljs-meta-keyword{font-weight:700} \ No newline at end of file
diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css
index 979f9f96..01bc17fa 100644
--- a/g4f/gui/client/static/css/style.css
+++ b/g4f/gui/client/static/css/style.css
@@ -381,7 +381,8 @@ body {
}
.message .count .fa-clipboard,
-.message .count .fa-volume-high {
+.message .count .fa-volume-high,
+.message .count .fa-rotate {
z-index: 1000;
cursor: pointer;
}
diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js
index 23605ed4..a0178e63 100644
--- a/g4f/gui/client/static/js/chat.v1.js
+++ b/g4f/gui/client/static/js/chat.v1.js
@@ -109,8 +109,9 @@ const register_message_buttons = async () => {
let playlist = [];
function play_next() {
const next = playlist.shift();
- if (next)
+ if (next && el.dataset.do_play) {
next.play();
+ }
}
if (el.dataset.stopped) {
el.classList.remove("blink")
@@ -179,6 +180,20 @@ const register_message_buttons = async () => {
});
}
});
+ document.querySelectorAll(".message .fa-rotate").forEach(async (el) => {
+ if (!("click" in el.dataset)) {
+ el.dataset.click = "true";
+ el.addEventListener("click", async () => {
+ const message_el = el.parentElement.parentElement.parentElement;
+ el.classList.add("clicked");
+ setTimeout(() => el.classList.remove("clicked"), 1000);
+ prompt_lock = true;
+ await hide_message(window.conversation_id, message_el.dataset.index);
+ window.token = message_id();
+ await ask_gpt(message_el.dataset.index);
+ })
+ }
+ });
}
const delete_conversations = async () => {
@@ -257,9 +272,9 @@ const remove_cancel_button = async () => {
}, 300);
};
-const prepare_messages = (messages, filter_last_message=true) => {
+const prepare_messages = (messages, message_index = -1) => {
// Removes none user messages at end
- if (filter_last_message) {
+ if (message_index == -1) {
let last_message;
while (last_message = messages.pop()) {
if (last_message["role"] == "user") {
@@ -267,14 +282,16 @@ const prepare_messages = (messages, filter_last_message=true) => {
break;
}
}
+ } else if (message_index >= 0) {
+ messages = messages.filter((_, index) => message_index >= index);
}
// Remove history, if it's selected
if (document.getElementById('history')?.checked) {
- if (filter_last_message) {
- messages = [messages.pop()];
- } else {
+ if (message_index == null) {
messages = [messages.pop(), messages.pop()];
+ } else {
+ messages = [messages.pop()];
}
}
@@ -361,11 +378,11 @@ imageInput?.addEventListener("click", (e) => {
}
});
-const ask_gpt = async () => {
+const ask_gpt = async (message_index = -1) => {
regenerate.classList.add(`regenerate-hidden`);
messages = await get_messages(window.conversation_id);
total_messages = messages.length;
- messages = prepare_messages(messages);
+ messages = prepare_messages(messages, message_index);
stop_generating.classList.remove(`stop_generating-hidden`);
@@ -528,6 +545,7 @@ const hide_option = async (conversation_id) => {
const span_el = document.createElement("span");
span_el.innerText = input_el.value;
span_el.classList.add("convo-title");
+ span_el.onclick = () => set_conversation(conversation_id);
left_el.removeChild(input_el);
left_el.appendChild(span_el);
}
@@ -616,7 +634,7 @@ const load_conversation = async (conversation_id, scroll=true) => {
}
if (window.GPTTokenizer_cl100k_base) {
- const filtered = prepare_messages(messages, false);
+ const filtered = prepare_messages(messages, null);
if (filtered.length > 0) {
last_model = last_model?.startsWith("gpt-4") ? "gpt-4" : "gpt-3.5-turbo"
let count_total = GPTTokenizer_cl100k_base?.encodeChat(filtered, last_model).length
@@ -683,15 +701,15 @@ async function save_system_message() {
await save_conversation(window.conversation_id, conversation);
}
}
-
-const hide_last_message = async (conversation_id) => {
+const hide_message = async (conversation_id, message_index =- 1) => {
const conversation = await get_conversation(conversation_id)
- const last_message = conversation.items.pop();
+ message_index = message_index == -1 ? conversation.items.length - 1 : message_index
+ const last_message = message_index in conversation.items ? conversation.items[message_index] : null;
if (last_message !== null) {
if (last_message["role"] == "assistant") {
last_message["regenerate"] = true;
}
- conversation.items.push(last_message);
+ conversation.items[message_index] = last_message;
}
await save_conversation(conversation_id, conversation);
};
@@ -790,7 +808,7 @@ document.getElementById("cancelButton").addEventListener("click", async () => {
document.getElementById("regenerateButton").addEventListener("click", async () => {
prompt_lock = true;
- await hide_last_message(window.conversation_id);
+ await hide_message(window.conversation_id);
window.token = message_id();
await ask_gpt();
});
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
index d64e8471..e2520437 100644
--- a/g4f/providers/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -3,18 +3,16 @@ from __future__ import annotations
import asyncio
import random
-from ..typing import Type, List, CreateResult, Messages, Iterator
+from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult
from .types import BaseProvider, BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
-class RetryProvider(BaseRetryProvider):
+class NewBaseRetryProvider(BaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
- shuffle: bool = True,
- single_provider_retry: bool = False,
- max_retries: int = 3,
+ shuffle: bool = True
) -> None:
"""
Initialize the BaseRetryProvider.
@@ -26,8 +24,6 @@ class RetryProvider(BaseRetryProvider):
"""
self.providers = providers
self.shuffle = shuffle
- self.single_provider_retry = single_provider_retry
- self.max_retries = max_retries
self.working = True
self.last_provider: Type[BaseProvider] = None
@@ -56,7 +52,146 @@ class RetryProvider(BaseRetryProvider):
exceptions = {}
started: bool = False
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+
+ raise_exceptions(exceptions)
+
+ async def create_async(
+ self,
+ model: str,
+ messages: Messages,
+ **kwargs,
+ ) -> str:
+ """
+ Asynchronously create a completion using available providers.
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ Returns:
+ str: The result of the asynchronous completion.
+ Raises:
+ Exception: Any exception encountered during the asynchronous completion process.
+ """
+ providers = self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+
+ exceptions = {}
+
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60),
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+
+ raise_exceptions(exceptions)
+
+ def get_providers(self, stream: bool):
+ providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+ return providers
+
+ async def create_async_generator(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = True,
+ **kwargs
+ ) -> AsyncResult:
+ exceptions = {}
+ started: bool = False
+
+ for provider in self.get_providers(stream):
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ if not stream:
+ yield await provider.create_async(model, messages, **kwargs)
+ elif hasattr(provider, "create_async_generator"):
+ async for token in provider.create_async_generator(model, messages, stream, **kwargs):
+ yield token
+ else:
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+
+ raise_exceptions(exceptions)
+
+class RetryProvider(NewBaseRetryProvider):
+ def __init__(
+ self,
+ providers: List[Type[BaseProvider]],
+ shuffle: bool = True,
+ single_provider_retry: bool = False,
+ max_retries: int = 3,
+ ) -> None:
+ """
+ Initialize the BaseRetryProvider.
+ Args:
+ providers (List[Type[BaseProvider]]): List of providers to use.
+ shuffle (bool): Whether to shuffle the providers list.
+ single_provider_retry (bool): Whether to retry a single provider if it fails.
+ max_retries (int): Maximum number of retries for a single provider.
+ """
+ super().__init__(providers, shuffle)
+ self.single_provider_retry = single_provider_retry
+ self.max_retries = max_retries
+
+ def create_completion(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs,
+ ) -> CreateResult:
+ """
+ Create a completion using available providers, with an option to stream the response.
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
+ Yields:
+ CreateResult: Tokens or results from the completion.
+ Raises:
+ Exception: Any exception encountered during the completion process.
+ """
+ providers = self.get_providers(stream)
if self.single_provider_retry and len(providers) == 1:
+ exceptions = {}
+ started: bool = False
provider = providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
@@ -74,25 +209,9 @@ class RetryProvider(BaseRetryProvider):
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
+ raise_exceptions(exceptions)
else:
- for provider in providers:
- self.last_provider = provider
- try:
- if debug.logging:
- print(f"Using {provider.__name__} provider")
- for token in provider.create_completion(model, messages, stream, **kwargs):
- yield token
- started = True
- if started:
- return
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
- if started:
- raise e
-
- raise_exceptions(exceptions)
+ yield from super().create_completion(model, messages, stream, **kwargs)
async def create_async(
self,
@@ -131,22 +250,9 @@ class RetryProvider(BaseRetryProvider):
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ raise_exceptions(exceptions)
else:
- for provider in providers:
- self.last_provider = provider
- try:
- if debug.logging:
- print(f"Using {provider.__name__} provider")
- return await asyncio.wait_for(
- provider.create_async(model, messages, **kwargs),
- timeout=kwargs.get("timeout", 60),
- )
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
-
- raise_exceptions(exceptions)
+ return await super().create_async(model, messages, **kwargs)
class IterProvider(BaseRetryProvider):
__name__ = "IterProvider"