File size: 2,270 Bytes
496390f
16320e6
e6dbd5f
52ef174
e6dbd5f
496390f
16320e6
 
 
 
e6dbd5f
496390f
 
236b12b
 
 
 
 
 
496390f
236b12b
496390f
e6dbd5f
 
16320e6
 
 
 
 
 
e6dbd5f
1868dc4
52ef174
e6dbd5f
46a8c74
c302efc
 
 
 
e6dbd5f
c302efc
 
16320e6
c302efc
 
16320e6
c302efc
 
 
 
 
 
 
 
e6dbd5f
c302efc
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from typing import Union, List  # 添加必要的类型导入
import numpy as np
import logging, os

# 设置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
    # 去掉 Bearer 和后面的空格
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid Authorization header format")
    
    token = authorization[len("Bearer "):]
    if token != os.environ.get("AUTHORIZATION"):
        raise HTTPException(status_code=401, detail="Unauthorized access")
    return token

app = FastAPI()

try:
    # Load the Sentence Transformer model
    model = SentenceTransformer("BAAI/bge-large-zh-v1.5")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise HTTPException(status_code=500, detail="Model loading failed")

class EmbeddingRequest(BaseModel):
    input: Union[str, List[str]]  # 修复类型定义

@app.post("/v1/embeddings")
async def embeddings(request: EmbeddingRequest, authorization: str = Depends(check_authorization)):
    input_data = request.input
    # 统一转换为列表处理
    inputs = [input_data] if isinstance(input_data, str) else input_data

    if not inputs:
        return { ... }  # 空输入处理

    # 计算嵌入向量(二维numpy数组)
    embeddings = model.encode(inputs, normalize_embeddings=True)

    # 构建符合OpenAI格式的响应
    data_entries = []
    for idx, embed in enumerate(embeddings):
        data_entries.append({
            "object": "embedding",
            "embedding": embed.tolist(),  # 每个embed是一维数组
            "index": idx
        })

    return {
        "object": "list",
        "data": data_entries,  # 包含每个输入的嵌入对象
        "model": "BAAI/bge-large-zh-v1.5",
        "usage": {
            "prompt_tokens": sum(len(text) for text in inputs),  # 粗略估计token数
            "total_tokens": sum(len(text) for text in inputs)
        }
    }