summaryrefslogtreecommitdiffstats
path: root/g4f/providers
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/providers')
-rw-r--r--g4f/providers/retry_provider.py114
1 files changed, 72 insertions, 42 deletions
diff --git a/g4f/providers/retry_provider.py b/g4f/providers/retry_provider.py
index 52f473e9..d64e8471 100644
--- a/g4f/providers/retry_provider.py
+++ b/g4f/providers/retry_provider.py
@@ -12,46 +12,40 @@ class RetryProvider(BaseRetryProvider):
def __init__(
self,
providers: List[Type[BaseProvider]],
- shuffle: bool = True
+ shuffle: bool = True,
+ single_provider_retry: bool = False,
+ max_retries: int = 3,
) -> None:
"""
Initialize the BaseRetryProvider.
-
Args:
providers (List[Type[BaseProvider]]): List of providers to use.
shuffle (bool): Whether to shuffle the providers list.
+ single_provider_retry (bool): Whether to retry a single provider if it fails.
+ max_retries (int): Maximum number of retries for a single provider.
"""
self.providers = providers
self.shuffle = shuffle
+ self.single_provider_retry = single_provider_retry
+ self.max_retries = max_retries
self.working = True
self.last_provider: Type[BaseProvider] = None
- """
- A provider class to handle retries for creating completions with different providers.
-
- Attributes:
- providers (list): A list of provider instances.
- shuffle (bool): A flag indicating whether to shuffle providers before use.
- last_provider (BaseProvider): The last provider that was used.
- """
def create_completion(
self,
model: str,
messages: Messages,
stream: bool = False,
- **kwargs
+ **kwargs,
) -> CreateResult:
"""
Create a completion using available providers, with an option to stream the response.
-
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
-
Yields:
CreateResult: Tokens or results from the completion.
-
Raises:
Exception: Any exception encountered during the completion process.
"""
@@ -61,22 +55,42 @@ class RetryProvider(BaseRetryProvider):
exceptions = {}
started: bool = False
- for provider in providers:
+
+ if self.single_provider_retry and len(providers) == 1:
+ provider = providers[0]
self.last_provider = provider
- try:
- if debug.logging:
- print(f"Using {provider.__name__} provider")
- for token in provider.create_completion(model, messages, stream, **kwargs):
- yield token
+ for attempt in range(self.max_retries):
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
started = True
- if started:
- return
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
- if started:
- raise e
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
+ else:
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ for token in provider.create_completion(model, messages, stream, **kwargs):
+ yield token
+ started = True
+ if started:
+ return
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ if started:
+ raise e
raise_exceptions(exceptions)
@@ -84,18 +98,15 @@ class RetryProvider(BaseRetryProvider):
self,
model: str,
messages: Messages,
- **kwargs
+ **kwargs,
) -> str:
"""
Asynchronously create a completion using available providers.
-
Args:
model (str): The model to be used for completion.
messages (Messages): The messages to be used for generating completion.
-
Returns:
str: The result of the asynchronous completion.
-
Raises:
Exception: Any exception encountered during the asynchronous completion process.
"""
@@ -104,17 +115,36 @@ class RetryProvider(BaseRetryProvider):
random.shuffle(providers)
exceptions = {}
- for provider in providers:
+
+ if self.single_provider_retry and len(providers) == 1:
+ provider = providers[0]
self.last_provider = provider
- try:
- return await asyncio.wait_for(
- provider.create_async(model, messages, **kwargs),
- timeout=kwargs.get("timeout", 60)
- )
- except Exception as e:
- exceptions[provider.__name__] = e
- if debug.logging:
- print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ for attempt in range(self.max_retries):
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60),
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
+ else:
+ for provider in providers:
+ self.last_provider = provider
+ try:
+ if debug.logging:
+ print(f"Using {provider.__name__} provider")
+ return await asyncio.wait_for(
+ provider.create_async(model, messages, **kwargs),
+ timeout=kwargs.get("timeout", 60),
+ )
+ except Exception as e:
+ exceptions[provider.__name__] = e
+ if debug.logging:
+ print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
raise_exceptions(exceptions)