diff options
author | abc <98614666+xtekky@users.noreply.github.com> | 2023-10-20 20:04:13 +0200 |
---|---|---|
committer | abc <98614666+xtekky@users.noreply.github.com> | 2023-10-20 20:04:13 +0200 |
commit | 8e7e694d81e674db63049145a35972df8ad2e3fa (patch) | |
tree | d6000ee808385a20f85e03e3173e3aa16a605757 /g4f/api/__init__.py | |
parent | ~ (diff) | |
download | gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.tar gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.tar.gz gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.tar.bz2 gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.tar.lz gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.tar.xz gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.tar.zst gpt4free-8e7e694d81e674db63049145a35972df8ad2e3fa.zip |
Diffstat (limited to '')
-rw-r--r-- | g4f/api/__init__.py | 346 |
1 files changed, 195 insertions, 151 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py index b19a721b..ecc70a13 100644 --- a/g4f/api/__init__.py +++ b/g4f/api/__init__.py @@ -1,162 +1,206 @@ +import g4f +import time import json import random import string -import time - -# import requests -from flask import Flask, request -from flask_cors import CORS -# from transformers import AutoTokenizer - -from g4f import ChatCompletion - -app = Flask(__name__) -CORS(app) - - -@app.route("/") -def index(): - return "interference api, url: http://127.0.0.1:1337" - - -@app.route("/chat/completions", methods=["POST"]) -def chat_completions(): - model = request.get_json().get("model", "gpt-3.5-turbo") - stream = request.get_json().get("stream", False) - messages = request.get_json().get("messages") - - response = ChatCompletion.create(model=model, stream=stream, messages=messages) - - completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28)) - completion_timestamp = int(time.time()) - - if not stream: - return { - "id": f"chatcmpl-{completion_id}", - "object": "chat.completion", - "created": completion_timestamp, - "model": model, - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": response, - }, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - }, - } - - def streaming(): - for chunk in response: - completion_data = { - "id": f"chatcmpl-{completion_id}", - "object": "chat.completion.chunk", - "created": completion_timestamp, - "model": model, - "choices": [ +import logging + +from typing import Union +from loguru import logger +from waitress import serve +from ._logging import hook_logging +from ._tokenizer import tokenize +from flask_cors import CORS +from werkzeug.serving import WSGIRequestHandler +from werkzeug.exceptions import default_exceptions +from werkzeug.middleware.proxy_fix import ProxyFix + +from flask import ( + Flask, + jsonify, + make_response, + request, +) + +class Api: + __default_ip = '127.0.0.1' + __default_port = 1337 + + def __init__(self, engine: g4f, debug: bool = True, sentry: bool = False) -> None: + self.engine = engine + self.debug = debug + self.sentry = sentry + self.log_level = logging.DEBUG if debug else logging.WARN + + hook_logging(level=self.log_level, format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s') + self.logger = logging.getLogger('waitress') + + self.app = Flask(__name__) + self.app.wsgi_app = ProxyFix(self.app.wsgi_app, x_port=1) + self.app.after_request(self.__after_request) + + def run(self, bind_str, threads=8): + host, port = self.__parse_bind(bind_str) + + CORS(self.app, resources={r'/v1/*': {'supports_credentials': True, 'expose_headers': [ + 'Content-Type', + 'Authorization', + 'X-Requested-With', + 'Accept', + 'Origin', + 'Access-Control-Request-Method', + 'Access-Control-Request-Headers', + 'Content-Disposition'], 'max_age': 600}}) + + self.app.route('/v1/models', methods=['GET'])(self.models) + self.app.route('v1/models/<model_id>', methods=['GET'])(self.model_info) + + self.app.route('/v1/chat/completions', methods=['POST'])(self.chat_completions) + self.app.route('/v1/completions', methods=['POST'])(self.completions) + + for ex in default_exceptions: + self.app.register_error_handler(ex, self.__handle_error) + + if not self.debug: + self.logger.warning('Serving on http://{}:{}'.format(host, port)) + + WSGIRequestHandler.protocol_version = 'HTTP/1.1' + serve(self.app, host=host, port=port, ident=None, threads=threads) + + def __handle_error(self, e: Exception): + self.logger.error(e) + + return make_response(jsonify({ + 'code': e.code, + 'message': str(e.original_exception if self.debug and hasattr(e, 'original_exception') else e.name)}), 500) + + @staticmethod + def __after_request(resp): + resp.headers['X-Server'] = 'g4f/%s' % g4f.version + + return resp + + def __parse_bind(self, bind_str): + sections = bind_str.split(':', 2) + if len(sections) < 2: + try: + port = int(sections[0]) + return self.__default_ip, port + except ValueError: + return sections[0], self.__default_port + + return sections[0], int(sections[1]) + + async def home(self): + return 'Hello world | https://127.0.0.1:1337/v1' + + async def chat_completions(self): + model = request.json.get('model', 'gpt-3.5-turbo') + stream = request.json.get('stream', False) + messages = request.json.get('messages') + + logger.info(f'model: {model}, stream: {stream}, request: {messages[-1]["content"]}') + + response = self.engine.ChatCompletion.create(model=model, + stream=stream, messages=messages) + + completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28)) + completion_timestamp = int(time.time()) + + if not stream: + prompt_tokens, _ = tokenize(''.join([message['content'] for message in messages])) + completion_tokens, _ = tokenize(response) + + return { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion', + 'created': completion_timestamp, + 'model': model, + 'choices': [ { - "index": 0, - "delta": { - "content": chunk, + 'index': 0, + 'message': { + 'role': 'assistant', + 'content': response, }, - "finish_reason": None, + 'finish_reason': 'stop', } ], + 'usage': { + 'prompt_tokens': prompt_tokens, + 'completion_tokens': completion_tokens, + 'total_tokens': prompt_tokens + completion_tokens, + }, } - content = json.dumps(completion_data, separators=(",", ":")) - yield f"data: {content}\n\n" - time.sleep(0.1) + def streaming(): + try: + for chunk in response: + completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'delta': { + 'content': chunk, + }, + 'finish_reason': None, + } + ], + } - end_completion_data = { - "id": f"chatcmpl-{completion_id}", - "object": "chat.completion.chunk", - "created": completion_timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop", + content = json.dumps(completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' + time.sleep(0.03) + + end_completion_data = { + 'id': f'chatcmpl-{completion_id}', + 'object': 'chat.completion.chunk', + 'created': completion_timestamp, + 'model': model, + 'choices': [ + { + 'index': 0, + 'delta': {}, + 'finish_reason': 'stop', + } + ], } - ], - } - content = json.dumps(end_completion_data, separators=(",", ":")) - yield f"data: {content}\n\n" - - return app.response_class(streaming(), mimetype="text/event-stream") - - -# Get the embedding from huggingface -# def get_embedding(input_text, token): -# huggingface_token = token -# embedding_model = "sentence-transformers/all-mpnet-base-v2" -# max_token_length = 500 - -# # Load the tokenizer for the 'all-mpnet-base-v2' model -# tokenizer = AutoTokenizer.from_pretrained(embedding_model) -# # Tokenize the text and split the tokens into chunks of 500 tokens each -# tokens = tokenizer.tokenize(input_text) -# token_chunks = [ -# tokens[i : i + max_token_length] -# for i in range(0, len(tokens), max_token_length) -# ] - -# # Initialize an empty list -# embeddings = [] - -# # Create embeddings for each chunk -# for chunk in token_chunks: -# # Convert the chunk tokens back to text -# chunk_text = tokenizer.convert_tokens_to_string(chunk) - -# # Use the Hugging Face API to get embeddings for the chunk -# api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{embedding_model}" -# headers = {"Authorization": f"Bearer {huggingface_token}"} -# chunk_text = chunk_text.replace("\n", " ") - -# # Make a POST request to get the chunk's embedding -# response = requests.post( -# api_url, -# headers=headers, -# json={"inputs": chunk_text, "options": {"wait_for_model": True}}, -# ) - -# # Parse the response and extract the embedding -# chunk_embedding = response.json() -# # Append the embedding to the list -# embeddings.append(chunk_embedding) - -# # averaging all the embeddings -# # this isn't very effective -# # someone a better idea? -# num_embeddings = len(embeddings) -# average_embedding = [sum(x) / num_embeddings for x in zip(*embeddings)] -# embedding = average_embedding -# return embedding - - -# @app.route("/embeddings", methods=["POST"]) -# def embeddings(): -# input_text_list = request.get_json().get("input") -# input_text = " ".join(map(str, input_text_list)) -# token = request.headers.get("Authorization").replace("Bearer ", "") -# embedding = get_embedding(input_text, token) - -# return { -# "data": [{"embedding": embedding, "index": 0, "object": "embedding"}], -# "model": "text-embedding-ada-002", -# "object": "list", -# "usage": {"prompt_tokens": None, "total_tokens": None}, -# } - - -def run_api(): - app.run(host="0.0.0.0", port=1337) + + content = json.dumps(end_completion_data, separators=(',', ':')) + yield f'data: {content}\n\n' + + logger.success(f'model: {model}, stream: {stream}') + + except GeneratorExit: + pass + + return self.app.response_class(streaming(), mimetype='text/event-stream') + + async def completions(self): + return 'not working yet', 500 + + async def model_info(self, model_name): + model_info = (g4f.ModelUtils.convert[model_name]) + + return jsonify({ + 'id' : model_name, + 'object' : 'model', + 'created' : 0, + 'owned_by' : model_info.base_provider + }) + + async def models(self): + model_list = [{ + 'id' : model, + 'object' : 'model', + 'created' : 0, + 'owned_by' : 'g4f'} for model in g4f.Model.__all__()] + + return jsonify({ + 'object': 'list', + 'data': model_list}) +
\ No newline at end of file |