Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Depends, HTTPException, Request | |
from fastapi.responses import HTMLResponse # β Add this line | |
from fastapi.security import APIKeyHeader | |
from pydantic import BaseModel | |
from model import ner_pipeline | |
import logging | |
import time | |
import json | |
import os | |
import secrets | |
from slowapi import Limiter, _rate_limit_exceeded_handler | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
limiter = Limiter(key_func=get_remote_address) | |
app.state.limiter = limiter | |
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) | |
#API_KEY_FILE = "api_keys.json" | |
API_KEY_FILE = "/tmp/api_keys.json" | |
token_header = APIKeyHeader(name="X-API-KEY") | |
if os.path.exists(API_KEY_FILE): | |
with open(API_KEY_FILE, "r") as f: | |
API_KEYS_STORE = json.load(f) | |
else: | |
API_KEYS_STORE = {"users": {}} | |
with open(API_KEY_FILE, "w") as f: | |
json.dump(API_KEYS_STORE, f) | |
API_KEYS = API_KEYS_STORE.get("users", {}) | |
ADMIN_KEY = os.getenv("ADMIN_KEY") | |
class TextRequest(BaseModel): | |
text: str | |
class RegisterRequest(BaseModel): | |
label: str = "user" | |
def save_keys(): | |
API_KEYS_STORE["users"] = API_KEYS | |
with open(API_KEY_FILE, "w") as f: | |
json.dump(API_KEYS_STORE, f, indent=2) | |
def verify_token(x_api_key: str = Depends(token_header)): | |
if x_api_key not in API_KEYS: | |
raise HTTPException(status_code=403, detail="Unauthorized") | |
return x_api_key | |
""" | |
@app.post("/ner") | |
@limiter.limit("10/minute") | |
def ner_predict(request: TextRequest, api_key: str = Depends(verify_token), req: Request = None): | |
logger.info("Received NER request from IP: %s", get_remote_address(req)) | |
predictions = ner_pipeline(request.text) | |
API_KEYS[api_key]["usage_count"] = API_KEYS[api_key].get("usage_count", 0) + 1 | |
save_keys() | |
return { | |
"entities": predictions, | |
"usage": API_KEYS[api_key]["usage_count"] | |
} | |
""" | |
from fastapi import Request | |
def ner_predict(body: TextRequest, request: Request, api_key: str = Depends(verify_token)): | |
logger.info("NER request from IP: %s", get_remote_address(request)) | |
raw_predictions = ner_pipeline(body.text) | |
# Use original character spans to extract clean entity words | |
predictions = [ | |
{ | |
**ent, | |
"score": float(ent["score"]), | |
"word": body.text[ent["start"]:ent["end"]].strip() # β Clean spaces | |
} | |
for ent in raw_predictions | |
] | |
API_KEYS[api_key]["usage_count"] = API_KEYS[api_key].get("usage_count", 0) + 1 | |
save_keys() | |
return { | |
"entities": predictions, | |
"usage": API_KEYS[api_key]["usage_count"] | |
} | |
def register_user(request: RegisterRequest, x_api_key: str = Depends(token_header)): | |
if x_api_key != ADMIN_KEY: | |
raise HTTPException(status_code=403, detail="Admin access required") | |
new_key = secrets.token_urlsafe(32) | |
API_KEYS[new_key] = {"usage_count": 0, "label": request.label} | |
save_keys() | |
return {"message": "User registered", "api_key": new_key} | |
def list_users(x_api_key: str = Depends(token_header)): | |
if x_api_key != ADMIN_KEY: | |
raise HTTPException(status_code=403, detail="Admin access required") | |
return {"users": API_KEYS} | |
def root(): | |
return HTMLResponse(""" | |
<html> | |
<head><title>NER API</title></head> | |
<body> | |
<h1>β NER API is running!</h1> | |
<p>Visit <a href='/docs'>/docs</a> to try the API.</p> | |
</body> | |
</html> | |
""") | |