diff options
Diffstat (limited to 'g4f')
-rw-r--r-- | g4f/Provider/Bing.py | 31 | ||||
-rw-r--r-- | g4f/gui/client/css/style.css | 4 | ||||
-rw-r--r-- | g4f/gui/client/html/index.html | 4 | ||||
-rw-r--r-- | g4f/gui/client/js/chat.v1.js | 100 | ||||
-rw-r--r-- | g4f/gui/server/backend.py | 25 |
5 files changed, 118 insertions, 46 deletions
diff --git a/g4f/Provider/Bing.py b/g4f/Provider/Bing.py index 786fec49..e3e47af9 100644 --- a/g4f/Provider/Bing.py +++ b/g4f/Provider/Bing.py @@ -12,7 +12,7 @@ from aiohttp import ClientSession, ClientTimeout, BaseConnector, WSMsgType from ..typing import AsyncResult, Messages, ImageType, Cookies from ..image import ImageRequest from ..errors import ResponseStatusError -from .base_provider import AsyncGeneratorProvider +from .base_provider import AsyncGeneratorProvider, ProviderModelMixin from .helper import get_connector, get_random_hex from .bing.upload_image import upload_image from .bing.conversation import Conversation, create_conversation, delete_conversation @@ -27,7 +27,7 @@ class Tones: balanced = "Balanced" precise = "Precise" -class Bing(AsyncGeneratorProvider): +class Bing(AsyncGeneratorProvider, ProviderModelMixin): """ Bing provider for generating responses using the Bing API. """ @@ -35,16 +35,21 @@ class Bing(AsyncGeneratorProvider): working = True supports_message_history = True supports_gpt_4 = True + default_model = Tones.balanced + models = [ + getattr(Tones, key) for key in dir(Tones) if not key.startswith("__") + ] - @staticmethod + @classmethod def create_async_generator( + cls, model: str, messages: Messages, proxy: str = None, timeout: int = 900, cookies: Cookies = None, connector: BaseConnector = None, - tone: str = Tones.balanced, + tone: str = None, image: ImageType = None, web_search: bool = False, **kwargs @@ -62,13 +67,11 @@ class Bing(AsyncGeneratorProvider): :param web_search: Flag to enable or disable web search. :return: An asynchronous result object. """ - if len(messages) < 2: - prompt = messages[0]["content"] - context = None - else: - prompt = messages[-1]["content"] - context = create_context(messages[:-1]) - + prompt = messages[-1]["content"] + context = create_context(messages[:-1]) if len(messages) > 1 else None + if tone is None: + tone = tone if model.startswith("gpt-4") else model + tone = cls.get_model(tone) gpt4_turbo = True if model.startswith("gpt-4-turbo") else False return stream_generate( @@ -86,7 +89,9 @@ def create_context(messages: Messages) -> str: :return: A string representing the context created from the messages. """ return "".join( - f"[{message['role']}]" + ("(#message)" if message['role'] != "system" else "(#additional_instructions)") + f"\n{message['content']}" + f"[{message['role']}]" + ("(#message)" + if message['role'] != "system" + else "(#additional_instructions)") + f"\n{message['content']}" for message in messages ) + "\n\n" @@ -403,7 +408,7 @@ async def stream_generate( do_read = False if response_txt.startswith(returned_text): new = response_txt[len(returned_text):] - if new != "\n": + if new not in ("", "\n"): yield new returned_text = response_txt if image_response: diff --git a/g4f/gui/client/css/style.css b/g4f/gui/client/css/style.css index 17f3e4b3..045eae99 100644 --- a/g4f/gui/client/css/style.css +++ b/g4f/gui/client/css/style.css @@ -106,6 +106,10 @@ body { border: 1px solid var(--blur-border); } +.hidden { + display: none; +} + .conversations { max-width: 260px; padding: var(--section-gap); diff --git a/g4f/gui/client/html/index.html b/g4f/gui/client/html/index.html index 46a9c541..e0091c8c 100644 --- a/g4f/gui/client/html/index.html +++ b/g4f/gui/client/html/index.html @@ -163,6 +163,10 @@ </select> </div> <div class="field"> + <select name="model2" id="model2" class="hidden"> + </select> + </div> + <div class="field"> <select name="jailbreak" id="jailbreak" style="display: none;"> <option value="default" selected>Set Jailbreak</option> <option value="gpt-math-1.0">math 1.0</option> diff --git a/g4f/gui/client/js/chat.v1.js b/g4f/gui/client/js/chat.v1.js index 8774fbf1..4e01593d 100644 --- a/g4f/gui/client/js/chat.v1.js +++ b/g4f/gui/client/js/chat.v1.js @@ -12,7 +12,9 @@ const imageInput = document.getElementById("image"); const cameraInput = document.getElementById("camera"); const fileInput = document.getElementById("file"); const inputCount = document.getElementById("input-count") +const providerSelect = document.getElementById("provider"); const modelSelect = document.getElementById("model"); +const modelProvider = document.getElementById("model2"); const systemPrompt = document.getElementById("systemPrompt") let prompt_lock = false; @@ -44,17 +46,21 @@ const markdown_render = (content) => { } let typesetPromise = Promise.resolve(); +let timeoutHighlightId; const highlight = (container) => { - container.querySelectorAll('code:not(.hljs').forEach((el) => { - if (el.className != "hljs") { - hljs.highlightElement(el); - } - }); - typesetPromise = typesetPromise.then( - () => MathJax.typesetPromise([container]) - ).catch( - (err) => console.log('Typeset failed: ' + err.message) - ); + if (timeoutHighlightId) clearTimeout(timeoutHighlightId); + timeoutHighlightId = setTimeout(() => { + container.querySelectorAll('code:not(.hljs').forEach((el) => { + if (el.className != "hljs") { + hljs.highlightElement(el); + } + }); + typesetPromise = typesetPromise.then( + () => MathJax.typesetPromise([container]) + ).catch( + (err) => console.log('Typeset failed: ' + err.message) + ); + }, 100); } const register_remove_message = async () => { @@ -108,7 +114,6 @@ const handle_ask = async () => { if (input.files.length > 0) imageInput.dataset.src = URL.createObjectURL(input.files[0]); else delete imageInput.dataset.src - model = modelSelect.options[modelSelect.selectedIndex].value message_box.innerHTML += ` <div class="message" data-index="${message_index}"> <div class="user"> @@ -124,7 +129,7 @@ const handle_ask = async () => { : '' } </div> - <div class="count">${count_words_and_tokens(message, model)}</div> + <div class="count">${count_words_and_tokens(message, get_selected_model())}</div> </div> </div> `; @@ -204,7 +209,6 @@ const ask_gpt = async () => { window.controller = new AbortController(); jailbreak = document.getElementById("jailbreak"); - provider = document.getElementById("provider"); window.text = ''; stop_generating.classList.remove(`stop_generating-hidden`); @@ -241,10 +245,10 @@ const ask_gpt = async () => { let body = JSON.stringify({ id: window.token, conversation_id: window.conversation_id, - model: modelSelect.options[modelSelect.selectedIndex].value, + model: get_selected_model(), jailbreak: jailbreak.options[jailbreak.selectedIndex].value, web_search: document.getElementById(`switch`).checked, - provider: provider.options[provider.selectedIndex].value, + provider: providerSelect.options[providerSelect.selectedIndex].value, patch_provider: document.getElementById('patch')?.checked, messages: messages }); @@ -666,11 +670,13 @@ sidebar_button.addEventListener("click", (event) => { window.scrollTo(0, 0); }); +const options = ["switch", "model", "model2", "jailbreak", "patch", "provider", "history"]; + const register_settings_localstorage = async () => { - for (id of ["switch", "model", "jailbreak", "patch", "provider", "history"]) { + options.forEach((id) => { element = document.getElementById(id); if (!element) { - continue; + return; } element.addEventListener('change', async (event) => { switch (event.target.type) { @@ -684,14 +690,14 @@ const register_settings_localstorage = async () => { console.warn("Unresolved element type"); } }); - } + }); } const load_settings_localstorage = async () => { - for (id of ["switch", "model", "jailbreak", "patch", "provider", "history"]) { + options.forEach((id) => { element = document.getElementById(id); if (!element || !(value = appStorage.getItem(element.id))) { - continue; + return; } if (value) { switch (element.type) { @@ -705,7 +711,7 @@ const load_settings_localstorage = async () => { console.warn("Unresolved element type"); } } - } + }); } const say_hello = async () => { @@ -780,13 +786,16 @@ function count_words_and_tokens(text, model) { } let countFocus = messageInput; +let timeoutId; const count_input = async () => { - if (countFocus.value) { - model = modelSelect.options[modelSelect.selectedIndex].value; - inputCount.innerText = count_words_and_tokens(countFocus.value, model); - } else { - inputCount.innerHTML = " " - } + if (timeoutId) clearTimeout(timeoutId); + timeoutId = setTimeout(() => { + if (countFocus.value) { + inputCount.innerText = count_words_and_tokens(countFocus.value, get_selected_model()); + } else { + inputCount.innerHTML = " " + } + }, 100); }; messageInput.addEventListener("keyup", count_input); systemPrompt.addEventListener("keyup", count_input); @@ -850,11 +859,13 @@ window.onload = async () => { providers = await response.json() select = document.getElementById('provider'); - for (provider of providers) { + providers.forEach((provider) => { let option = document.createElement('option'); option.value = option.text = provider; select.appendChild(option); - } + }) + + await load_provider_models(); await load_settings_localstorage() })(); @@ -914,4 +925,33 @@ fileInput.addEventListener('change', async (event) => { systemPrompt?.addEventListener("blur", async () => { await save_system_message(); -});
\ No newline at end of file +}); + +function get_selected_model() { + if (modelProvider.selectedIndex >= 0) { + return modelProvider.options[modelProvider.selectedIndex].value; + } else if (modelSelect.selectedIndex >= 0) { + return modelSelect.options[modelSelect.selectedIndex].value; + } +} + +async function load_provider_models() { + provider = providerSelect.options[providerSelect.selectedIndex].value; + response = await fetch('/backend-api/v2/models/' + provider); + models = await response.json(); + if (models.length > 0) { + modelSelect.classList.add("hidden"); + modelProvider.classList.remove("hidden"); + modelProvider.innerHTML = ''; + models.forEach((model) => { + let option = document.createElement('option'); + option.value = option.text = model.model; + option.selected = model.default; + modelProvider.appendChild(option); + }); + } else { + modelProvider.classList.add("hidden"); + modelSelect.classList.remove("hidden"); + } +}; +providerSelect.addEventListener("change", load_provider_models)
\ No newline at end of file diff --git a/g4f/gui/server/backend.py b/g4f/gui/server/backend.py index 454ed1c6..e9617c07 100644 --- a/g4f/gui/server/backend.py +++ b/g4f/gui/server/backend.py @@ -6,10 +6,11 @@ from g4f import version, models from g4f import get_last_provider, ChatCompletion from g4f.image import is_allowed_extension, to_image from g4f.errors import VersionNotFoundError -from g4f.Provider import __providers__ +from g4f.Provider import ProviderType, __providers__, __map__ +from g4f.providers.base_provider import ProviderModelMixin from g4f.Provider.bing.create_images import patch_provider -class Backend_Api: +class Backend_Api: """ Handles various endpoints in a Flask application for backend operations. @@ -33,6 +34,10 @@ class Backend_Api: 'function': self.get_models, 'methods': ['GET'] }, + '/backend-api/v2/models/<provider>': { + 'function': self.get_provider_models, + 'methods': ['GET'] + }, '/backend-api/v2/providers': { 'function': self.get_providers, 'methods': ['GET'] @@ -75,7 +80,21 @@ class Backend_Api: List[str]: A list of model names. """ return models._all_models - + + def get_provider_models(self, provider: str): + if provider in __map__: + provider: ProviderType = __map__[provider] + if issubclass(provider, ProviderModelMixin): + return [{"model": model, "default": model == provider.default_model} for model in provider.get_models()] + elif provider.supports_gpt_35_turbo or provider.supports_gpt_4: + return [ + *([{"model": "gpt-3.5-turbo", "default": not provider.supports_gpt_4}] if provider.supports_gpt_35_turbo else []), + *([{"model": "gpt-4", "default": not provider.supports_gpt_4}] if provider.supports_gpt_4 else []) + ] + else: + return []; + return 404, "Provider not found" + def get_providers(self): """ Return a list of all working providers. |