summaryrefslogtreecommitdiffstats
path: root/g4f/api
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/api/__init__.py341
-rw-r--r--g4f/api/run.py2
2 files changed, 139 insertions, 204 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index fec5606f..43bca2a5 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -1,163 +1,137 @@
-import typing
-from .. import BaseProvider
-import g4f; g4f.debug.logging = True
+from fastapi import FastAPI, Response, Request
+from typing import List, Union, Any, Dict, AnyStr
+from ._tokenizer import tokenize
+from .. import BaseProvider
+
import time
import json
import random
import string
-import logging
-
-from typing import Union
-from loguru import logger
-from waitress import serve
-from ._logging import hook_logging
-from ._tokenizer import tokenize
-from flask_cors import CORS
-from werkzeug.serving import WSGIRequestHandler
-from werkzeug.exceptions import default_exceptions
-from werkzeug.middleware.proxy_fix import ProxyFix
-
-from flask import (
- Flask,
- jsonify,
- make_response,
- request,
-)
+import uvicorn
+import nest_asyncio
+import g4f
class Api:
- __default_ip = '127.0.0.1'
- __default_port = 1337
-
def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False,
- list_ignored_providers:typing.List[typing.Union[str, BaseProvider]]=None) -> None:
- self.engine = engine
- self.debug = debug
- self.sentry = sentry
- self.list_ignored_providers = list_ignored_providers
- self.log_level = logging.DEBUG if debug else logging.WARN
-
- hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s')
- self.logger = logging.getLogger('waitress')
-
- self.app = Flask(__name__)
- self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1)
- self.app.after_request(self.__after_request)
-
- def run(self, bind_str, threads=8):
- host, port = self.__parse_bind(bind_str)
-
- CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [
- 'Content-Type',
- 'Authorization',
- 'X-Requested-With',
- 'Accept',
- 'Origin',
- 'Access-Control-Request-Method',
- 'Access-Control-Request-Headers',
- 'Content-Disposition'], 'max_age': 600}})
-
- self.app.route('/v1/models', methods=['GET'])(self.models)
- self.app.route('/v1/models/<model_id>', methods=['GET'])(self.model_info)
-
- self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions)
- self.app.route('/v1/completions', methods=['POST'])(self.completions)
-
- for ex in default_exceptions:
- self.app.register_error_handler(ex, self.__handle_error)
-
- if not self.debug:
- self.logger.warning(f'Serving on http://{host}:{port}')
-
- WSGIRequestHandler.protocol_version = 'HTTP/1.1'
- serve(self.app, host=host, port=port, ident=None, threads=threads)
-
- def __handle_error(self, e: Exception):
- self.logger.error(e)
-
- return make_response(jsonify({
- 'code': e.code,
- 'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500)
-
- @staticmethod
- def __after_request(resp):
- resp.headers['X-Server'] = f'g4f/{g4f.version}'
-
- return resp
-
- def __parse_bind(self, bind_str):
- sections = bind_str.split(':', 2)
- if len(sections) < 2:
+ list_ignored_providers: List[Union[str, BaseProvider]] = None) -> None:
+ self.engine = engine
+ self.debug = debug
+ self.sentry = sentry
+ self.list_ignored_providers = list_ignored_providers
+
+ self.app = FastAPI()
+ nest_asyncio.apply()
+
+ JSONObject = Dict[AnyStr, Any]
+ JSONArray = List[Any]
+ JSONStructure = Union[JSONArray, JSONObject]
+
+ @self.app.get("/")
+ async def read_root():
+ return Response(content=json.dumps({"info": "g4f API"}, indent=4), media_type="application/json")
+
+ @self.app.get("/v1")
+ async def read_root_v1():
+ return Response(content=json.dumps({"info": "Go to /v1/chat/completions or /v1/models."}, indent=4), media_type="application/json")
+
+ @self.app.get("/v1/models")
+ async def models():
+ model_list = [{
+ 'id': model,
+ 'object': 'model',
+ 'created': 0,
+ 'owned_by': 'g4f'} for model in g4f.Model.__all__()]
+
+ return Response(content=json.dumps({
+ 'object': 'list',
+ 'data': model_list}, indent=4), media_type="application/json")
+
+ @self.app.get("/v1/models/{model_name}")
+ async def model_info(model_name: str):
try:
- port = int(sections[0])
- return self.__default_ip, port
- except ValueError:
- return sections[0], self.__default_port
-
- return sections[0], int(sections[1])
-
- async def home(self):
- return 'Hello world | https://127.0.0.1:1337/v1'
-
- async def chat_completions(self):
- model = request.json.get('model', 'gpt-3.5-turbo')
- stream = request.json.get('stream', False)
- messages = request.json.get('messages')
-
- logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}')
-
- config = None
- proxy = None
-
- try:
- config = json.load(open("config.json","r",encoding="utf-8"))
- proxy = config["proxy"]
-
- except Exception:
- pass
-
- if proxy != None:
- response = self.engine.ChatCompletion.create(model=model,
- stream=stream, messages=messages,
- ignored=self.list_ignored_providers,
- proxy=proxy)
- else:
- response = self.engine.ChatCompletion.create(model=model,
- stream=stream, messages=messages,
- ignored=self.list_ignored_providers)
-
- completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
- completion_timestamp = int(time.time())
-
- if not stream:
- prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages]))
- completion_tokens, _ = tokenize(response)
-
- return {
- 'id': f'chatcmpl-{completion_id}',
- 'object': 'chat.completion',
- 'created': completion_timestamp,
- 'model': model,
- 'choices': [
- {
- 'index': 0,
- 'message': {
- 'role': 'assistant',
- 'content': response,
- },
- 'finish_reason': 'stop',
- }
- ],
- 'usage': {
- 'prompt_tokens': prompt_tokens,
- 'completion_tokens': completion_tokens,
- 'total_tokens': prompt_tokens + completion_tokens,
- },
+ model_info = (g4f.ModelUtils.convert[model_name])
+
+ return Response(content=json.dumps({
+ 'id': model_name,
+ 'object': 'model',
+ 'created': 0,
+ 'owned_by': model_info.base_provider
+ }, indent=4), media_type="application/json")
+ except:
+ return Response(content=json.dumps({"error": "The model does not exist."}, indent=4), media_type="application/json")
+
+ @self.app.post("/v1/chat/completions")
+ async def chat_completions(request: Request, item: JSONStructure = None):
+ item_data = {
+ 'model': 'gpt-3.5-turbo',
+ 'stream': False,
}
- def streaming():
+ item_data.update(item or {})
+ model = item_data.get('model')
+ stream = item_data.get('stream')
+ messages = item_data.get('messages')
+
try:
- for chunk in response:
- completion_data = {
+ response = g4f.ChatCompletion.create(model=model, stream=stream, messages=messages)
+ except:
+ return Response(content=json.dumps({"error": "An error occurred while generating the response."}, indent=4), media_type="application/json")
+
+ completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
+ completion_timestamp = int(time.time())
+
+ if not stream:
+ prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages]))
+ completion_tokens, _ = tokenize(response)
+
+ json_data = {
+ 'id': f'chatcmpl-{completion_id}',
+ 'object': 'chat.completion',
+ 'created': completion_timestamp,
+ 'model': model,
+ 'choices': [
+ {
+ 'index': 0,
+ 'message': {
+ 'role': 'assistant',
+ 'content': response,
+ },
+ 'finish_reason': 'stop',
+ }
+ ],
+ 'usage': {
+ 'prompt_tokens': prompt_tokens,
+ 'completion_tokens': completion_tokens,
+ 'total_tokens': prompt_tokens + completion_tokens,
+ },
+ }
+
+ return Response(content=json.dumps(json_data, indent=4), media_type="application/json")
+
+ def streaming():
+ try:
+ for chunk in response:
+ completion_data = {
+ 'id': f'chatcmpl-{completion_id}',
+ 'object': 'chat.completion.chunk',
+ 'created': completion_timestamp,
+ 'model': model,
+ 'choices': [
+ {
+ 'index': 0,
+ 'delta': {
+ 'content': chunk,
+ },
+ 'finish_reason': None,
+ }
+ ],
+ }
+
+ content = json.dumps(completion_data, separators=(',', ':'))
+ yield f'data: {content}\n\n'
+ time.sleep(0.03)
+
+ end_completion_data = {
'id': f'chatcmpl-{completion_id}',
'object': 'chat.completion.chunk',
'created': completion_timestamp,
@@ -165,63 +139,24 @@ class Api:
'choices': [
{
'index': 0,
- 'delta': {
- 'content': chunk,
- },
- 'finish_reason': None,
+ 'delta': {},
+ 'finish_reason': 'stop',
}
],
}
- content = json.dumps(completion_data, separators=(',', ':'))
+ content = json.dumps(end_completion_data, separators=(',', ':'))
yield f'data: {content}\n\n'
- time.sleep(0.03)
- end_completion_data = {
- 'id': f'chatcmpl-{completion_id}',
- 'object': 'chat.completion.chunk',
- 'created': completion_timestamp,
- 'model': model,
- 'choices': [
- {
- 'index': 0,
- 'delta': {},
- 'finish_reason': 'stop',
- }
- ],
- }
-
- content = json.dumps(end_completion_data, separators=(',', ':'))
- yield f'data: {content}\n\n'
-
- logger.success(f'model: {model}, stream: {stream}')
-
- except GeneratorExit:
- pass
-
- return self.app.response_class(streaming(), mimetype='text/event-stream')
-
- async def completions(self):
- return 'not working yet', 500
-
- async def model_info(self, model_name):
- model_info = (g4f.ModelUtils.convert[model_name])
-
- return jsonify({
- 'id' : model_name,
- 'object' : 'model',
- 'created' : 0,
- 'owned_by' : model_info.base_provider
- })
-
- async def models(self):
- model_list = [{
- 'id' : model,
- 'object' : 'model',
- 'created' : 0,
- 'owned_by' : 'g4f'} for model in g4f.Model.__all__()]
-
- return jsonify({
- 'object': 'list',
- 'data': model_list})
- \ No newline at end of file
+ except GeneratorExit:
+ pass
+
+ return Response(content=json.dumps(streaming(), indent=4), media_type="application/json")
+
+ @self.app.post("/v1/completions")
+ async def completions():
+ return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
+
+ def run(self, ip):
+ split_ip = ip.split(":")
+ uvicorn.run(app=self.app, host=split_ip[0], port=int(split_ip[1]), use_colors=False)
diff --git a/g4f/api/run.py b/g4f/api/run.py
index 12bf9eed..88f34741 100644
--- a/g4f/api/run.py
+++ b/g4f/api/run.py
@@ -3,4 +3,4 @@ import g4f.api
if __name__ == "__main__":
print(f'Starting server... [g4f v-{g4f.version}]')
- g4f.api.Api(g4f).run('127.0.0.1:1337', 8) \ No newline at end of file
+ g4f.api.Api(engine = g4f, debug = True).run(ip = "127.0.0.1:1337")