From 91feb34054f529c37e10d98d2471c8c0c6780147 Mon Sep 17 00:00:00 2001 From: Heiner Lohaus Date: Tue, 23 Jan 2024 19:44:48 +0100 Subject: Add ProviderModelMixin for model selection --- g4f/Provider/Llama2.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) (limited to 'g4f/Provider/Llama2.py') diff --git a/g4f/Provider/Llama2.py b/g4f/Provider/Llama2.py index 17969621..d1f8e194 100644 --- a/g4f/Provider/Llama2.py +++ b/g4f/Provider/Llama2.py @@ -3,18 +3,24 @@ from __future__ import annotations from aiohttp import ClientSession from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin -models = { - "meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat", - "meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat", - "meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat", -} -class Llama2(AsyncGeneratorProvider): +class Llama2(AsyncGeneratorProvider, ProviderModelMixin): url = "https://www.llama2.ai" working = True supports_message_history = True + default_model = "meta/llama-2-70b-chat" + models = [ + "meta/llama-2-7b-chat", + "meta/llama-2-13b-chat", + "meta/llama-2-70b-chat", + ] + model_aliases = { + "meta-llama/Llama-2-7b-chat-hf": "meta/llama-2-7b-chat", + "meta-llama/Llama-2-13b-chat-hf": "meta/llama-2-13b-chat", + "meta-llama/Llama-2-70b-chat-hf": "meta/llama-2-70b-chat", + } @classmethod async def create_async_generator( @@ -24,10 +30,6 @@ class Llama2(AsyncGeneratorProvider): proxy: str = None, **kwargs ) -> AsyncResult: - if not model: - model = "meta/llama-2-70b-chat" - elif model in models: - model = models[model] headers = { "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0", "Accept": "*/*", @@ -48,7 +50,7 @@ class Llama2(AsyncGeneratorProvider): prompt = format_prompt(messages) data = { "prompt": prompt, - "model": model, + "model": cls.get_model(model), "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."), "temperature": kwargs.get("temperature", 0.75), "topP": kwargs.get("top_p", 0.9), -- cgit v1.2.3