summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/Llama2.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/Llama2.py')
-rw-r--r--g4f/Provider/Llama2.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/g4f/Provider/Llama2.py b/g4f/Provider/Llama2.py
index d1f8e194..6a94eea1 100644
--- a/g4f/Provider/Llama2.py
+++ b/g4f/Provider/Llama2.py
@@ -28,6 +28,10 @@ class Llama2(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
+ system_message: str = "You are a helpful assistant.",
+ temperature: float = 0.75,
+ top_p: float = 0.9,
+ max_tokens: int = 8000,
**kwargs
) -> AsyncResult:
headers = {
@@ -47,14 +51,18 @@ class Llama2(AsyncGeneratorProvider, ProviderModelMixin):
"TE": "trailers"
}
async with ClientSession(headers=headers) as session:
+ system_messages = [message["content"] for message in messages if message["role"] == "system"]
+ if system_messages:
+ system_message = "\n".join(system_messages)
+ messages = [message for message in messages if message["role"] != "system"]
prompt = format_prompt(messages)
data = {
"prompt": prompt,
"model": cls.get_model(model),
- "systemPrompt": kwargs.get("system_message", "You are a helpful assistant."),
- "temperature": kwargs.get("temperature", 0.75),
- "topP": kwargs.get("top_p", 0.9),
- "maxTokens": kwargs.get("max_tokens", 8000),
+ "systemPrompt": system_message,
+ "temperature": temperature,
+ "topP": top_p,
+ "maxTokens": max_tokens,
"image": None
}
started = False