|
from fastapi import FastAPI, HTTPException, Depends, status, BackgroundTasks |
|
from fastapi.responses import FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from pydantic import BaseModel |
|
from jose import JWTError, jwt |
|
from datetime import datetime, timedelta, timezone |
|
from openai import OpenAI |
|
from pathlib import Path |
|
from typing import List, Optional, Dict, Literal |
|
from datasets import Dataset, load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from huggingface_hub import login |
|
from contextlib import asynccontextmanager |
|
import pandas as pd |
|
import numpy as np |
|
import torch as t |
|
import os |
|
import logging |
|
from functools import lru_cache |
|
from diskcache import Cache |
|
import json |
|
import asyncio |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
get_sentence_transformer() |
|
yield |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
cache = Cache('./cache') |
|
|
|
|
|
SECRET_KEY = os.environ.get("PRIME_AUTH", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be") |
|
REFRESH_SECRET_KEY = os.environ.get("PROLONGED_AUTH", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91") |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
REFRESH_TOKEN_EXPIRE_DAYS = 7 |
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") |
|
|
|
|
|
class QueryInput(BaseModel): |
|
query: str |
|
|
|
class SearchResult(BaseModel): |
|
text: str |
|
similarity: float |
|
model_type: Literal["WhereIsAI_UAE_Large_V1", "BAAI_bge_large_en_v1.5"] |
|
|
|
class TokenResponse(BaseModel): |
|
access_token: str |
|
refresh_token: str |
|
token_type: str |
|
|
|
class SaveInput(BaseModel): |
|
user_type: str |
|
username: str |
|
query: str |
|
retrieved_text: str |
|
model_type: Literal["WhereIsAI_UAE_Large_V1", "BAAI_bge_large_en_v1.5"] |
|
reaction: str |
|
confidence_score: float |
|
|
|
class SaveBatchInput(BaseModel): |
|
items: List[SaveInput] |
|
|
|
class RefreshRequest(BaseModel): |
|
refresh_token: str |
|
|
|
|
|
@lru_cache(maxsize=2) |
|
def get_embedding_models(): |
|
"""Load and cache both embedding models""" |
|
return { |
|
"uae-large": SentenceTransformer("WhereIsAI/UAE-Large-V1", device="cpu"), |
|
"bge-large": SentenceTransformer("BAAI/bge-large-en-v1.5", device="cpu") |
|
} |
|
|
|
def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]: |
|
"""Try to get embeddings from cache""" |
|
cache_key = f"{model_type}_{hash(text)}" |
|
return cache.get(cache_key) |
|
|
|
def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]): |
|
"""Store embeddings in cache""" |
|
cache_key = f"{model_type}_{hash(text)}" |
|
cache.set(cache_key, embeddings, expire=86400) |
|
|
|
@lru_cache(maxsize=1) |
|
def load_dataframe(): |
|
"""Load and cache the parquet dataframe""" |
|
database_file = Path(__file__).parent / "[embed] The Alchemy of Happiness (Ghazzālī, Claud Field).parquet" |
|
return pd.read_parquet(database_file) |
|
|
|
|
|
def cosine_similarity(embedding_0, embedding_1): |
|
dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1)) |
|
norm_0 = sum(a * a for a in embedding_0) ** 0.5 |
|
norm_1 = sum(b * b for b in embedding_1) ** 0.5 |
|
return dot_product / (norm_0 * norm_1) |
|
|
|
def generate_embedding(model, text: str, model_type: str) -> List[float]: |
|
cached_embedding = get_cached_embeddings(text, model_type) |
|
if cached_embedding is not None: |
|
return cached_embedding |
|
|
|
|
|
embedding = model.encode( |
|
text, |
|
convert_to_tensor=True, |
|
normalize_embeddings=True |
|
) |
|
embedding = np.array(t.Tensor.cpu(embedding)).tolist() |
|
|
|
set_cached_embeddings(text, model_type, embedding) |
|
return embedding |
|
|
|
def search_query(st_models, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]: |
|
|
|
uae_embedding = generate_embedding(st_models["uae-large"], query, "uae-large") |
|
bge_embedding = generate_embedding(st_models["bge-large"], query, "bge-large") |
|
|
|
|
|
df['uae_similarities'] = df["WhereIsAI_UAE_Large_V1"].apply( |
|
lambda x: cosine_similarity(x, uae_embedding) |
|
) |
|
df['bge_similarities'] = df["BAAI_bge_large_en_v1.5"].apply( |
|
lambda x: cosine_similarity(x, bge_embedding) |
|
) |
|
|
|
|
|
uae_results = df.nlargest(n, 'uae_similarities') |
|
bge_results = df.nlargest(n, 'bge_similarities') |
|
|
|
|
|
results = [] |
|
|
|
for _, row in uae_results.iterrows(): |
|
results.append({ |
|
"text": row["ext"], |
|
"similarity": float(row["uae_similarities"]), |
|
"model_type": "WhereIsAI_UAE_Large_V1" |
|
}) |
|
|
|
for _, row in bge_results.iterrows(): |
|
results.append({ |
|
"text": row["ext"], |
|
"similarity": float(row["bge_similarities"]), |
|
"model_type": "BAAI_bge_large_en_v1.5" |
|
}) |
|
|
|
return results |
|
|
|
|
|
def load_credentials(): |
|
credentials = {} |
|
for i in range(1, 51): |
|
username = os.environ.get(f"login_{i}") |
|
password = os.environ.get(f"password_{i}") |
|
if username and password: |
|
credentials[username] = password |
|
return credentials |
|
|
|
def authenticate_user(username: str, password: str): |
|
credentials_dict = load_credentials() |
|
if username in credentials_dict and credentials_dict[username] == password: |
|
return username |
|
return None |
|
|
|
def create_token(data: dict, expires_delta: timedelta, secret_key: str): |
|
to_encode = data.copy() |
|
expire = datetime.utcnow() + expires_delta |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
def verify_token(token: str, secret_key: str): |
|
credentials_exception = HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
try: |
|
payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) |
|
username: str = payload.get("sub") |
|
if username is None: |
|
raise credentials_exception |
|
except JWTError: |
|
raise credentials_exception |
|
return username |
|
|
|
def verify_access_token(token: str = Depends(oauth2_scheme)): |
|
username = verify_token(token, SECRET_KEY) |
|
|
|
|
|
if cache.get(f"blacklist_{token}"): |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Token has been revoked", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
return username |
|
|
|
|
|
@app.get("/") |
|
def index() -> FileResponse: |
|
"""Serve the custom HTML page from the static directory""" |
|
file_path = Path(__file__).parent / "static" / "index.html" |
|
return FileResponse(path=str(file_path), media_type="text/html") |
|
|
|
@app.post("/login", response_model=TokenResponse) |
|
def login_app(form_data: OAuth2PasswordRequestForm = Depends()): |
|
username = authenticate_user(form_data.username, form_data.password) |
|
if not username: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) |
|
access_token = create_token( |
|
data={"sub": username}, |
|
expires_delta=access_token_expires, |
|
secret_key=SECRET_KEY |
|
) |
|
refresh_token = create_token( |
|
data={"sub": username}, |
|
expires_delta=refresh_token_expires, |
|
secret_key=REFRESH_SECRET_KEY |
|
) |
|
return { |
|
"access_token": access_token, |
|
"refresh_token": refresh_token, |
|
"token_type": "bearer" |
|
} |
|
|
|
@app.post("/refresh", response_model=TokenResponse) |
|
async def refresh(refresh_request: RefreshRequest): |
|
""" |
|
Endpoint to refresh an access token using a valid refresh token. |
|
Returns a new access token and the existing refresh token. |
|
""" |
|
try: |
|
|
|
username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) |
|
|
|
|
|
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) |
|
access_token = create_token( |
|
data={"sub": username}, |
|
expires_delta=access_token_expires, |
|
secret_key=SECRET_KEY |
|
) |
|
|
|
return { |
|
"access_token": access_token, |
|
"refresh_token": refresh_request.refresh_token, |
|
"token_type": "bearer" |
|
} |
|
|
|
except JWTError: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Could not validate credentials", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
@app.post("/logout") |
|
def logout( |
|
token: str = Depends(oauth2_scheme), |
|
username: str = Depends(verify_access_token) |
|
): |
|
try: |
|
|
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
exp_timestamp = payload.get("exp") |
|
if exp_timestamp is None: |
|
raise HTTPException(status_code=400, detail="Token missing expiration time") |
|
|
|
|
|
current_time = datetime.now(timezone.utc).timestamp() |
|
remaining_time = exp_timestamp - current_time |
|
|
|
if remaining_time > 0: |
|
|
|
cache_key = f"blacklist_{token}" |
|
cache.set(cache_key, True, expire=remaining_time) |
|
|
|
return {"message": "Successfully logged out"} |
|
|
|
except JWTError: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Invalid token", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
@app.post("/search", response_model=List[SearchResult]) |
|
async def search( |
|
query_input: QueryInput, |
|
username: str = Depends(verify_access_token), |
|
): |
|
try: |
|
st_models = get_embedding_models() |
|
df = load_dataframe() |
|
|
|
results = search_query(st_models, query_input.query, df, n=1) |
|
return [SearchResult(**result) for result in results] |
|
|
|
except Exception as e: |
|
logging.error(f"Search error: {str(e)}") |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail=f"Search failed: {str(e)}" |
|
) |
|
|
|
|
|
QUEUE_FILE = "./save_queue.jsonl" |
|
PUSH_INTERVAL_S = 300 |
|
QUEUE_THRESHOLD = 100 |
|
MAX_PUSH_INTERVAL_S = 47 * 3600 |
|
|
|
|
|
async def _hf_sync_loop(): |
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
if not hf_token: |
|
logging.error("HF_TOKEN not set for Hugging Face authentication") |
|
return |
|
login(token=hf_token) |
|
|
|
last_push_time = datetime.now(timezone.utc).timestamp() |
|
|
|
while True: |
|
await asyncio.sleep(PUSH_INTERVAL_S) |
|
try: |
|
|
|
if not os.path.exists(QUEUE_FILE): |
|
continue |
|
with open(QUEUE_FILE, "r") as f: |
|
lines = f.read().splitlines() |
|
queue_len = len(lines) |
|
now = datetime.now(timezone.utc).timestamp() |
|
time_since_last_push = now - last_push_time |
|
|
|
print(f"Queue length: {queue_len}, Time since last push: {time_since_last_push}") |
|
|
|
if queue_len >= QUEUE_THRESHOLD or time_since_last_push >= MAX_PUSH_INTERVAL_S: |
|
if not lines: |
|
last_push_time = now |
|
continue |
|
new_records = [json.loads(l) for l in lines] |
|
|
|
dataset = load_dataset( |
|
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation", |
|
split="train" |
|
) |
|
data = dataset.to_dict() |
|
|
|
for rec in new_records: |
|
for k, v in rec.items(): |
|
data.setdefault(k, []).append(v) |
|
updated = Dataset.from_dict(data) |
|
|
|
updated.push_to_hub( |
|
"HumbleBeeAI/al-ghazali-rag-retrieval-evaluation", |
|
token=hf_token |
|
) |
|
|
|
open(QUEUE_FILE, "w").close() |
|
last_push_time = now |
|
except Exception as e: |
|
logging.error(f"Background sync failed: {e}") |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
os.makedirs("./cache", exist_ok=True) |
|
Path(QUEUE_FILE).touch(exist_ok=True) |
|
|
|
asyncio.create_task(_hf_sync_loop()) |
|
|
|
|
|
@app.post("/save") |
|
async def save_data( |
|
save_input: SaveBatchInput, |
|
username: str = Depends(verify_access_token) |
|
): |
|
records = [] |
|
for item in save_input.items: |
|
records.append({ |
|
"user_type": item.user_type, |
|
"username": item.username, |
|
"query": item.query, |
|
"retrieved_text": item.retrieved_text, |
|
"model_type": item.model_type, |
|
"reaction": item.reaction, |
|
"timestamp": datetime.now(timezone.utc).isoformat().replace('+00:00','Z'), |
|
"confidence_score": item.confidence_score |
|
}) |
|
|
|
with open(QUEUE_FILE, "a") as f: |
|
for r in records: |
|
f.write(json.dumps(r) + "\n") |
|
return {"message": "Your data is queued for batch upload."} |
|
|
|
|
|
app.mount("/home", StaticFiles(directory="static", html=True), name="home") |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=7860) |