diff options
Diffstat (limited to 'g4f/Provider/base_provider.py')
-rw-r--r-- | g4f/Provider/base_provider.py | 23 |
1 files changed, 21 insertions, 2 deletions
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index bc47a1fa..e1dcd24d 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -8,7 +8,7 @@ from inspect import signature, Parameter from .helper import get_cookies, format_prompt from ..typing import CreateResult, AsyncResult, Messages, Union from ..base_provider import BaseProvider -from ..errors import NestAsyncioError +from ..errors import NestAsyncioError, ModelNotSupportedError if sys.version_info < (3, 10): NoneType = type(None) @@ -251,4 +251,23 @@ class AsyncGeneratorProvider(AsyncProvider): Returns: AsyncResult: An asynchronous generator yielding results. """ - raise NotImplementedError()
\ No newline at end of file + raise NotImplementedError() + +class ProviderModelMixin: + default_model: str + models: list[str] = [] + model_aliases: dict[str, str] = {} + + @classmethod + def get_models(cls) -> list[str]: + return cls.models + + @classmethod + def get_model(cls, model: str) -> str: + if not model: + return cls.default_model + elif model in cls.model_aliases: + return cls.model_aliases[model] + elif model not in cls.get_models(): + raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}") + return model
\ No newline at end of file |