From 10a38324582d34a2514ae0be64bc3a03774bfd77 Mon Sep 17 00:00:00 2001 From: abc <98614666+xtekky@users.noreply.github.com> Date: Fri, 24 Nov 2023 14:16:00 +0000 Subject: ~ fix DeepInfra --- g4f/Provider/DeepInfra.py | 112 +++++++++++++++++++++--------------------- g4f/Provider/base_provider.py | 3 +- 2 files changed, 56 insertions(+), 59 deletions(-) (limited to 'g4f') diff --git a/g4f/Provider/DeepInfra.py b/g4f/Provider/DeepInfra.py index da6333ad..754439c1 100644 --- a/g4f/Provider/DeepInfra.py +++ b/g4f/Provider/DeepInfra.py @@ -1,64 +1,62 @@ from __future__ import annotations -import json -from aiohttp import ClientSession +import requests, json +from ..typing import CreateResult, Messages +from .base_provider import BaseProvider -from ..typing import AsyncResult, Messages -from .base_provider import AsyncGeneratorProvider +class DeepInfra(BaseProvider): + url: str = "https://deepinfra.com" + working: bool = True + supports_stream: bool = True + supports_message_history: bool = True - -class DeepInfra(AsyncGeneratorProvider): - url = "https://deepinfra.com" - supports_message_history = True - working = True - - @classmethod - async def create_async_generator( - cls, - model: str, - messages: Messages, - proxy: str = None, - **kwargs - ) -> AsyncResult: - if not model: - model = "meta-llama/Llama-2-70b-chat-hf" + @staticmethod + def create_completion(model: str, + messages: Messages, + stream: bool, + **kwargs) -> CreateResult: + headers = { - "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0", - "Accept": "text/event-stream", - "Accept-Language": "de,en-US;q=0.7,en;q=0.3", - "Accept-Encoding": "gzip, deflate, br", - "Referer": f"{cls.url}/", - "Content-Type": "application/json", - "X-Deepinfra-Source": "web-page", - "Origin": cls.url, - "Connection": "keep-alive", - "Sec-Fetch-Dest": "empty", - "Sec-Fetch-Mode": "cors", - "Sec-Fetch-Site": "same-site", - "Pragma": "no-cache", - "Cache-Control": "no-cache", + 'Accept-Language': 'en,fr-FR;q=0.9,fr;q=0.8,es-ES;q=0.7,es;q=0.6,en-US;q=0.5,am;q=0.4,de;q=0.3', + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'Content-Type': 'application/json', + 'Origin': 'https://deepinfra.com', + 'Pragma': 'no-cache', + 'Referer': 'https://deepinfra.com/', + 'Sec-Fetch-Dest': 'empty', + 'Sec-Fetch-Mode': 'cors', + 'Sec-Fetch-Site': 'same-site', + 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36', + 'X-Deepinfra-Source': 'web-embed', + 'accept': 'text/event-stream', + 'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"', + 'sec-ch-ua-mobile': '?0', + 'sec-ch-ua-platform': '"macOS"', } - async with ClientSession(headers=headers) as session: - data = { - "model": model, - "messages": messages, - "stream": True, - } - async with session.post( - "https://api.deepinfra.com/v1/openai/chat/completions", - json=data, - proxy=proxy - ) as response: - response.raise_for_status() - first = True - async for line in response.content: - if line.startswith(b"data: [DONE]"): - break - elif line.startswith(b"data: "): - chunk = json.loads(line[6:])["choices"][0]["delta"].get("content") + + json_data = json.dumps({ + 'model' : 'meta-llama/Llama-2-70b-chat-hf', + 'messages': messages, + 'stream' : True}, separators=(',', ':')) + + response = requests.post('https://api.deepinfra.com/v1/openai/chat/completions', + headers=headers, data=json_data, stream=True) + + response.raise_for_status() + first = True + + for line in response.iter_content(chunk_size=1024): + if line.startswith(b"data: [DONE]"): + break + + elif line.startswith(b"data: "): + chunk = json.loads(line[6:])["choices"][0]["delta"].get("content") + + if chunk: + if first: + chunk = chunk.lstrip() if chunk: - if first: - chunk = chunk.lstrip() - if chunk: - first = False - yield chunk \ No newline at end of file + first = False + + yield (chunk) \ No newline at end of file diff --git a/g4f/Provider/base_provider.py b/g4f/Provider/base_provider.py index 1b0771ff..f3959634 100644 --- a/g4f/Provider/base_provider.py +++ b/g4f/Provider/base_provider.py @@ -8,7 +8,6 @@ from inspect import signature, Parameter from .helper import get_event_loop, get_cookies, format_prompt from ..typing import CreateResult, AsyncResult, Messages - if sys.version_info < (3, 10): NoneType = type(None) else: @@ -76,7 +75,7 @@ class BaseProvider(ABC): annotation = "None" return str(annotation) - args = ""; + args = "" for name, param in sig.parameters.items(): if name in ("self", "kwargs"): continue -- cgit v1.2.3