summaryrefslogblamecommitdiffstats
path: root/g4f/Provider/Cloudflare.py
blob: 825c50275667e4bbeda7726e9b60645f24948ed4 (plain) (tree)
1
2
3
4
5
6
7
8
9




                                  
 


                                                                                                      

                                                             
                           

                                                                       
                                                                  



                                   
                                                    
                            

                                                      

                                                         
                                                                

                                                             
                                                         
                                                      
     
                      

                














                                                                                     






                                     
                               

                                


                                    

                                                                                      
                
                                 

                           
                                     
                          
         







                                                                                     



                                                   

                                                                                             
                                                         
                                         
                                   
from __future__ import annotations

import asyncio
import json
import uuid

from ..typing import AsyncResult, Messages, Cookies
from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies

class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
    label = "Cloudflare AI"
    url = "https://playground.ai.cloudflare.com"
    api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
    models_url = "https://playground.ai.cloudflare.com/api/models"
    working = True
    supports_stream = True
    supports_system_message = True
    supports_message_history = True
    default_model = "@cf/meta/llama-3.1-8b-instruct"
    model_aliases = {       
        "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
        "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
        "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
        "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
        "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
        "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
        "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
        "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
        "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
    }
    _args: dict = None

    @classmethod
    def get_models(cls) -> str:
        if not cls.models:
            if cls._args is None:
                get_running_loop(check_nested=True)
                args = get_args_from_nodriver(cls.url, cookies={
                    '__cf_bm': uuid.uuid4().hex,
                })
                cls._args = asyncio.run(args)
            with Session(**cls._args) as session:
                response = session.get(cls.models_url)
                raise_for_status(response)
                json_data = response.json()
                cls.models = [model.get("name") for model in json_data.get("models")]
                cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
        return cls.models

    @classmethod
    async def create_async_generator(
        cls,
        model: str,
        messages: Messages,
        proxy: str = None,
        max_tokens: int = 2048,
        cookies: Cookies = None,
        timeout: int = 300,
        **kwargs
    ) -> AsyncResult:
        model = cls.get_model(model)
        if cls._args is None:
            cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
        data = {
            "messages": messages,
            "lora": None,
            "model": model,
            "max_tokens": max_tokens,
            "stream": True
        }
        async with StreamSession(**cls._args) as session:
            async with session.post(
                cls.api_endpoint,
                json=data,
            ) as response:
                await raise_for_status(response)
                cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
                async for line in response.iter_lines():
                    if line.startswith(b'data: '):
                        if line == b'data: [DONE]':
                            break
                        try:
                            content = json.loads(line[6:].decode())
                            if content.get("response") and content.get("response") != '</s>':
                                yield content['response']
                        except Exception:
                            continue