diff options
author | Heiner Lohaus <heiner@lohaus.eu> | 2023-10-07 10:17:43 +0200 |
---|---|---|
committer | Heiner Lohaus <heiner@lohaus.eu> | 2023-10-07 10:17:43 +0200 |
commit | f7bb30036e5e5482611627a040f54254ac162f72 (patch) | |
tree | 32afa79238a6acc49d8890d26bbdfa38829ce316 /g4f/requests.py | |
parent | Add GptGod Provider (diff) | |
download | gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.tar gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.tar.gz gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.tar.bz2 gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.tar.lz gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.tar.xz gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.tar.zst gpt4free-f7bb30036e5e5482611627a040f54254ac162f72.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/requests.py | 81 |
1 files changed, 40 insertions, 41 deletions
diff --git a/g4f/requests.py b/g4f/requests.py index c51d9804..3a4a3f54 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -1,47 +1,44 @@ from __future__ import annotations -import warnings, json, asyncio - +import warnings +import json +import asyncio from functools import partialmethod from asyncio import Future, Queue from typing import AsyncGenerator from curl_cffi.requests import AsyncSession, Response - import curl_cffi -is_newer_0_5_8 = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") -is_newer_0_5_9 = hasattr(curl_cffi.AsyncCurl, "remove_handle") -is_newer_0_5_10 = hasattr(AsyncSession, "release_curl") +is_newer_0_5_8: bool = hasattr(AsyncSession, "_set_cookies") or hasattr(curl_cffi.requests.Cookies, "get_cookies_for_curl") +is_newer_0_5_9: bool = hasattr(curl_cffi.AsyncCurl, "remove_handle") +is_newer_0_5_10: bool = hasattr(AsyncSession, "release_curl") + class StreamResponse: - def __init__(self, inner: Response, queue: Queue): - self.inner = inner - self.queue = queue + def __init__(self, inner: Response, queue: Queue[bytes]) -> None: + self.inner: Response = inner + self.queue: Queue[bytes] = queue self.request = inner.request - self.status_code = inner.status_code - self.reason = inner.reason - self.ok = inner.ok + self.status_code: int = inner.status_code + self.reason: str = inner.reason + self.ok: bool = inner.ok self.headers = inner.headers self.cookies = inner.cookies async def text(self) -> str: - content = await self.read() + content: bytes = await self.read() return content.decode() - def raise_for_status(self): + def raise_for_status(self) -> None: if not self.ok: raise RuntimeError(f"HTTP Error {self.status_code}: {self.reason}") - async def json(self, **kwargs): + async def json(self, **kwargs) -> dict: return json.loads(await self.read(), **kwargs) - - async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes]: - """ - Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/ - which is under the License: Apache 2.0 - """ - pending = None + + async def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None) -> AsyncGenerator[bytes, None]: + pending: bytes = None async for chunk in self.iter_content( chunk_size=chunk_size, decode_unicode=decode_unicode @@ -63,7 +60,7 @@ class StreamResponse: if pending is not None: yield pending - async def iter_content(self, chunk_size=None, decode_unicode=False) -> As: + async def iter_content(self, chunk_size=None, decode_unicode=False) -> AsyncGenerator[bytes, None]: if chunk_size: warnings.warn("chunk_size is ignored, there is no way to tell curl that.") if decode_unicode: @@ -77,22 +74,23 @@ class StreamResponse: async def read(self) -> bytes: return b"".join([chunk async for chunk in self.iter_content()]) + class StreamRequest: - def __init__(self, session: AsyncSession, method: str, url: str, **kwargs): - self.session = session - self.loop = session.loop if session.loop else asyncio.get_running_loop() - self.queue = Queue() - self.method = method - self.url = url - self.options = kwargs - self.handle = None - - def _on_content(self, data): + def __init__(self, session: AsyncSession, method: str, url: str, **kwargs) -> None: + self.session: AsyncSession = session + self.loop: asyncio.AbstractEventLoop = session.loop if session.loop else asyncio.get_running_loop() + self.queue: Queue[bytes] = Queue() + self.method: str = method + self.url: str = url + self.options: dict = kwargs + self.handle: curl_cffi.AsyncCurl = None + + def _on_content(self, data: bytes) -> None: if not self.enter.done(): self.enter.set_result(None) self.queue.put_nowait(data) - def _on_done(self, task: Future): + def _on_done(self, task: Future) -> None: if not self.enter.done(): self.enter.set_result(None) self.queue.put_nowait(None) @@ -102,8 +100,8 @@ class StreamRequest: async def fetch(self) -> StreamResponse: if self.handle: raise RuntimeError("Request already started") - self.curl = await self.session.pop_curl() - self.enter = self.loop.create_future() + self.curl: curl_cffi.AsyncCurl = await self.session.pop_curl() + self.enter: asyncio.Future = self.loop.create_future() if is_newer_0_5_10: request, _, header_buffer, _, _ = self.session._set_curl_options( self.curl, @@ -121,7 +119,7 @@ class StreamRequest: **self.options ) if is_newer_0_5_9: - self.handle = self.session.acurl.add_handle(self.curl) + self.handle = self.session.acurl.add_handle(self.curl) else: await self.session.acurl.add_handle(self.curl, False) self.handle = self.session.acurl._curl2future[self.curl] @@ -140,14 +138,14 @@ class StreamRequest: response, self.queue ) - + async def __aenter__(self) -> StreamResponse: return await self.fetch() - async def __aexit__(self, *args): + async def __aexit__(self, *args) -> None: self.release_curl() - def release_curl(self): + def release_curl(self) -> None: if is_newer_0_5_10: self.session.release_curl(self.curl) return @@ -162,6 +160,7 @@ class StreamRequest: self.session.push_curl(self.curl) self.curl = None + class StreamSession(AsyncSession): def request( self, @@ -170,7 +169,7 @@ class StreamSession(AsyncSession): **kwargs ) -> StreamRequest: return StreamRequest(self, method, url, **kwargs) - + head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") post = partialmethod(request, "POST") |