Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import JSONResponse, FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from optimum.neuron import utils | |
import logging | |
import sys | |
import os | |
import httpx | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[ | |
logging.StreamHandler(sys.stdout) | |
] | |
) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Get the absolute path to the static directory | |
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") | |
logger.info(f"Static directory path: {static_dir}") | |
# Get the absolute path to the templates directory | |
templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") | |
logger.info(f"Templates directory path: {templates_dir}") | |
# Mount static files and templates | |
app.mount("/static", StaticFiles(directory=static_dir), name="static") | |
templates = Jinja2Templates(directory=templates_dir) | |
async def health_check(): | |
logger.info("Health check endpoint called") | |
return {"status": "healthy"} | |
async def home(request: Request): | |
logger.info("Home page requested") | |
# Check if we're running in Spaces | |
is_spaces = os.getenv("SPACE_ID") is not None | |
# Use HTTPS only for Spaces, otherwise use the request's protocol | |
base_url = str(request.base_url) | |
if is_spaces: | |
base_url = base_url.replace("http://", "https://") | |
return templates.TemplateResponse( | |
"index.html", | |
{ | |
"request": request, | |
"base_url": base_url | |
} | |
) | |
async def get_model_list(): | |
logger.info("Fetching model list") | |
try: | |
# Add debug logging | |
logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}") | |
model_list = utils.get_hub_cached_models(mode="inference") | |
logger.info(f"Found {len(model_list)} models") | |
models = [] | |
seen_models = set() | |
for model_tuple in model_list: | |
architecture, org, model_id = model_tuple | |
full_model_id = f"{org}/{model_id}" | |
if full_model_id not in seen_models: | |
models.append({ | |
"id": full_model_id, | |
"name": full_model_id, | |
"type": architecture | |
}) | |
seen_models.add(full_model_id) | |
logger.info(f"Returning {len(models)} unique models") | |
return JSONResponse(content=models) | |
except Exception as e: | |
# Enhanced error logging | |
logger.error(f"Error fetching models: {str(e)}") | |
logger.error("Full error details:", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e), "type": str(type(e).__name__)} | |
) | |
async def get_model_info_endpoint(model_id: str): | |
logger.info(f"Fetching configurations for model: {model_id}") | |
try: | |
# Define the base URL for the HuggingFace API | |
base_url = "https://huggingface.co./api/integrations/aws/v1/lookup" | |
api_url = f"{base_url}/{model_id}" | |
# Make async HTTP request with timeout | |
timeout = httpx.Timeout(15.0, connect=5.0) # 10s for entire request, 5s for connection | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
response = await client.get(api_url) | |
response.raise_for_status() | |
data = response.json() | |
configs = data.get("cached_configs", []) | |
logger.info(f"Found {len(configs)} configurations for model {model_id}") | |
return JSONResponse(content={"configurations": configs}) | |
except httpx.TimeoutException as e: | |
logger.error(f"Timeout while fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=504, # Gateway Timeout | |
content={"error": "Request timed out while fetching model configurations"} | |
) | |
except httpx.HTTPError as e: | |
logger.error(f"HTTP error fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": f"Failed to fetch model configurations: {str(e)}"} | |
) | |
except Exception as e: | |
logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) | |
return JSONResponse( | |
status_code=500, | |
content={"error": str(e)} | |
) | |
async def static_files(path: str, request: Request): | |
logger.info(f"Static file requested: {path}") | |
file_path = os.path.join(static_dir, path) | |
if os.path.exists(file_path): | |
response = FileResponse(file_path) | |
# Ensure proper content type | |
if path.endswith('.css'): | |
response.headers["content-type"] = "text/css" | |
elif path.endswith('.js'): | |
response.headers["content-type"] = "application/javascript" | |
return response | |
return JSONResponse(status_code=404, content={"error": "File not found"}) |