summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to 'g4f')
-rw-r--r--g4f/Provider/BingCreateImages.py1
-rw-r--r--g4f/Provider/Ecosia.py8
-rw-r--r--g4f/Provider/MetaAI.py206
-rw-r--r--g4f/Provider/__init__.py1
-rw-r--r--g4f/Provider/base_provider.py2
-rw-r--r--g4f/Provider/needs_auth/Groq.py2
-rw-r--r--g4f/Provider/needs_auth/OpenaiAccount.py1
-rw-r--r--g4f/api/__init__.py75
-rw-r--r--g4f/api/run.py4
-rw-r--r--g4f/cli.py28
-rw-r--r--g4f/client/service.py1
-rw-r--r--g4f/gui/client/static/css/style.css2
-rw-r--r--g4f/gui/client/static/js/chat.v1.js5
-rw-r--r--g4f/gui/server/api.py2
-rw-r--r--g4f/providers/types.py9
-rw-r--r--g4f/requests/curl_cffi.py16
-rw-r--r--g4f/webdriver.py2
17 files changed, 294 insertions, 71 deletions
diff --git a/g4f/Provider/BingCreateImages.py b/g4f/Provider/BingCreateImages.py
index a7bd54e4..f29126b6 100644
--- a/g4f/Provider/BingCreateImages.py
+++ b/g4f/Provider/BingCreateImages.py
@@ -15,6 +15,7 @@ class BingCreateImages(AsyncGeneratorProvider, ProviderModelMixin):
label = "Microsoft Designer"
url = "https://www.bing.com/images/create"
working = True
+ needs_auth = True
def __init__(self, cookies: Cookies = None, proxy: str = None) -> None:
self.cookies: Cookies = cookies
diff --git a/g4f/Provider/Ecosia.py b/g4f/Provider/Ecosia.py
index 1cae3560..231412aa 100644
--- a/g4f/Provider/Ecosia.py
+++ b/g4f/Provider/Ecosia.py
@@ -15,7 +15,8 @@ class Ecosia(AsyncGeneratorProvider, ProviderModelMixin):
working = True
supports_gpt_35_turbo = True
default_model = "gpt-3.5-turbo-0125"
- model_aliases = {"gpt-3.5-turbo": "gpt-3.5-turbo-0125"}
+ models = [default_model, "green"]
+ model_aliases = {"gpt-3.5-turbo": default_model}
@classmethod
async def create_async_generator(
@@ -23,11 +24,10 @@ class Ecosia(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
connector: BaseConnector = None,
- green: bool = False,
proxy: str = None,
**kwargs
) -> AsyncResult:
- cls.get_model(model)
+ model = cls.get_model(model)
headers = {
"authority": "api.ecosia.org",
"accept": "*/*",
@@ -39,7 +39,7 @@ class Ecosia(AsyncGeneratorProvider, ProviderModelMixin):
data = {
"messages": base64.b64encode(json.dumps(messages).encode()).decode()
}
- api_url = f"https://api.ecosia.org/v2/chat/?sp={'eco' if green else 'productivity'}"
+ api_url = f"https://api.ecosia.org/v2/chat/?sp={'eco' if model == 'green' else 'productivity'}"
async with session.post(api_url, json=data) as response:
await raise_for_status(response)
async for chunk in response.content.iter_any():
diff --git a/g4f/Provider/MetaAI.py b/g4f/Provider/MetaAI.py
new file mode 100644
index 00000000..e64a96d5
--- /dev/null
+++ b/g4f/Provider/MetaAI.py
@@ -0,0 +1,206 @@
+import json
+import uuid
+import random
+import time
+from typing import Dict, List
+
+from aiohttp import ClientSession, BaseConnector
+
+from ..typing import AsyncResult, Messages, Cookies
+from ..requests import raise_for_status, DEFAULT_HEADERS
+from .base_provider import AsyncGeneratorProvider
+from .helper import format_prompt, get_connector
+
+class Sources():
+ def __init__(self, link_list: List[Dict[str, str]]) -> None:
+ self.link = link_list
+
+ def __str__(self) -> str:
+ return "\n\n" + ("\n".join([f"[{link['title']}]({link['link']})" for link in self.list]))
+
+class AbraGeoBlockedError(Exception):
+ pass
+
+class MetaAI(AsyncGeneratorProvider):
+ url = "https://www.meta.ai"
+ working = True
+
+ def __init__(self, proxy: str = None, connector: BaseConnector = None):
+ self.session = ClientSession(connector=get_connector(connector, proxy), headers=DEFAULT_HEADERS)
+ self.cookies: Cookies = None
+ self.access_token: str = None
+
+ @classmethod
+ async def create_async_generator(
+ cls,
+ model: str,
+ messages: Messages,
+ proxy: str = None,
+ **kwargs
+ ) -> AsyncResult:
+ #cookies = get_cookies(".meta.ai", False, True)
+ async for chunk in cls(proxy).prompt(format_prompt(messages)):
+ yield chunk
+
+ async def get_access_token(self, birthday: str = "1999-01-01") -> str:
+ url = "https://www.meta.ai/api/graphql/"
+
+ payload = {
+ "lsd": self.lsd,
+ "fb_api_caller_class": "RelayModern",
+ "fb_api_req_friendly_name": "useAbraAcceptTOSForTempUserMutation",
+ "variables": json.dumps({
+ "dob": birthday,
+ "icebreaker_type": "TEXT",
+ "__relay_internal__pv__WebPixelRatiorelayprovider": 1,
+ }),
+ "doc_id": "7604648749596940",
+ }
+ headers = {
+ "x-fb-friendly-name": "useAbraAcceptTOSForTempUserMutation",
+ "x-fb-lsd": self.lsd,
+ "x-asbd-id": "129477",
+ "alt-used": "www.meta.ai",
+ "sec-fetch-site": "same-origin"
+ }
+ async with self.session.post(url, headers=headers, cookies=self.cookies, data=payload) as response:
+ await raise_for_status(response, "Fetch access_token failed")
+ auth_json = await response.json(content_type=None)
+ access_token = auth_json["data"]["xab_abra_accept_terms_of_service"]["new_temp_user_auth"]["access_token"]
+ return access_token
+
+ async def prompt(self, message: str, cookies: Cookies = None) -> AsyncResult:
+ if cookies is not None:
+ self.cookies = cookies
+ self.access_token = None
+ if self.cookies is None:
+ self.cookies = await self.get_cookies()
+ if self.access_token is None:
+ self.access_token = await self.get_access_token()
+
+ url = "https://graph.meta.ai/graphql?locale=user"
+ #url = "https://www.meta.ai/api/graphql/"
+ payload = {
+ "access_token": self.access_token,
+ #"lsd": cookies["lsd"],
+ "fb_api_caller_class": "RelayModern",
+ "fb_api_req_friendly_name": "useAbraSendMessageMutation",
+ "variables": json.dumps({
+ "message": {"sensitive_string_value": message},
+ "externalConversationId": str(uuid.uuid4()),
+ "offlineThreadingId": generate_offline_threading_id(),
+ "suggestedPromptIndex": None,
+ "flashVideoRecapInput": {"images": []},
+ "flashPreviewInput": None,
+ "promptPrefix": None,
+ "entrypoint": "ABRA__CHAT__TEXT",
+ "icebreaker_type": "TEXT",
+ "__relay_internal__pv__AbraDebugDevOnlyrelayprovider": False,
+ "__relay_internal__pv__WebPixelRatiorelayprovider": 1,
+ }),
+ "server_timestamps": "true",
+ "doc_id": "7783822248314888",
+ }
+ headers = {
+ "x-asbd-id": "129477",
+ "x-fb-friendly-name": "useAbraSendMessageMutation",
+ #"x-fb-lsd": cookies["lsd"],
+ }
+ async with self.session.post(url, headers=headers, cookies=self.cookies, data=payload) as response:
+ await raise_for_status(response, "Fetch response failed")
+ last_snippet_len = 0
+ fetch_id = None
+ async for line in response.content:
+ try:
+ json_line = json.loads(line)
+ except json.JSONDecodeError:
+ continue
+ bot_response_message = json_line.get("data", {}).get("node", {}).get("bot_response_message", {})
+ streaming_state = bot_response_message.get("streaming_state")
+ fetch_id = bot_response_message.get("fetch_id") or fetch_id
+ if streaming_state in ("STREAMING", "OVERALL_DONE"):
+ #imagine_card = bot_response_message["imagine_card"]
+ snippet = bot_response_message["snippet"]
+ new_snippet_len = len(snippet)
+ if new_snippet_len > last_snippet_len:
+ yield snippet[last_snippet_len:]
+ last_snippet_len = new_snippet_len
+ #if last_streamed_response is None:
+ # if attempts > 3:
+ # raise Exception("MetaAI is having issues and was not able to respond (Server Error)")
+ # access_token = await self.get_access_token()
+ # return await self.prompt(message=message, attempts=attempts + 1)
+ if fetch_id is not None:
+ sources = await self.fetch_sources(fetch_id)
+ if sources is not None:
+ yield sources
+
+ async def get_cookies(self, cookies: Cookies = None) -> Cookies:
+ async with self.session.get("https://www.meta.ai/", cookies=cookies) as response:
+ await raise_for_status(response, "Fetch home failed")
+ text = await response.text()
+ if "AbraGeoBlockedError" in text:
+ raise AbraGeoBlockedError("Meta AI isn't available yet in your country")
+ if cookies is None:
+ cookies = {
+ "_js_datr": self.extract_value(text, "_js_datr"),
+ "abra_csrf": self.extract_value(text, "abra_csrf"),
+ "datr": self.extract_value(text, "datr"),
+ }
+ self.lsd = self.extract_value(text, start_str='"LSD",[],{"token":"', end_str='"}')
+ return cookies
+
+ async def fetch_sources(self, fetch_id: str) -> Sources:
+ url = "https://graph.meta.ai/graphql?locale=user"
+ payload = {
+ "access_token": self.access_token,
+ "fb_api_caller_class": "RelayModern",
+ "fb_api_req_friendly_name": "AbraSearchPluginDialogQuery",
+ "variables": json.dumps({"abraMessageFetchID": fetch_id}),
+ "server_timestamps": "true",
+ "doc_id": "6946734308765963",
+ }
+ headers = {
+ "authority": "graph.meta.ai",
+ "x-fb-friendly-name": "AbraSearchPluginDialogQuery",
+ }
+ async with self.session.post(url, headers=headers, cookies=self.cookies, data=payload) as response:
+ await raise_for_status(response)
+ response_json = await response.json()
+ try:
+ message = response_json["data"]["message"]
+ if message is not None:
+ searchResults = message["searchResults"]
+ if searchResults is not None:
+ return Sources(searchResults["references"])
+ except (KeyError, TypeError):
+ raise RuntimeError(f"Response: {response_json}")
+
+ @staticmethod
+ def extract_value(text: str, key: str = None, start_str = None, end_str = '",') -> str:
+ if start_str is None:
+ start_str = f'{key}":{{"value":"'
+ start = text.find(start_str)
+ if start >= 0:
+ start+= len(start_str)
+ end = text.find(end_str, start)
+ if end >= 0:
+ return text[start:end]
+
+def generate_offline_threading_id() -> str:
+ """
+ Generates an offline threading ID.
+
+ Returns:
+ str: The generated offline threading ID.
+ """
+ # Generate a random 64-bit integer
+ random_value = random.getrandbits(64)
+
+ # Get the current timestamp in milliseconds
+ timestamp = int(time.time() * 1000)
+
+ # Combine timestamp and random value
+ threading_id = (timestamp << 22) | (random_value & ((1 << 22) - 1))
+
+ return str(threading_id) \ No newline at end of file
diff --git a/g4f/Provider/__init__.py b/g4f/Provider/__init__.py
index f761df5b..10249aa2 100644
--- a/g4f/Provider/__init__.py
+++ b/g4f/Provider/__init__.py
@@ -42,6 +42,7 @@ from .Koala import Koala
from .Liaobots import Liaobots
from .Llama import Llama
from .Local import Local
+from .MetaAI import MetaAI
from .PerplexityLabs import PerplexityLabs
from .Pi import Pi
from .ReplicateImage import ReplicateImage
diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py
index 4c0157f3..8f368747 100644
--- a/g4f/Provider/base_provider.py
+++ b/g4f/Provider/base_provider.py
@@ -1,3 +1,3 @@
from ..providers.base_provider import *
-from ..providers.types import FinishReason
+from ..providers.types import FinishReason, Streaming
from .helper import get_cookies, format_prompt \ No newline at end of file
diff --git a/g4f/Provider/needs_auth/Groq.py b/g4f/Provider/needs_auth/Groq.py
index 922b2dd2..d11f6a82 100644
--- a/g4f/Provider/needs_auth/Groq.py
+++ b/g4f/Provider/needs_auth/Groq.py
@@ -4,7 +4,7 @@ from .Openai import Openai
from ...typing import AsyncResult, Messages
class Groq(Openai):
- lebel = "Groq"
+ label = "Groq"
url = "https://console.groq.com/playground"
working = True
default_model = "mixtral-8x7b-32768"
diff --git a/g4f/Provider/needs_auth/OpenaiAccount.py b/g4f/Provider/needs_auth/OpenaiAccount.py
index 5c90b1de..7be60c86 100644
--- a/g4f/Provider/needs_auth/OpenaiAccount.py
+++ b/g4f/Provider/needs_auth/OpenaiAccount.py
@@ -3,5 +3,4 @@ from __future__ import annotations
from .OpenaiChat import OpenaiChat
class OpenaiAccount(OpenaiChat):
- label = "OpenAI ChatGPT with Account"
needs_auth = True \ No newline at end of file
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index 8151881e..ed39fc58 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -15,6 +15,8 @@ import g4f.debug
from g4f.client import AsyncClient
from g4f.typing import Messages
+app = FastAPI()
+
class ChatCompletionsConfig(BaseModel):
messages: Messages
model: str
@@ -25,53 +27,44 @@ class ChatCompletionsConfig(BaseModel):
stop: Union[list[str], str, None] = None
api_key: Optional[str] = None
web_search: Optional[bool] = None
+ proxy: Optional[str] = None
class Api:
- def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False,
- list_ignored_providers: List[str] = None) -> None:
- self.engine = engine
- self.debug = debug
- self.sentry = sentry
+ def __init__(self, list_ignored_providers: List[str] = None) -> None:
self.list_ignored_providers = list_ignored_providers
-
- if debug:
- g4f.debug.logging = True
self.client = AsyncClient()
- self.app = FastAPI()
-
- self.routes()
- self.register_validation_exception_handler()
+
+ def set_list_ignored_providers(self, list: list):
+ self.list_ignored_providers = list
def register_validation_exception_handler(self):
- @self.app.exception_handler(RequestValidationError)
+ @app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
details = exc.errors()
modified_details = []
for error in details:
- modified_details.append(
- {
- "loc": error["loc"],
- "message": error["msg"],
- "type": error["type"],
- }
- )
+ modified_details.append({
+ "loc": error["loc"],
+ "message": error["msg"],
+ "type": error["type"],
+ })
return JSONResponse(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
content=jsonable_encoder({"detail": modified_details}),
)
- def routes(self):
- @self.app.get("/")
+ def register_routes(self):
+ @app.get("/")
async def read_root():
return RedirectResponse("/v1", 302)
- @self.app.get("/v1")
+ @app.get("/v1")
async def read_root_v1():
return HTMLResponse('g4f API: Go to '
'<a href="/v1/chat/completions">chat/completions</a> '
'or <a href="/v1/models">models</a>.')
- @self.app.get("/v1/models")
+ @app.get("/v1/models")
async def models():
model_list = dict(
(model, g4f.models.ModelUtils.convert[model])
@@ -85,7 +78,7 @@ class Api:
} for model_id, model in model_list.items()]
return JSONResponse(model_list)
- @self.app.get("/v1/models/{model_name}")
+ @app.get("/v1/models/{model_name}")
async def model_info(model_name: str):
try:
model_info = g4f.models.ModelUtils.convert[model_name]
@@ -98,8 +91,8 @@ class Api:
except:
return JSONResponse({"error": "The model does not exist."})
- @self.app.post("/v1/chat/completions")
- async def chat_completions(config: ChatCompletionsConfig = None, request: Request = None, provider: str = None):
+ @app.post("/v1/chat/completions")
+ async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
try:
config.provider = provider if config.provider is None else config.provider
if config.api_key is None and request is not None:
@@ -132,13 +125,13 @@ class Api:
return StreamingResponse(streaming(), media_type="text/event-stream")
- @self.app.post("/v1/completions")
+ @app.post("/v1/completions")
async def completions():
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
- def run(self, ip, use_colors : bool = False):
- split_ip = ip.split(":")
- uvicorn.run(app=self.app, host=split_ip[0], port=int(split_ip[1]), use_colors=use_colors)
+api = Api()
+api.register_routes()
+api.register_validation_exception_handler()
def format_exception(e: Exception, config: ChatCompletionsConfig) -> str:
last_provider = g4f.get_last_provider(True)
@@ -148,7 +141,19 @@ def format_exception(e: Exception, config: ChatCompletionsConfig) -> str:
"provider": last_provider.get("name") if last_provider else config.provider
})
-def run_api(host: str = '0.0.0.0', port: int = 1337, debug: bool = False, use_colors=True) -> None:
- print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]')
- app = Api(engine=g4f, debug=debug)
- app.run(f"{host}:{port}", use_colors=use_colors) \ No newline at end of file
+def run_api(
+ host: str = '0.0.0.0',
+ port: int = 1337,
+ bind: str = None,
+ debug: bool = False,
+ workers: int = None,
+ use_colors: bool = None
+) -> None:
+ print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
+ if use_colors is None:
+ use_colors = debug
+ if bind is not None:
+ host, port = bind.split(":")
+ if debug:
+ g4f.debug.logging = True
+ uvicorn.run("g4f.api:app", host=host, port=int(port), workers=workers, use_colors=use_colors)# \ No newline at end of file
diff --git a/g4f/api/run.py b/g4f/api/run.py
index 4dcf5613..bc1cbf92 100644
--- a/g4f/api/run.py
+++ b/g4f/api/run.py
@@ -1,6 +1,4 @@
-import g4f
import g4f.api
if __name__ == "__main__":
- print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]')
- g4f.api.Api(engine = g4f, debug = True).run(ip = "0.0.0.0:10000")
+ g4f.api.run_api(debug=True)
diff --git a/g4f/cli.py b/g4f/cli.py
index 64d63fd3..6b39091d 100644
--- a/g4f/cli.py
+++ b/g4f/cli.py
@@ -1,30 +1,32 @@
import argparse
-from enum import Enum
-import g4f
from g4f import Provider
-
from g4f.gui.run import gui_parser, run_gui_args
-def run_gui(args):
- print("Running GUI...")
-
def main():
- IgnoredProviders = Enum("ignore_providers", {key: key for key in Provider.__all__})
parser = argparse.ArgumentParser(description="Run gpt4free")
subparsers = parser.add_subparsers(dest="mode", help="Mode to run the g4f in.")
- api_parser=subparsers.add_parser("api")
+ api_parser = subparsers.add_parser("api")
api_parser.add_argument("--bind", default="0.0.0.0:1337", help="The bind string.")
- api_parser.add_argument("--debug", type=bool, default=False, help="Enable verbose logging")
- api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider.name for provider in IgnoredProviders],
+ api_parser.add_argument("--debug", action="store_true", help="Enable verbose logging.")
+ api_parser.add_argument("--workers", type=int, default=None, help="Number of workers.")
+ api_parser.add_argument("--disable_colors", action="store_true", help="Don't use colors.")
+ api_parser.add_argument("--ignored-providers", nargs="+", choices=[provider for provider in Provider.__map__],
default=[], help="List of providers to ignore when processing request.")
subparsers.add_parser("gui", parents=[gui_parser()], add_help=False)
args = parser.parse_args()
if args.mode == "api":
- from g4f.api import Api
- controller=Api(engine=g4f, debug=args.debug, list_ignored_providers=args.ignored_providers)
- controller.run(args.bind)
+ import g4f.api
+ g4f.api.api.set_list_ignored_providers(
+ args.ignored_providers
+ )
+ g4f.api.run_api(
+ bind=args.bind,
+ debug=args.debug,
+ workers=args.workers,
+ use_colors=not args.disable_colors
+ )
elif args.mode == "gui":
run_gui_args(args)
else:
diff --git a/g4f/client/service.py b/g4f/client/service.py
index d25c923d..dd6bf4b6 100644
--- a/g4f/client/service.py
+++ b/g4f/client/service.py
@@ -111,5 +111,6 @@ def get_last_provider(as_dict: bool = False) -> Union[ProviderType, dict[str, st
"name": last.__name__,
"url": last.url,
"model": debug.last_model,
+ "label": last.label if hasattr(last, "label") else None
}
return last \ No newline at end of file
diff --git a/g4f/gui/client/static/css/style.css b/g4f/gui/client/static/css/style.css
index c0279bc2..a28c9cd6 100644
--- a/g4f/gui/client/static/css/style.css
+++ b/g4f/gui/client/static/css/style.css
@@ -890,7 +890,7 @@ a:-webkit-any-link {
resize: vertical;
max-height: 200px;
- min-height: 80px;
+ min-height: 100px;
}
/* style for hljs copy */
diff --git a/g4f/gui/client/static/js/chat.v1.js b/g4f/gui/client/static/js/chat.v1.js
index 39027260..a043cb25 100644
--- a/g4f/gui/client/static/js/chat.v1.js
+++ b/g4f/gui/client/static/js/chat.v1.js
@@ -302,7 +302,7 @@ async function add_message_chunk(message) {
window.provider_result = message.provider;
content.querySelector('.provider').innerHTML = `
<a href="${message.provider.url}" target="_blank">
- ${message.provider.name}
+ ${message.provider.label ? message.provider.label : message.provider.name}
</a>
${message.provider.model ? ' with ' + message.provider.model : ''}
`
@@ -545,7 +545,8 @@ const load_conversation = async (conversation_id, scroll=true) => {
last_model = item.provider?.model;
let next_i = parseInt(i) + 1;
let next_provider = item.provider ? item.provider : (messages.length > next_i ? messages[next_i].provider : null);
- let provider_link = item.provider?.name ? `<a href="${item.provider.url}" target="_blank">${item.provider.name}</a>` : "";
+ let provider_label = item.provider?.label ? item.provider?.label : item.provider?.name;
+ let provider_link = item.provider?.name ? `<a href="${item.provider.url}" target="_blank">${provider_label}</a>` : "";
let provider = provider_link ? `
<div class="provider">
${provider_link}
diff --git a/g4f/gui/server/api.py b/g4f/gui/server/api.py
index e3244c84..211d40c6 100644
--- a/g4f/gui/server/api.py
+++ b/g4f/gui/server/api.py
@@ -99,7 +99,7 @@ class Api():
if api_key is not None:
kwargs["api_key"] = api_key
if json_data.get('web_search'):
- if provider == "Bing":
+ if provider in ("Bing", "HuggingChat"):
kwargs['web_search'] = True
else:
from .internet import get_search_message
diff --git a/g4f/providers/types.py b/g4f/providers/types.py
index f33ea969..50c14431 100644
--- a/g4f/providers/types.py
+++ b/g4f/providers/types.py
@@ -102,4 +102,11 @@ ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
class FinishReason():
def __init__(self, reason: str):
- self.reason = reason \ No newline at end of file
+ self.reason = reason
+
+class Streaming():
+ def __init__(self, data: str) -> None:
+ self.data = data
+
+ def __str__(self) -> str:
+ return self.data \ No newline at end of file
diff --git a/g4f/requests/curl_cffi.py b/g4f/requests/curl_cffi.py
index d0d44ba7..1464cb32 100644
--- a/g4f/requests/curl_cffi.py
+++ b/g4f/requests/curl_cffi.py
@@ -79,10 +79,10 @@ class StreamSession(AsyncSession):
return StreamResponse(super().request(method, url, stream=True, **kwargs))
def ws_connect(self, url, *args, **kwargs):
- return WebSocket(self, url)
+ return WebSocket(self, url, **kwargs)
- def _ws_connect(self, url):
- return super().ws_connect(url)
+ def _ws_connect(self, url, **kwargs):
+ return super().ws_connect(url, **kwargs)
# Defining HTTP methods as partial methods of the request method.
head = partialmethod(request, "HEAD")
@@ -102,20 +102,22 @@ else:
raise RuntimeError("CurlMimi in curl_cffi is missing | pip install -U g4f[curl_cffi]")
class WebSocket():
- def __init__(self, session, url) -> None:
+ def __init__(self, session, url, **kwargs) -> None:
if not has_curl_ws:
raise RuntimeError("CurlWsFlag in curl_cffi is missing | pip install -U g4f[curl_cffi]")
self.session: StreamSession = session
self.url: str = url
+ del kwargs["autoping"]
+ self.options: dict = kwargs
async def __aenter__(self):
- self.inner = await self.session._ws_connect(self.url)
+ self.inner = await self.session._ws_connect(self.url, **self.options)
return self
async def __aexit__(self, *args):
- self.inner.aclose()
+ await self.inner.aclose()
- async def receive_str(self) -> str:
+ async def receive_str(self, **kwargs) -> str:
bytes, _ = await self.inner.arecv()
return bytes.decode(errors="ignore")
diff --git a/g4f/webdriver.py b/g4f/webdriver.py
index f392cacc..022e7a9f 100644
--- a/g4f/webdriver.py
+++ b/g4f/webdriver.py
@@ -65,7 +65,7 @@ def get_browser(
WebDriver: An instance of WebDriver configured with the specified options.
"""
if not has_requirements:
- raise MissingRequirementsError('Webdriver packages are not installed | pip install -U g4f[webdriver]')
+ raise MissingRequirementsError('Install Webdriver packages | pip install -U g4f[webdriver]')
browser = find_chrome_executable()
if browser is None:
raise MissingRequirementsError('Install "Google Chrome" browser')