summaryrefslogtreecommitdiffstats
path: root/g4f/locals/provider.py
diff options
context:
space:
mode:
authorH Lohaus <hlohaus@users.noreply.github.com>2024-04-07 11:27:26 +0200
committerGitHub <noreply@github.com>2024-04-07 11:27:26 +0200
commitd327afc60620913f5d2b0a9985b03a7934468ad4 (patch)
tree395de9142af3e6b9c0e5e3968ee7f8234b8b25e2 /g4f/locals/provider.py
parentUpdate Gemini.py (diff)
parentUpdate provider.py (diff)
downloadgpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.gz
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.bz2
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.lz
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.xz
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.tar.zst
gpt4free-d327afc60620913f5d2b0a9985b03a7934468ad4.zip
Diffstat (limited to '')
-rw-r--r--g4f/locals/provider.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/g4f/locals/provider.py b/g4f/locals/provider.py
new file mode 100644
index 00000000..45041539
--- /dev/null
+++ b/g4f/locals/provider.py
@@ -0,0 +1,74 @@
+from __future__ import annotations
+
+import os
+
+from gpt4all import GPT4All
+from .models import get_models
+from ..typing import Messages
+
+MODEL_LIST: dict[str, dict] = None
+
+def find_model_dir(model_file: str) -> str:
+ local_dir = os.path.dirname(os.path.abspath(__file__))
+ project_dir = os.path.dirname(os.path.dirname(local_dir))
+
+ new_model_dir = os.path.join(project_dir, "models")
+ new_model_file = os.path.join(new_model_dir, model_file)
+ if os.path.isfile(new_model_file):
+ return new_model_dir
+
+ old_model_dir = os.path.join(local_dir, "models")
+ old_model_file = os.path.join(old_model_dir, model_file)
+ if os.path.isfile(old_model_file):
+ return old_model_dir
+
+ working_dir = "./"
+ for root, dirs, files in os.walk(working_dir):
+ if model_file in files:
+ return root
+
+ return new_model_dir
+
+class LocalProvider:
+ @staticmethod
+ def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs):
+ global MODEL_LIST
+ if MODEL_LIST is None:
+ MODEL_LIST = get_models()
+ if model not in MODEL_LIST:
+ raise ValueError(f'Model "{model}" not found / not yet implemented')
+
+ model = MODEL_LIST[model]
+ model_file = model["path"]
+ model_dir = find_model_dir(model_file)
+ if not os.path.isfile(os.path.join(model_dir, model_file)):
+ print(f'Model file "models/{model_file}" not found.')
+ download = input(f"Do you want to download {model_file}? [y/n]: ")
+ if download in ["y", "Y"]:
+ GPT4All.download_model(model_file, model_dir)
+ else:
+ raise ValueError(f'Model "{model_file}" not found.')
+
+ model = GPT4All(model_name=model_file,
+ #n_threads=8,
+ verbose=False,
+ allow_download=False,
+ model_path=model_dir)
+
+ system_message = "\n".join(message["content"] for message in messages if message["role"] == "system")
+ if system_message:
+ system_message = "A chat between a curious user and an artificial intelligence assistant."
+
+ prompt_template = "USER: {0}\nASSISTANT: "
+ conversation = "\n" . join(
+ f"{message['role'].upper()}: {message['content']}"
+ for message in messages
+ if message["role"] != "system"
+ ) + "\nASSISTANT: "
+
+ with model.chat_session(system_message, prompt_template):
+ if stream:
+ for token in model.generate(conversation, streaming=True):
+ yield token
+ else:
+ yield model.generate(conversation) \ No newline at end of file