summaryrefslogtreecommitdiffstats
path: root/g4f/api/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'g4f/api/__init__.py')
-rw-r--r--g4f/api/__init__.py57
1 files changed, 52 insertions, 5 deletions
diff --git a/g4f/api/__init__.py b/g4f/api/__init__.py
index 02ba5260..21e69388 100644
--- a/g4f/api/__init__.py
+++ b/g4f/api/__init__.py
@@ -4,6 +4,7 @@ import logging
import json
import uvicorn
import secrets
+import os
from fastapi import FastAPI, Response, Request
from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
@@ -13,13 +14,16 @@ from starlette.exceptions import HTTPException
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from fastapi.encoders import jsonable_encoder
from fastapi.middleware.cors import CORSMiddleware
+from starlette.responses import FileResponse
from pydantic import BaseModel
from typing import Union, Optional
import g4f
import g4f.debug
from g4f.client import AsyncClient, ChatCompletion
+from g4f.providers.response import BaseConversation
from g4f.client.helper import filter_none
+from g4f.image import is_accepted_format, images_dir
from g4f.typing import Messages
from g4f.cookies import read_cookie_files
@@ -63,6 +67,7 @@ class ChatCompletionsConfig(BaseModel):
api_key: Optional[str] = None
web_search: Optional[bool] = None
proxy: Optional[str] = None
+ conversation_id: str = None
class ImageGenerationConfig(BaseModel):
prompt: str
@@ -98,6 +103,7 @@ class Api:
self.client = AsyncClient()
self.g4f_api_key = g4f_api_key
self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
+ self.conversations: dict[str, dict[str, BaseConversation]] = {}
def register_authorization(self):
@self.app.middleware("http")
@@ -179,12 +185,21 @@ class Api:
async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
try:
config.provider = provider if config.provider is None else config.provider
+ if config.provider is None:
+ config.provider = AppConfig.provider
if config.api_key is None and request is not None:
auth_header = request.headers.get("Authorization")
if auth_header is not None:
- auth_header = auth_header.split(None, 1)[-1]
- if auth_header and auth_header != "Bearer":
- config.api_key = auth_header
+ api_key = auth_header.split(None, 1)[-1]
+ if api_key and api_key != "Bearer":
+ config.api_key = api_key
+
+ conversation = return_conversation = None
+ if config.conversation_id is not None and config.provider is not None:
+ return_conversation = True
+ if config.conversation_id in self.conversations:
+ if config.provider in self.conversations[config.conversation_id]:
+ conversation = self.conversations[config.conversation_id][config.provider]
# Create the completion response
response = self.client.chat.completions.create(
@@ -194,6 +209,11 @@ class Api:
"provider": AppConfig.provider,
"proxy": AppConfig.proxy,
**config.dict(exclude_none=True),
+ **{
+ "conversation_id": None,
+ "return_conversation": return_conversation,
+ "conversation": conversation
+ }
},
ignored=AppConfig.ignored_providers
),
@@ -206,7 +226,13 @@ class Api:
async def streaming():
try:
async for chunk in response:
- yield f"data: {json.dumps(chunk.to_json())}\n\n"
+ if isinstance(chunk, BaseConversation):
+ if config.conversation_id is not None and config.provider is not None:
+ if config.conversation_id not in self.conversations:
+ self.conversations[config.conversation_id] = {}
+ self.conversations[config.conversation_id][config.provider] = chunk
+ else:
+ yield f"data: {json.dumps(chunk.to_json())}\n\n"
except GeneratorExit:
pass
except Exception as e:
@@ -222,7 +248,13 @@ class Api:
@self.app.post("/v1/images/generate")
@self.app.post("/v1/images/generations")
- async def generate_image(config: ImageGenerationConfig):
+ async def generate_image(config: ImageGenerationConfig, request: Request):
+ if config.api_key is None:
+ auth_header = request.headers.get("Authorization")
+ if auth_header is not None:
+ api_key = auth_header.split(None, 1)[-1]
+ if api_key and api_key != "Bearer":
+ config.api_key = api_key
try:
response = await self.client.images.generate(
prompt=config.prompt,
@@ -234,6 +266,9 @@ class Api:
proxy = config.proxy
)
)
+ for image in response.data:
+ if hasattr(image, "url") and image.url.startswith("/"):
+ image.url = f"{request.base_url}{image.url.lstrip('/')}"
return JSONResponse(response.to_json())
except Exception as e:
logger.exception(e)
@@ -243,6 +278,18 @@ class Api:
async def completions():
return Response(content=json.dumps({'info': 'Not working yet.'}, indent=4), media_type="application/json")
+ @self.app.get("/images/{filename}")
+ async def get_image(filename):
+ target = os.path.join(images_dir, filename)
+
+ if not os.path.isfile(target):
+ return Response(status_code=404)
+
+ with open(target, "rb") as f:
+ content_type = is_accepted_format(f.read(12))
+
+ return FileResponse(target, media_type=content_type)
+
def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
last_provider = {} if not image else g4f.get_last_provider(True)
provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider