diff options
author | H Lohaus <hlohaus@users.noreply.github.com> | 2024-04-11 03:36:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-11 03:36:12 +0200 |
commit | 4271fb9870fff99be2895216a839a9484aa5dbf3 (patch) | |
tree | 78a714c01fac85ed34856e9c295f9028efed576b /g4f/Provider/unfinished | |
parent | Merge pull request #1817 from hlohaus/bugfix (diff) | |
parent | Add ReplicateImage to provider list (diff) | |
download | gpt4free-0.2.9.4.tar gpt4free-0.2.9.4.tar.gz gpt4free-0.2.9.4.tar.bz2 gpt4free-0.2.9.4.tar.lz gpt4free-0.2.9.4.tar.xz gpt4free-0.2.9.4.tar.zst gpt4free-0.2.9.4.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/Provider/unfinished/Replicate.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/g4f/Provider/unfinished/Replicate.py b/g4f/Provider/unfinished/Replicate.py new file mode 100644 index 00000000..aaaf31b3 --- /dev/null +++ b/g4f/Provider/unfinished/Replicate.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +import asyncio + +from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin +from ..helper import format_prompt, filter_none +from ...typing import AsyncResult, Messages +from ...requests import StreamSession, raise_for_status +from ...image import ImageResponse +from ...errors import ResponseError, MissingAuthError + +class Replicate(AsyncGeneratorProvider, ProviderModelMixin): + url = "https://replicate.com" + working = True + default_model = "mistralai/mixtral-8x7b-instruct-v0.1" + api_base = "https://api.replicate.com/v1/models/" + + @classmethod + async def create_async_generator( + cls, + model: str, + messages: Messages, + api_key: str = None, + proxy: str = None, + timeout: int = 180, + system_prompt: str = None, + max_new_tokens: int = None, + temperature: float = None, + top_p: float = None, + top_k: float = None, + stop: list = None, + extra_data: dict = {}, + headers: dict = {}, + **kwargs + ) -> AsyncResult: + model = cls.get_model(model) + if api_key is None: + raise MissingAuthError("api_key is missing") + headers["Authorization"] = f"Bearer {api_key}" + async with StreamSession( + proxies={"all": proxy}, + headers=headers, + timeout=timeout + ) as session: + data = { + "stream": True, + "input": { + "prompt": format_prompt(messages), + **filter_none( + system_prompt=system_prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=",".join(stop) if stop else None + ), + **extra_data + }, + } + url = f"{cls.api_base.rstrip('/')}/{model}/predictions" + async with session.post(url, json=data) as response: + await raise_for_status(response) + result = await response.json() + if "id" not in result: + raise ResponseError(f"Invalid response: {result}") + async with session.get(result["urls"]["stream"], headers={"Accept": "text/event-stream"}) as response: + await raise_for_status(response) + event = None + async for line in response.iter_lines(): + if line.startswith(b"event: "): + event = line[7:] + elif event == b"output": + if line.startswith(b"data: "): + yield line[6:].decode() + elif not line.startswith(b"id: "): + continue#yield "+"+line.decode() + elif event == b"done": + break
\ No newline at end of file |