diff options
Diffstat (limited to 'g4f/Provider/Llama2.py')
-rw-r--r-- | g4f/Provider/Llama2.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/g4f/Provider/Llama2.py b/g4f/Provider/Llama2.py index d1f8e194..6a94eea1 100644 --- a/g4f/Provider/Llama2.py +++ b/g4f/Provider/Llama2.py @@ -28,6 +28,10 @@ class Llama2(AsyncGeneratorProvider, ProviderModelMixin): model: str, messages: Messages, proxy: str = None, + system_message: str = "You are a helpful assistant.", + temperature: float = 0.75, + top_p: float = 0.9, + max_tokens: int = 8000, **kwargs ) -> AsyncResult: headers = { @@ -47,14 +51,18 @@ class Llama2(AsyncGeneratorProvider, ProviderModelMixin): "TE": "trailers" } async with ClientSession(headers=headers) as session: + system_messages = [message["content"] for message in messages if message["role"] == "system"] + if system_messages: + system_message = "\n".join(system_messages) + messages = [message for message in messages if message["role"] != "system"] prompt = format_prompt(messages) data = { "prompt": prompt, "model": cls.get_model(model), - "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."), - "temperature": kwargs.get("temperature", 0.75), - "topP": kwargs.get("top_p", 0.9), - "maxTokens": kwargs.get("max_tokens", 8000), + "systemPrompt": system_message, + "temperature": temperature, + "topP": top_p, + "maxTokens": max_tokens, "image": None } started = False |