import torch from fastapi import FastAPI, File, UploadFile from fastapi.responses import JSONResponse, RedirectResponse from transformers import ConvNextForImageClassification, AutoImageProcessor from PIL import Image import io import gradio as gr from starlette.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from gradio.routes import mount_gradio_app # Class names class_names = [ 'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos', 'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', 'Warts Molluscum and other Viral Infections' ] # Load model and processor model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224") model.classifier = torch.nn.Linear(in_features=1024, out_features=23) model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu")) model.eval() processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224") # FastAPI app app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # Adjust for production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Predict function def predict(image: Image.Image): inputs = processor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predicted_class = torch.argmax(outputs.logits, dim=1).item() return predicted_class, class_names[predicted_class] # FastAPI route @app.post("/predict/") async def predict_endpoint(file: UploadFile = File(...)): try: img_bytes = await file.read() img = Image.open(io.BytesIO(img_bytes)).convert("RGB") predicted_class, predicted_name = predict(img) return JSONResponse(content={ "predicted_class": predicted_class, "predicted_name": predicted_name }) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) @app.get("/") def redirect_root_to_gradio(): return RedirectResponse(url="/gradio") # Gradio interface def gradio_interface(image): predicted_class, predicted_name = predict(image) return f"{predicted_name} (Class {predicted_class})" gradio_app = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="pil"), outputs="text", title="Skin Disease Classifier", description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model." ) # Mount Gradio in FastAPI app = mount_gradio_app(app, gradio_app, path="/gradio")