Spaces:
Sleeping
Sleeping
File size: 5,522 Bytes
deb3471 bdc22f6 05e2837 deb3471 d352fe2 d6a1cfd d352fe2 deb3471 05e2837 d352fe2 aa37823 deb3471 d352fe2 aa37823 deb3471 d352fe2 deb3471 d352fe2 933f6ff 05e2837 deb3471 d352fe2 deb3471 08e0ab7 e0174a0 deb3471 d352fe2 deb3471 08e0ab7 deb3471 d352fe2 deb3471 08e0ab7 deb3471 08e0ab7 deb3471 08e0ab7 deb3471 d352fe2 deb3471 d6a1cfd d352fe2 deb3471 d6a1cfd deb3471 d352fe2 deb3471 aa37823 05e2837 aa37823 bdc22f6 05e2837 bdc22f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from optimum.neuron import utils
import logging
import sys
import os
import httpx
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Get the absolute path to the static directory
static_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "static")
logger.info(f"Static directory path: {static_dir}")
# Get the absolute path to the templates directory
templates_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
logger.info(f"Templates directory path: {templates_dir}")
# Mount static files and templates
app.mount("/static", StaticFiles(directory=static_dir), name="static")
templates = Jinja2Templates(directory=templates_dir)
@app.get("/health")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
@app.get("/")
async def home(request: Request):
logger.info("Home page requested")
# Check if we're running in Spaces
is_spaces = os.getenv("SPACE_ID") is not None
# Use HTTPS only for Spaces, otherwise use the request's protocol
base_url = str(request.base_url)
if is_spaces:
base_url = base_url.replace("http://", "https://")
return templates.TemplateResponse(
"index.html",
{
"request": request,
"base_url": base_url
}
)
@app.get("/api/models")
async def get_model_list():
logger.info("Fetching model list")
try:
# Add debug logging
logger.info(f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}")
model_list = utils.get_hub_cached_models(mode="inference")
logger.info(f"Found {len(model_list)} models")
models = []
seen_models = set()
for model_tuple in model_list:
architecture, org, model_id = model_tuple
full_model_id = f"{org}/{model_id}"
if full_model_id not in seen_models:
models.append({
"id": full_model_id,
"name": full_model_id,
"type": architecture
})
seen_models.add(full_model_id)
logger.info(f"Returning {len(models)} unique models")
return JSONResponse(content=models)
except Exception as e:
# Enhanced error logging
logger.error(f"Error fetching models: {str(e)}")
logger.error("Full error details:", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": str(e), "type": str(type(e).__name__)}
)
@app.get("/api/models/{model_id:path}")
async def get_model_info_endpoint(model_id: str):
logger.info(f"Fetching configurations for model: {model_id}")
try:
# Define the base URL for the HuggingFace API
base_url = "https://huggingface.co./api/integrations/aws/v1/lookup"
api_url = f"{base_url}/{model_id}"
# Make async HTTP request with timeout
timeout = httpx.Timeout(15.0, connect=5.0) # 10s for entire request, 5s for connection
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(api_url)
response.raise_for_status()
data = response.json()
configs = data.get("cached_configs", [])
logger.info(f"Found {len(configs)} configurations for model {model_id}")
return JSONResponse(content={"configurations": configs})
except httpx.TimeoutException as e:
logger.error(f"Timeout while fetching configurations for model {model_id}: {str(e)}", exc_info=True)
return JSONResponse(
status_code=504, # Gateway Timeout
content={"error": "Request timed out while fetching model configurations"}
)
except httpx.HTTPError as e:
logger.error(f"HTTP error fetching configurations for model {model_id}: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": f"Failed to fetch model configurations: {str(e)}"}
)
except Exception as e:
logger.error(f"Error fetching configurations for model {model_id}: {str(e)}", exc_info=True)
return JSONResponse(
status_code=500,
content={"error": str(e)}
)
@app.get("/static/{path:path}")
async def static_files(path: str, request: Request):
logger.info(f"Static file requested: {path}")
file_path = os.path.join(static_dir, path)
if os.path.exists(file_path):
response = FileResponse(file_path)
# Ensure proper content type
if path.endswith('.css'):
response.headers["content-type"] = "text/css"
elif path.endswith('.js'):
response.headers["content-type"] = "application/javascript"
return response
return JSONResponse(status_code=404, content={"error": "File not found"}) |