summaryrefslogtreecommitdiffstats
path: root/g4f/requests.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/requests.py')
-rw-r--r--g4f/requests.py31
1 files changed, 17 insertions, 14 deletions
diff --git a/g4f/requests.py b/g4f/requests.py
index 467ea371..1a13dec9 100644
--- a/g4f/requests.py
+++ b/g4f/requests.py
@@ -11,12 +11,6 @@ from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_drive
class StreamResponse:
def __init__(self, inner: Response) -> None:
self.inner: Response = inner
- self.request = inner.request
- self.status_code: int = inner.status_code
- self.reason: str = inner.reason
- self.ok: bool = inner.ok
- self.headers = inner.headers
- self.cookies = inner.cookies
async def text(self) -> str:
return await self.inner.atext()
@@ -34,17 +28,26 @@ class StreamResponse:
async def iter_content(self) -> AsyncGenerator[bytes, None]:
async for chunk in self.inner.aiter_content():
yield chunk
+
+ async def __aenter__(self):
+ inner: Response = await self.inner
+ self.inner = inner
+ self.request = inner.request
+ self.status_code: int = inner.status_code
+ self.reason: str = inner.reason
+ self.ok: bool = inner.ok
+ self.headers = inner.headers
+ self.cookies = inner.cookies
+ return self
+
+ async def __aexit__(self, *args):
+ await self.inner.aclose()
class StreamSession(AsyncSession):
- @asynccontextmanager
- async def request(
+ def request(
self, method: str, url: str, **kwargs
- ) -> AsyncGenerator[StreamResponse]:
- response = await super().request(method, url, stream=True, **kwargs)
- try:
- yield StreamResponse(response)
- finally:
- await response.aclose()
+ ) -> StreamResponse:
+ return StreamResponse(super().request(method, url, stream=True, **kwargs))
head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")