Spaces:
Sleeping
Sleeping
import torch | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse, RedirectResponse | |
from transformers import ConvNextForImageClassification, AutoImageProcessor | |
from PIL import Image | |
import io | |
import gradio as gr | |
from starlette.middleware.cors import CORSMiddleware | |
from fastapi.staticfiles import StaticFiles | |
from gradio.routes import mount_gradio_app | |
import tempfile | |
import os | |
from typing import Optional | |
# Class names for skin disease classification | |
class_names = [ | |
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', | |
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', | |
'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', | |
'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', | |
'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', | |
'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', | |
'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', | |
'Warts Molluscum and other Viral Infections' | |
] | |
# Load the ConvNeXt model and processor | |
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224") | |
model.classifier = torch.nn.Linear(in_features=1024, out_features=23) | |
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu")) | |
model.eval() | |
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224") | |
# FastAPI app setup | |
app = FastAPI() | |
# CORS Middleware to allow cross-origin requests | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allow all origins for demo purposes | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Function to predict the skin disease from an image | |
def predict(image: Image.Image): | |
inputs = processor(images=image, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
return predicted_class, class_names[predicted_class] | |
# FastAPI route for prediction via API | |
async def predict_from_upload(file: UploadFile = File(...)): | |
"""API endpoint for image uploads""" | |
try: | |
img_bytes = await file.read() | |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
predicted_class, predicted_name = predict(img) | |
return JSONResponse(content={ | |
"predicted_class": predicted_class, | |
"predicted_name": predicted_name | |
}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}, status_code=500) | |
# Redirect root to Gradio interface | |
def redirect_root_to_gradio(): | |
return RedirectResponse(url="/gradio") | |
# Gradio interface for testing | |
def gradio_interface(image): | |
"""Gradio function to handle the prediction from image""" | |
predicted_class, predicted_name = predict(image) | |
return f"{predicted_name} (Class {predicted_class})" | |
# Gradio app setup | |
gradio_app = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Image(type="pil"), | |
outputs="text", | |
title="Skin Disease Classifier", | |
description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model." | |
) | |
# Mount Gradio app into FastAPI | |
app = mount_gradio_app(app, gradio_app, path="/gradio") | |
# For running the app locally with uvicorn | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |