diff options
author | Heiner Lohaus <heiner@lohaus.eu> | 2023-10-07 19:00:45 +0200 |
---|---|---|
committer | Heiner Lohaus <heiner@lohaus.eu> | 2023-10-07 19:00:45 +0200 |
commit | dfdb759639479da640701fe0db716d4455b7ae38 (patch) | |
tree | 607297dd731568653d11421038b23861b5a9a4fa /g4f/Provider | |
parent | Improve code by AI (diff) | |
download | gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.tar gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.tar.gz gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.tar.bz2 gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.tar.lz gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.tar.xz gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.tar.zst gpt4free-dfdb759639479da640701fe0db716d4455b7ae38.zip |
Diffstat (limited to 'g4f/Provider')
-rw-r--r-- | g4f/Provider/helper.py | 25 |
1 files changed, 16 insertions, 9 deletions
diff --git a/g4f/Provider/helper.py b/g4f/Provider/helper.py index c127f241..5a9a9329 100644 --- a/g4f/Provider/helper.py +++ b/g4f/Provider/helper.py @@ -1,8 +1,10 @@ from __future__ import annotations -import asyncio, sys +import asyncio +import sys from asyncio import AbstractEventLoop from os import path +from typing import Dict, List import browser_cookie3 # Change event loop policy on windows @@ -13,7 +15,7 @@ if sys.platform == 'win32': asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # Local Cookie Storage -_cookies: dict[str, dict[str, str]] = {} +_cookies: Dict[str, Dict[str, str]] = {} # If event loop is already running, handle nested event loops # If "nest_asyncio" is installed, patch the event loop. @@ -34,11 +36,13 @@ def get_event_loop() -> AbstractEventLoop: return event_loop except ImportError: raise RuntimeError( - 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.') + 'Use "create_async" instead of "create" function in a running event loop. Or install the "nest_asyncio" package.' + ) + -# Load cookies for a domain from all supported browser. -# Cache the results in the "_cookies" variable -def get_cookies(cookie_domain: str) -> dict: +# Load cookies for a domain from all supported browsers. +# Cache the results in the "_cookies" variable. +def get_cookies(cookie_domain: str) -> Dict[str, str]: if cookie_domain not in _cookies: _cookies[cookie_domain] = {} try: @@ -49,15 +53,18 @@ def get_cookies(cookie_domain: str) -> dict: return _cookies[cookie_domain] -def format_prompt(messages: list[dict[str, str]], add_special_tokens=False): +def format_prompt(messages: List[Dict[str, str]], add_special_tokens=False) -> str: if add_special_tokens or len(messages) > 1: formatted = "\n".join( - ["%s: %s" % ((message["role"]).capitalize(), message["content"]) for message in messages] + [ + "%s: %s" % ((message["role"]).capitalize(), message["content"]) + for message in messages + ] ) return f"{formatted}\nAssistant:" else: return messages[0]["content"] - + def get_browser(user_data_dir: str = None): from undetected_chromedriver import Chrome |