diff options
Diffstat (limited to 'g4f/requests.py')
-rw-r--r-- | g4f/requests.py | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/g4f/requests.py b/g4f/requests.py index 467ea371..1a13dec9 100644 --- a/g4f/requests.py +++ b/g4f/requests.py @@ -11,12 +11,6 @@ from .webdriver import WebDriver, WebDriverSession, bypass_cloudflare, get_drive class StreamResponse: def __init__(self, inner: Response) -> None: self.inner: Response = inner - self.request = inner.request - self.status_code: int = inner.status_code - self.reason: str = inner.reason - self.ok: bool = inner.ok - self.headers = inner.headers - self.cookies = inner.cookies async def text(self) -> str: return await self.inner.atext() @@ -34,17 +28,26 @@ class StreamResponse: async def iter_content(self) -> AsyncGenerator[bytes, None]: async for chunk in self.inner.aiter_content(): yield chunk + + async def __aenter__(self): + inner: Response = await self.inner + self.inner = inner + self.request = inner.request + self.status_code: int = inner.status_code + self.reason: str = inner.reason + self.ok: bool = inner.ok + self.headers = inner.headers + self.cookies = inner.cookies + return self + + async def __aexit__(self, *args): + await self.inner.aclose() class StreamSession(AsyncSession): - @asynccontextmanager - async def request( + def request( self, method: str, url: str, **kwargs - ) -> AsyncGenerator[StreamResponse]: - response = await super().request(method, url, stream=True, **kwargs) - try: - yield StreamResponse(response) - finally: - await response.aclose() + ) -> StreamResponse: + return StreamResponse(super().request(method, url, stream=True, **kwargs)) head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") |