from fastapi import FastAPI, Request from fastapi.templating import Jinja2Templates from fastapi.staticfiles import StaticFiles from fastapi.responses import JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from optimum.neuron import utils import logging import sys import os import httpx # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Get the absolute path to the static directory static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static") logger.info(f"Static directory path: {static_dir}") # Get the absolute path to the templates directory templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates") logger.info(f"Templates directory path: {templates_dir}") # Mount static files and templates app.mount("/static", StaticFiles(directory=static_dir), name="static") templates = Jinja2Templates(directory=templates_dir) @app.get("/health") async def health_check(): logger.info("Health check endpoint called") return {"status": "healthy"} @app.get("/") async def home(request: Request): logger.info("Home page requested") # Check if we're running in Spaces is_spaces = os.getenv("SPACE_ID") is not None # Use HTTPS only for Spaces, otherwise use the request's protocol base_url = str(request.base_url) if is_spaces: base_url = base_url.replace("http://", "https://") return templates.TemplateResponse( "index.html", { "request": request, "base_url": base_url } ) @app.get("/api/models") async def get_model_list(): logger.info("Fetching model list") try: # Add debug logging logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}") model_list = utils.get_hub_cached_models(mode="inference") logger.info(f"Found {len(model_list)} models") models = [] seen_models = set() for model_tuple in model_list: architecture, org, model_id = model_tuple full_model_id = f"{org}/{model_id}" if full_model_id not in seen_models: models.append({ "id": full_model_id, "name": full_model_id, "type": architecture }) seen_models.add(full_model_id) logger.info(f"Returning {len(models)} unique models") return JSONResponse(content=models) except Exception as e: # Enhanced error logging logger.error(f"Error fetching models: {str(e)}") logger.error("Full error details:", exc_info=True) return JSONResponse( status_code=500, content={"error": str(e), "type": str(type(e).__name__)} ) @app.get("/api/models/{model_id:path}") async def get_model_info_endpoint(model_id: str): logger.info(f"Fetching configurations for model: {model_id}") try: # Define the base URL for the HuggingFace API base_url = "https://huggingface.co./api/integrations/aws/v1/lookup" api_url = f"{base_url}/{model_id}" # Make async HTTP request with timeout timeout = httpx.Timeout(15.0, connect=5.0) # 10s for entire request, 5s for connection async with httpx.AsyncClient(timeout=timeout) as client: response = await client.get(api_url) response.raise_for_status() data = response.json() configs = data.get("cached_configs", []) logger.info(f"Found {len(configs)} configurations for model {model_id}") return JSONResponse(content={"configurations": configs}) except httpx.TimeoutException as e: logger.error(f"Timeout while fetching configurations for model {model_id}: {str(e)}", exc_info=True) return JSONResponse( status_code=504, # Gateway Timeout content={"error": "Request timed out while fetching model configurations"} ) except httpx.HTTPError as e: logger.error(f"HTTP error fetching configurations for model {model_id}: {str(e)}", exc_info=True) return JSONResponse( status_code=500, content={"error": f"Failed to fetch model configurations: {str(e)}"} ) except Exception as e: logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True) return JSONResponse( status_code=500, content={"error": str(e)} ) @app.get("/static/{path:path}") async def static_files(path: str, request: Request): logger.info(f"Static file requested: {path}") file_path = os.path.join(static_dir, path) if os.path.exists(file_path): response = FileResponse(file_path) # Ensure proper content type if path.endswith('.css'): response.headers["content-type"] = "text/css" elif path.endswith('.js'): response.headers["content-type"] = "application/javascript" return response return JSONResponse(status_code=404, content={"error": "File not found"})