Rivalcoder
[Edit]
2118cd6
import torch
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, RedirectResponse
from transformers import ConvNextForImageClassification, AutoImageProcessor
from PIL import Image
import io
import gradio as gr
from starlette.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from gradio.routes import mount_gradio_app
import tempfile
import os
from typing import Optional
# Class names for skin disease classification
class_names = [
'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation',
'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease',
'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases',
'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease',
'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos',
'Warts Molluscum and other Viral Infections'
]
# Load the ConvNeXt model and processor
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
model.eval()
processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
# FastAPI app setup
app = FastAPI()
# CORS Middleware to allow cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for demo purposes
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Function to predict the skin disease from an image
def predict(image: Image.Image):
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
return predicted_class, class_names[predicted_class]
# FastAPI route for prediction via API
@app.post("/api/predict")
async def predict_from_upload(file: UploadFile = File(...)):
"""API endpoint for image uploads"""
try:
img_bytes = await file.read()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
predicted_class, predicted_name = predict(img)
return JSONResponse(content={
"predicted_class": predicted_class,
"predicted_name": predicted_name
})
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# Redirect root to Gradio interface
@app.get("/")
def redirect_root_to_gradio():
return RedirectResponse(url="/gradio")
# Gradio interface for testing
def gradio_interface(image):
"""Gradio function to handle the prediction from image"""
predicted_class, predicted_name = predict(image)
return f"{predicted_name} (Class {predicted_class})"
# Gradio app setup
gradio_app = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"),
outputs="text",
title="Skin Disease Classifier",
description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model."
)
# Mount Gradio app into FastAPI
app = mount_gradio_app(app, gradio_app, path="/gradio")
# For running the app locally with uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)