diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-01-02 01:10:31 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-02 01:10:31 +0100 |
commit | b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366 (patch) | |
tree | 6cd09fb2eb4c144e28a82759a2a9a2fa7f30d311 /g4f/Provider/retry_provider.py | |
parent | Merge pull request #1414 from hlohaus/lia (diff) | |
parent | Fix markdown replace (diff) | |
download | gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.tar gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.tar.gz gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.tar.bz2 gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.tar.lz gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.tar.xz gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.tar.zst gpt4free-b1b8ed40a4e8c7c3490b1c6b7cf6b55d0776f366.zip |
Diffstat (limited to 'g4f/Provider/retry_provider.py')
-rw-r--r-- | g4f/Provider/retry_provider.py | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py index e49b6da6..4d3e77ac 100644 --- a/g4f/Provider/retry_provider.py +++ b/g4f/Provider/retry_provider.py @@ -2,26 +2,13 @@ from __future__ import annotations import asyncio import random -from typing import List, Type, Dict from ..typing import CreateResult, Messages -from .base_provider import BaseProvider, AsyncProvider +from ..base_provider import BaseRetryProvider from .. import debug from ..errors import RetryProviderError, RetryNoProviderError -class RetryProvider(AsyncProvider): - __name__: str = "RetryProvider" - supports_stream: bool = True - - def __init__( - self, - providers: List[Type[BaseProvider]], - shuffle: bool = True - ) -> None: - self.providers: List[Type[BaseProvider]] = providers - self.shuffle: bool = shuffle - self.working = True - +class RetryProvider(BaseRetryProvider): def create_completion( self, model: str, @@ -36,20 +23,18 @@ class RetryProvider(AsyncProvider): if self.shuffle: random.shuffle(providers) - self.exceptions: Dict[str, Exception] = {} + self.exceptions = {} started: bool = False for provider in providers: + self.last_provider = provider try: if debug.logging: print(f"Using {provider.__name__} provider") - for token in provider.create_completion(model, messages, stream, **kwargs): yield token started = True - if started: return - except Exception as e: self.exceptions[provider.__name__] = e if debug.logging: @@ -69,8 +54,9 @@ class RetryProvider(AsyncProvider): if self.shuffle: random.shuffle(providers) - self.exceptions: Dict[str, Exception] = {} + self.exceptions = {} for provider in providers: + self.last_provider = provider try: return await asyncio.wait_for( provider.create_async(model, messages, **kwargs), |