summaryrefslogtreecommitdiffstats
path: root/g4f/local/__init__.py
blob: c9d3d74a04af936353fd7ce2691d9ae140a58133 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from ..typing import Union, Iterator, Messages
from ..stubs  import ChatCompletion, ChatCompletionChunk
from ._engine import LocalProvider
from ._models import models
from ..client import iter_response, filter_none, IterResponse

class LocalClient():
    def __init__(self, **kwargs) -> None:
        self.chat: Chat = Chat(self)
    
    @staticmethod
    def list_models():
        return list(models.keys())
        
class Completions():
    def __init__(self, client: LocalClient):
        self.client: LocalClient = client

    def create(
        self,
        messages: Messages,
        model: str,
        stream: bool = False,
        response_format: dict = None,
        max_tokens: int = None,
        stop: Union[list[str], str] = None,
        **kwargs
    ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:

        stop = [stop] if isinstance(stop, str) else stop
        response = LocalProvider.create_completion(
            model, messages, stream,            
            **filter_none(
                max_tokens=max_tokens,
                stop=stop,
            ),
            **kwargs
        )
        response = iter_response(response, stream, response_format, max_tokens, stop)
        return response if stream else next(response)

class Chat():
    completions: Completions

    def __init__(self, client: LocalClient):
        self.completions = Completions(client)