summaryrefslogtreecommitdiffstats
path: root/g4f/gui/server/api.py
blob: 00eb718241b7d8c1208c304401689af0a38fa57b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations

import logging
import os
import asyncio
from typing import Iterator
from flask import send_from_directory
from inspect import signature

from g4f import version, models
from g4f import get_last_provider, ChatCompletion
from g4f.errors import VersionNotFoundError
from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
from g4f.Provider import ProviderType, __providers__, __map__
from g4f.providers.base_provider import ProviderModelMixin
from g4f.providers.response import BaseConversation, FinishReason
from g4f.client.service import convert_to_provider
from g4f import debug

logger = logging.getLogger(__name__)
conversations: dict[dict[str, BaseConversation]] = {}

class Api:
    @staticmethod
    def get_models() -> list[str]:
        return models._all_models

    @staticmethod
    def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
        if provider in __map__:
            provider: ProviderType = __map__[provider]
            if issubclass(provider, ProviderModelMixin):
                if api_key is not None and "api_key" in signature(provider.get_models).parameters:
                    models = provider.get_models(api_key=api_key)
                else:
                    models = provider.get_models()
                return [
                    {
                        "model": model,
                        "default": model == provider.default_model,
                        "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
                        "image": False if provider.image_models is None else model in provider.image_models,
                    }
                    for model in models
                ]
        return []

    @staticmethod
    def get_image_models() -> list[dict]:
        image_models = []
        index = []
        for provider in __providers__:
            if hasattr(provider, "image_models"):
                if hasattr(provider, "get_models"):
                    provider.get_models()
                parent = provider
                if hasattr(provider, "parent"):
                    parent = __map__[provider.parent]
                if parent.__name__ not in index:
                    for model in provider.image_models:
                        image_models.append({
                            "provider": parent.__name__,
                            "url": parent.url,
                            "label": parent.label if hasattr(parent, "label") else None,
                            "image_model": model,
                            "vision_model": getattr(parent, "default_vision_model", None)
                        })
                    index.append(parent.__name__)
            elif hasattr(provider, "default_vision_model") and provider.__name__ not in index:
                image_models.append({
                    "provider": provider.__name__,
                    "url": provider.url,
                    "label": provider.label if hasattr(provider, "label") else None,
                    "image_model": None,
                    "vision_model": provider.default_vision_model
                })
                index.append(provider.__name__)
        return image_models

    @staticmethod
    def get_providers() -> list[str]:
        return {
            provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
            + (" (Image Generation)" if getattr(provider, "image_models", None) else "")
            + (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "")
            + (" (WebDriver)" if "webdriver" in provider.get_parameters() else "")
            + (" (Auth)" if provider.needs_auth else "")
            for provider in __providers__
            if provider.working
        }

    @staticmethod
    def get_version():
        try:
            current_version = version.utils.current_version
        except VersionNotFoundError:
            current_version = None
        return {
            "version": current_version,
            "latest_version": version.utils.latest_version,
        }

    def serve_images(self, name):
        ensure_images_dir()
        return send_from_directory(os.path.abspath(images_dir), name)

    def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
        model = json_data.get('model') or models.default
        provider = json_data.get('provider')
        messages = json_data['messages']
        api_key = json_data.get("api_key")
        if api_key is not None:
            kwargs["api_key"] = api_key
        do_web_search = json_data.get('web_search')
        if do_web_search and provider:
            provider_handler = convert_to_provider(provider)
            if hasattr(provider_handler, "get_parameters"):
                if "web_search" in provider_handler.get_parameters():
                    kwargs['web_search'] = True
                    do_web_search = False
        if do_web_search:
            from .internet import get_search_message
            messages[-1]["content"] = get_search_message(messages[-1]["content"])
        if json_data.get("auto_continue"):
            kwargs['auto_continue'] = True

        conversation_id = json_data.get("conversation_id")
        if conversation_id and provider:
            if provider in conversations and conversation_id in conversations[provider]:
                kwargs["conversation"] = conversations[provider][conversation_id]

        return {
            "model": model,
            "provider": provider,
            "messages": messages,
            "stream": True,
            "ignore_stream": True,
            "return_conversation": True,
            **kwargs
        }

    def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
        if debug.logging:
            debug.logs = []
            print_callback = debug.log_handler
            def log_handler(text: str):
                debug.logs.append(text)
                print_callback(text)
            debug.log_handler = log_handler
        try:
            result = ChatCompletion.create(**kwargs)
            first = True
            if isinstance(result, ImageResponse):
                if first:
                    first = False
                    yield self._format_json("provider", get_last_provider(True))
                yield self._format_json("content", str(result))
            else:
                for chunk in result:
                    if first:
                        first = False
                        yield self._format_json("provider", get_last_provider(True))
                    if isinstance(chunk, BaseConversation):
                        if provider:
                            if provider not in conversations:
                                conversations[provider] = {}
                            conversations[provider][conversation_id] = chunk
                            yield self._format_json("conversation", conversation_id)
                    elif isinstance(chunk, Exception):
                        logger.exception(chunk)
                        yield self._format_json("message", get_error_message(chunk))
                    elif isinstance(chunk, ImagePreview):
                        yield self._format_json("preview", chunk.to_string())
                    elif isinstance(chunk, ImageResponse):
                        images = chunk
                        if download_images:
                            images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
                            images = ImageResponse(images, chunk.alt)
                        yield self._format_json("content", str(images))
                    elif not isinstance(chunk, FinishReason):
                        yield self._format_json("content", str(chunk))
                    if debug.logs:
                        for log in debug.logs:
                            yield self._format_json("log", str(log))
                        debug.logs = []
        except Exception as e:
            logger.exception(e)
            yield self._format_json('error', get_error_message(e))

    def _format_json(self, response_type: str, content):
        return {
            'type': response_type,
            response_type: content
        }

def get_error_message(exception: Exception) -> str:
    message = f"{type(exception).__name__}: {exception}"
    provider = get_last_provider()
    if provider is None:
        return message
    return f"{provider.__name__}: {message}"