summaryrefslogtreecommitdiffstats
path: root/g4f/providers/retry_provider.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/providers/retry_provider.py186
1 files changed, 146 insertions, 40 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
index d64e8471..e2520437 100644
--- a/g4f/providers/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -3,18 +3,16 @@ from __future__ import annotations
import asyncio
import random
-from ..typing import Type, List, CreateResult, Messages, Iterator
+from ..typing import Type, List, CreateResult, Messages, Iterator, AsyncResult
from .types import BaseProvider, BaseRetryProvider
from .. import debug
from ..errors import RetryProviderError, RetryNoProviderError
-class RetryProvider(BaseRetryProvider):
+class NewBaseRetryProvider(BaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
- shuffle: bool = True,
- single_provider_retry: bool = False,
- max_retries: int = 3,
+ shuffle: bool = True
) -> None:
"""
Initialize the BaseRetryProvider.
@@ -26,8 +24,6 @@ class RetryProvider(BaseRetryProvider):
"""
self.providers = providers
self.shuffle = shuffle
- self.single_provider_retry = single_provider_retry
- self.max_retries = max_retries
self.working = True
self.last_provider: Type[BaseProvider] = None
@@ -56,7 +52,146 @@ class RetryProvider(BaseRetryProvider):
exceptions = {}
started: bool = False
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+
+ raise_exceptions(exceptions)
+
+ async def create_async(
+ self,
+ model: str,
+ messages: Messages,
+ **kwargs,
+ ) -> str:
+ """
+ Asynchronously create a completion using available providers.
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ Returns:
+ str: The result of the asynchronous completion.
+ Raises:
+ Exception: Any exception encountered during the asynchronous completion process.
+ """
+ providers = self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+
+ exceptions = {}
+
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60),
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+
+ raise_exceptions(exceptions)
+
+ def get_providers(self, stream: bool):
+ providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
+ if self.shuffle:
+ random.shuffle(providers)
+ return providers
+
+ async def create_async_generator(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = True,
+ **kwargs
+ ) -> AsyncResult:
+ exceptions = {}
+ started: bool = False
+
+ for provider in self.get_providers(stream):
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ if not stream:
+ yield await provider.create_async(model, messages, **kwargs)
+ elif hasattr(provider, "create_async_generator"):
+ async for token in provider.create_async_generator(model, messages, stream, **kwargs):
+ yield token
+ else:
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+
+ raise_exceptions(exceptions)
+
+class RetryProvider(NewBaseRetryProvider):
+ def __init__(
+ self,
+ providers: List[Type[BaseProvider]],
+ shuffle: bool = True,
+ single_provider_retry: bool = False,
+ max_retries: int = 3,
+ ) -> None:
+ """
+ Initialize the BaseRetryProvider.
+ Args:
+ providers (List[Type[BaseProvider]]): List of providers to use.
+ shuffle (bool): Whether to shuffle the providers list.
+ single_provider_retry (bool): Whether to retry a single provider if it fails.
+ max_retries (int): Maximum number of retries for a single provider.
+ """
+ super().__init__(providers, shuffle)
+ self.single_provider_retry = single_provider_retry
+ self.max_retries = max_retries
+
+ def create_completion(
+ self,
+ model: str,
+ messages: Messages,
+ stream: bool = False,
+ **kwargs,
+ ) -> CreateResult:
+ """
+ Create a completion using available providers, with an option to stream the response.
+ Args:
+ model (str): The model to be used for completion.
+ messages (Messages): The messages to be used for generating completion.
+ stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
+ Yields:
+ CreateResult: Tokens or results from the completion.
+ Raises:
+ Exception: Any exception encountered during the completion process.
+ """
+ providers = self.get_providers(stream)
if self.single_provider_retry and len(providers) == 1:
+ exceptions = {}
+ started: bool = False
provider = providers[0]
self.last_provider = provider
for attempt in range(self.max_retries):
@@ -74,25 +209,9 @@ class RetryProvider(BaseRetryProvider):
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
if started:
raise e
+ raise_exceptions(exceptions)
else:
- for provider in providers:
- self.last_provider = provider
- try:
- if debug.logging:
- print(f"Using {provider.__name__} provider")
- for token in provider.create_completion(model, messages, stream, **kwargs):
- yield token
- started = True
- if started:
- return
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
- if started:
- raise e
-
- raise_exceptions(exceptions)
+ yield from super().create_completion(model, messages, stream, **kwargs)
async def create_async(
self,
@@ -131,22 +250,9 @@ class RetryProvider(BaseRetryProvider):
exceptions[provider.__name__] = e
if debug.logging:
print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ raise_exceptions(exceptions)
else:
- for provider in providers:
- self.last_provider = provider
- try:
- if debug.logging:
- print(f"Using {provider.__name__} provider")
- return await asyncio.wait_for(
- provider.create_async(model, messages, **kwargs),
- timeout=kwargs.get("timeout", 60),
- )
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
-
- raise_exceptions(exceptions)
+ return await super().create_async(model, messages, **kwargs)
class IterProvider(BaseRetryProvider):
__name__ = "IterProvider"