File size: 2,270 Bytes
496390f 16320e6 e6dbd5f 52ef174 e6dbd5f 496390f 16320e6 e6dbd5f 496390f 236b12b 496390f 236b12b 496390f e6dbd5f 16320e6 e6dbd5f 1868dc4 52ef174 e6dbd5f 46a8c74 c302efc e6dbd5f c302efc 16320e6 c302efc 16320e6 c302efc e6dbd5f c302efc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from typing import Union, List # 添加必要的类型导入
import numpy as np
import logging, os
# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
# 去掉 Bearer 和后面的空格
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
token = authorization[len("Bearer "):]
if token != os.environ.get("AUTHORIZATION"):
raise HTTPException(status_code=401, detail="Unauthorized access")
return token
app = FastAPI()
try:
# Load the Sentence Transformer model
model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail="Model loading failed")
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]] # 修复类型定义
@app.post("/v1/embeddings")
async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)):
input_data = request.input
# 统一转换为列表处理
inputs = [input_data] if isinstance(input_data, str) else input_data
if not inputs:
return { ... } # 空输入处理
# 计算嵌入向量(二维numpy数组)
embeddings = model.encode(inputs, normalize_embeddings=True)
# 构建符合OpenAI格式的响应
data_entries = []
for idx, embed in enumerate(embeddings):
data_entries.append({
"object": "embedding",
"embedding": embed.tolist(), # 每个embed是一维数组
"index": idx
})
return {
"object": "list",
"data": data_entries, # 包含每个输入的嵌入对象
"model": "BAAI/bge-large-zh-v1.5",
"usage": {
"prompt_tokens": sum(len(text) for text in inputs), # 粗略估计token数
"total_tokens": sum(len(text) for text in inputs)
}
} |