from fastapi import FastAPI, HTTPException, Depends, status, BackgroundTasks from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel from jose import JWTError, jwt from datetime import datetime, timedelta, timezone from openai import OpenAI from pathlib import Path from typing import List, Optional, Dict, Literal from datasets import Dataset, load_dataset from sentence_transformers import SentenceTransformer from huggingface_hub import login from contextlib import asynccontextmanager import pandas as pd import numpy as np import torch as t import os import logging from functools import lru_cache from diskcache import Cache import json import asyncio # Configure logging logging.basicConfig(level=logging.INFO) @asynccontextmanager async def lifespan(app: FastAPI): # Preload the model get_sentence_transformer() yield # Initialize FastAPI app app = FastAPI() # Initialize disk cache cache = Cache('./cache') # JWT Configuration SECRET_KEY = os.environ.get("PRIME_AUTH", "c0369f977b69e717dc16f6fc574039eb2b1ebde38014d2be") REFRESH_SECRET_KEY = os.environ.get("PROLONGED_AUTH", "916018771b29084378c9362c0cd9e631fd4927b8aea07f91") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 REFRESH_TOKEN_EXPIRE_DAYS = 7 oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") # Pydantic models class QueryInput(BaseModel): query: str class SearchResult(BaseModel): text: str similarity: float model_type: Literal["WhereIsAI_UAE_Large_V1", "BAAI_bge_large_en_v1.5"] class TokenResponse(BaseModel): access_token: str refresh_token: str token_type: str class SaveInput(BaseModel): user_type: str username: str query: str retrieved_text: str model_type: Literal["WhereIsAI_UAE_Large_V1", "BAAI_bge_large_en_v1.5"] reaction: str confidence_score: float class SaveBatchInput(BaseModel): items: List[SaveInput] class RefreshRequest(BaseModel): refresh_token: str # Cache management @lru_cache(maxsize=2) # Cache both models def get_embedding_models(): """Load and cache both embedding models""" return { "uae-large": SentenceTransformer("WhereIsAI/UAE-Large-V1", device="cpu"), "bge-large": SentenceTransformer("BAAI/bge-large-en-v1.5", device="cpu") } def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]: """Try to get embeddings from cache""" cache_key = f"{model_type}_{hash(text)}" return cache.get(cache_key) def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]): """Store embeddings in cache""" cache_key = f"{model_type}_{hash(text)}" cache.set(cache_key, embeddings, expire=86400) # Cache for 24 hours @lru_cache(maxsize=1) def load_dataframe(): """Load and cache the parquet dataframe""" database_file = Path(__file__).parent / "[embed] The Alchemy of Happiness (Ghazzālī, Claud Field).parquet" return pd.read_parquet(database_file) # Utility functions def cosine_similarity(embedding_0, embedding_1): dot_product = sum(a * b for a, b in zip(embedding_0, embedding_1)) norm_0 = sum(a * a for a in embedding_0) ** 0.5 norm_1 = sum(b * b for b in embedding_1) ** 0.5 return dot_product / (norm_0 * norm_1) def generate_embedding(model, text: str, model_type: str) -> List[float]: cached_embedding = get_cached_embeddings(text, model_type) if cached_embedding is not None: return cached_embedding # Generate new embedding embedding = model.encode( text, convert_to_tensor=True, normalize_embeddings=True # Important for UAE and BGE models ) embedding = np.array(t.Tensor.cpu(embedding)).tolist() set_cached_embeddings(text, model_type, embedding) return embedding def search_query(st_models, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]: # Generate embeddings with both models uae_embedding = generate_embedding(st_models["uae-large"], query, "uae-large") bge_embedding = generate_embedding(st_models["bge-large"], query, "bge-large") # Calculate similarities df['uae_similarities'] = df["WhereIsAI_UAE_Large_V1"].apply( lambda x: cosine_similarity(x, uae_embedding) ) df['bge_similarities'] = df["BAAI_bge_large_en_v1.5"].apply( lambda x: cosine_similarity(x, bge_embedding) ) # Get top results for each model uae_results = df.nlargest(n, 'uae_similarities') bge_results = df.nlargest(n, 'bge_similarities') # Format results results = [] for _, row in uae_results.iterrows(): results.append({ "text": row["ext"], "similarity": float(row["uae_similarities"]), "model_type": "WhereIsAI_UAE_Large_V1" }) for _, row in bge_results.iterrows(): results.append({ "text": row["ext"], "similarity": float(row["bge_similarities"]), "model_type": "BAAI_bge_large_en_v1.5" }) return results # Authentication functions def load_credentials(): credentials = {} for i in range(1, 51): username = os.environ.get(f"login_{i}") password = os.environ.get(f"password_{i}") if username and password: credentials[username] = password return credentials def authenticate_user(username: str, password: str): credentials_dict = load_credentials() if username in credentials_dict and credentials_dict[username] == password: return username return None def create_token(data: dict, expires_delta: timedelta, secret_key: str): to_encode = data.copy() expire = datetime.utcnow() + expires_delta to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=ALGORITHM) return encoded_jwt def verify_token(token: str, secret_key: str): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, secret_key, algorithms=[ALGORITHM]) username: str = payload.get("sub") if username is None: raise credentials_exception except JWTError: raise credentials_exception return username def verify_access_token(token: str = Depends(oauth2_scheme)): username = verify_token(token, SECRET_KEY) # Check if token is blacklisted if cache.get(f"blacklist_{token}"): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"}, ) return username # Endpoints @app.get("/") def index() -> FileResponse: """Serve the custom HTML page from the static directory""" file_path = Path(__file__).parent / "static" / "index.html" return FileResponse(path=str(file_path), media_type="text/html") @app.post("/login", response_model=TokenResponse) def login_app(form_data: OAuth2PasswordRequestForm = Depends()): username = authenticate_user(form_data.username, form_data.password) if not username: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) refresh_token_expires = timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) refresh_token = create_token( data={"sub": username}, expires_delta=refresh_token_expires, secret_key=REFRESH_SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_token, "token_type": "bearer" } @app.post("/refresh", response_model=TokenResponse) async def refresh(refresh_request: RefreshRequest): """ Endpoint to refresh an access token using a valid refresh token. Returns a new access token and the existing refresh token. """ try: # Verify the refresh token username = verify_token(refresh_request.refresh_token, REFRESH_SECRET_KEY) # Create new access token access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_token( data={"sub": username}, expires_delta=access_token_expires, secret_key=SECRET_KEY ) return { "access_token": access_token, "refresh_token": refresh_request.refresh_token, # Return the same refresh token "token_type": "bearer" } except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) @app.post("/logout") def logout( token: str = Depends(oauth2_scheme), username: str = Depends(verify_access_token) ): try: # Decode token to get expiration time payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) exp_timestamp = payload.get("exp") if exp_timestamp is None: raise HTTPException(status_code=400, detail="Token missing expiration time") # Calculate remaining token validity current_time = datetime.now(timezone.utc).timestamp() remaining_time = exp_timestamp - current_time if remaining_time > 0: # Add to blacklist cache with TTL matching token expiration cache_key = f"blacklist_{token}" cache.set(cache_key, True, expire=remaining_time) return {"message": "Successfully logged out"} except JWTError: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token", headers={"WWW-Authenticate": "Bearer"}, ) @app.post("/search", response_model=List[SearchResult]) async def search( query_input: QueryInput, username: str = Depends(verify_access_token), ): try: st_models = get_embedding_models() df = load_dataframe() results = search_query(st_models, query_input.query, df, n=1) return [SearchResult(**result) for result in results] except Exception as e: logging.error(f"Search error: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Search failed: {str(e)}" ) # new constants QUEUE_FILE = "./save_queue.jsonl" PUSH_INTERVAL_S = 300 # seconds QUEUE_THRESHOLD = 100 MAX_PUSH_INTERVAL_S = 47 * 3600 # 44 hours # background task to batch-push queued records async def _hf_sync_loop(): # authenticate once for private repo access hf_token = os.environ.get("HF_TOKEN") if not hf_token: logging.error("HF_TOKEN not set for Hugging Face authentication") return login(token=hf_token) last_push_time = datetime.now(timezone.utc).timestamp() while True: await asyncio.sleep(PUSH_INTERVAL_S) try: # Count lines in queue file if not os.path.exists(QUEUE_FILE): continue with open(QUEUE_FILE, "r") as f: lines = f.read().splitlines() queue_len = len(lines) now = datetime.now(timezone.utc).timestamp() time_since_last_push = now - last_push_time print(f"Queue length: {queue_len}, Time since last push: {time_since_last_push}") # Only push if threshold met or max interval if queue_len >= QUEUE_THRESHOLD or time_since_last_push >= MAX_PUSH_INTERVAL_S: if not lines: last_push_time = now continue new_records = [json.loads(l) for l in lines] # load remote dataset with auth dataset = load_dataset( "HumbleBeeAI/al-ghazali-rag-retrieval-evaluation", split="train" ) data = dataset.to_dict() # append new records for rec in new_records: for k, v in rec.items(): data.setdefault(k, []).append(v) updated = Dataset.from_dict(data) # push with token updated.push_to_hub( "HumbleBeeAI/al-ghazali-rag-retrieval-evaluation", token=hf_token ) # clear queue open(QUEUE_FILE, "w").close() last_push_time = now except Exception as e: logging.error(f"Background sync failed: {e}") # replace existing startup_event @app.on_event("startup") async def startup_event(): os.makedirs("./cache", exist_ok=True) Path(QUEUE_FILE).touch(exist_ok=True) # start background sync loop asyncio.create_task(_hf_sync_loop()) # replace existing /save endpoint @app.post("/save") async def save_data( save_input: SaveBatchInput, username: str = Depends(verify_access_token) ): records = [] for item in save_input.items: records.append({ "user_type": item.user_type, "username": item.username, "query": item.query, "retrieved_text": item.retrieved_text, "model_type": item.model_type, "reaction": item.reaction, "timestamp": datetime.now(timezone.utc).isoformat().replace('+00:00','Z'), "confidence_score": item.confidence_score }) # append to local queue with open(QUEUE_FILE, "a") as f: for r in records: f.write(json.dumps(r) + "\n") return {"message": "Your data is queued for batch upload."} # Make sure to keep the static files mounting app.mount("/home", StaticFiles(directory="static", html=True), name="home") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)