summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/RubiksAI.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/Provider/RubiksAI.py')
-rw-r--r--g4f/Provider/RubiksAI.py124
1 files changed, 47 insertions, 77 deletions
diff --git a/g4f/Provider/RubiksAI.py b/g4f/Provider/RubiksAI.py
index 7e76d558..c06e6c3d 100644
--- a/g4f/Provider/RubiksAI.py
+++ b/g4f/Provider/RubiksAI.py
@@ -1,7 +1,6 @@
+
from __future__ import annotations
-import asyncio
-import aiohttp
import random
import string
import json
@@ -11,34 +10,24 @@ from aiohttp import ClientSession
from ..typing import AsyncResult, Messages
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
-from .helper import format_prompt
-
+from ..requests.raise_for_status import raise_for_status
class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
label = "Rubiks AI"
url = "https://rubiks.ai"
- api_endpoint = "https://rubiks.ai/search/api.php"
+ api_endpoint = "https://rubiks.ai/search/api/"
working = True
supports_stream = True
supports_system_message = True
supports_message_history = True
- default_model = 'llama-3.1-70b-versatile'
- models = [default_model, 'gpt-4o-mini']
+ default_model = 'gpt-4o-mini'
+ models = [default_model, 'gpt-4o', 'o1-mini', 'claude-3.5-sonnet', 'grok-beta', 'gemini-1.5-pro', 'nova-pro']
model_aliases = {
"llama-3.1-70b": "llama-3.1-70b-versatile",
}
- @classmethod
- def get_model(cls, model: str) -> str:
- if model in cls.models:
- return model
- elif model in cls.model_aliases:
- return cls.model_aliases[model]
- else:
- return cls.default_model
-
@staticmethod
def generate_mid() -> str:
"""
@@ -70,7 +59,8 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
model: str,
messages: Messages,
proxy: str = None,
- websearch: bool = False,
+ web_search: bool = False,
+ temperature: float = 0.6,
**kwargs
) -> AsyncResult:
"""
@@ -80,20 +70,18 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
- model (str): The model to use in the request.
- messages (Messages): The messages to send as a prompt.
- proxy (str, optional): Proxy URL, if needed.
- - websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
+ - web_search (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
"""
model = cls.get_model(model)
- prompt = format_prompt(messages)
- q_value = prompt
mid_value = cls.generate_mid()
- referer = cls.create_referer(q=q_value, mid=mid_value, model=model)
-
- url = cls.api_endpoint
- params = {
- 'q': q_value,
- 'model': model,
- 'id': '',
- 'mid': mid_value
+ referer = cls.create_referer(q=messages[-1]["content"], mid=mid_value, model=model)
+
+ data = {
+ "messages": messages,
+ "model": model,
+ "search": web_search,
+ "stream": True,
+ "temperature": temperature
}
headers = {
@@ -111,52 +99,34 @@ class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
'sec-ch-ua-mobile': '?0',
'sec-ch-ua-platform': '"Linux"'
}
-
- try:
- timeout = aiohttp.ClientTimeout(total=None)
- async with ClientSession(timeout=timeout) as session:
- async with session.get(url, headers=headers, params=params, proxy=proxy) as response:
- if response.status != 200:
- yield f"Request ended with status code {response.status}"
- return
-
- assistant_text = ''
- sources = []
-
- async for line in response.content:
- decoded_line = line.decode('utf-8').strip()
- if not decoded_line.startswith('data: '):
- continue
- data = decoded_line[6:]
- if data in ('[DONE]', '{"done": ""}'):
- break
- try:
- json_data = json.loads(data)
- except json.JSONDecodeError:
- continue
-
- if 'url' in json_data and 'title' in json_data:
- if websearch:
- sources.append({'title': json_data['title'], 'url': json_data['url']})
-
- elif 'choices' in json_data:
- for choice in json_data['choices']:
- delta = choice.get('delta', {})
- content = delta.get('content', '')
- role = delta.get('role', '')
- if role == 'assistant':
- continue
- assistant_text += content
-
- if websearch and sources:
- sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
- assistant_text += f"\n\n**Source:**\n{sources_text}"
-
- yield assistant_text
-
- except asyncio.CancelledError:
- yield "The request was cancelled."
- except aiohttp.ClientError as e:
- yield f"An error occurred during the request: {e}"
- except Exception as e:
- yield f"An unexpected error occurred: {e}"
+ async with ClientSession() as session:
+ async with session.post(cls.api_endpoint, headers=headers, json=data, proxy=proxy) as response:
+ await raise_for_status(response)
+
+ sources = []
+ async for line in response.content:
+ decoded_line = line.decode('utf-8').strip()
+ if not decoded_line.startswith('data: '):
+ continue
+ data = decoded_line[6:]
+ if data in ('[DONE]', '{"done": ""}'):
+ break
+ try:
+ json_data = json.loads(data)
+ except json.JSONDecodeError:
+ continue
+
+ if 'url' in json_data and 'title' in json_data:
+ if web_search:
+ sources.append({'title': json_data['title'], 'url': json_data['url']})
+
+ elif 'choices' in json_data:
+ for choice in json_data['choices']:
+ delta = choice.get('delta', {})
+ content = delta.get('content', '')
+ if content:
+ yield content
+
+ if web_search and sources:
+ sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
+ yield f"\n\n**Source:**\n{sources_text}" \ No newline at end of file