From e9f96ced9c534f313fd2d3b82b2464cd8424281a Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Thu, 21 Sep 2023 20:10:59 +0200 Subject: Add RetryProvider --- g4f/Provider/__init__.py | 6 +++- g4f/Provider/retry_provider.py | 81 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 g4f/Provider/retry_provider.py (limited to 'g4f/Provider') diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py index 0ca22533..b9ee2544 100644 --- a/g4f/Provider/__init__.py +++ b/g4f/Provider/__init__.py @@ -38,10 +38,14 @@ from .FastGpt import FastGpt from .V50 import V50 from .Wuguokai import Wuguokai -from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider +from .base_provider import BaseProvider, AsyncProvider, AsyncGeneratorProvider +from .retry_provider import RetryProvider __all__ = [ 'BaseProvider', + 'AsyncProvider', + 'AsyncGeneratorProvider', + 'RetryProvider', 'Acytoo', 'Aichat', 'Ails', diff --git a/g4f/Provider/retry_provider.py b/g4f/Provider/retry_provider.py new file mode 100644 index 00000000..e1a9cd1f --- /dev/null +++ b/g4f/Provider/retry_provider.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import random + +from ..typing import CreateResult +from .base_provider import BaseProvider, AsyncProvider + + +class RetryProvider(AsyncProvider): + __name__ = "RetryProvider" + working = True + needs_auth = False + supports_stream = True + supports_gpt_35_turbo = False + supports_gpt_4 = False + + def __init__( + self, + providers: list[type[BaseProvider]], + shuffle: bool = True + ) -> None: + self.providers = providers + self.shuffle = shuffle + + + def create_completion( + self, + model: str, + messages: list[dict[str, str]], + stream: bool = False, + **kwargs + ) -> CreateResult: + if stream: + providers = [provider for provider in self.providers if provider.supports_stream] + else: + providers = self.providers + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + started = False + for provider in providers: + try: + for token in provider.create_completion(model, messages, stream, **kwargs): + yield token + started = True + if started: + return + except Exception as e: + self.exceptions[provider.__name__] = e + if started: + break + + self.raise_exceptions() + + async def create_async( + self, + model: str, + messages: list[dict[str, str]], + **kwargs + ) -> str: + providers = [provider for provider in self.providers if issubclass(provider, AsyncProvider)] + if self.shuffle: + random.shuffle(providers) + + self.exceptions = {} + for provider in providers: + try: + return await provider.create_async(model, messages, **kwargs) + except Exception as e: + self.exceptions[provider.__name__] = e + + self.raise_exceptions() + + def raise_exceptions(self): + if self.exceptions: + raise RuntimeError("\n".join(["All providers failed:"] + [ + f"{p}: {self.exceptions[p].__class__.__name__}: {self.exceptions[p]}" for p in self.exceptions + ])) + + raise RuntimeError("No provider found") \ No newline at end of file -- cgit v1.2.3