File size: 3,211 Bytes
884137e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import gradio as gr
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from transformers import ConvNextForImageClassification, AutoImageProcessor
from PIL import Image
import io

# Class names (for skin diseases)
class_names = [
    'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
    'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions', 
    'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation', 
    'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease', 
    'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases', 
    'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease', 
    'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos', 
    'Warts Molluscum and other Viral Infections'
]

# Load model and processor
model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
model.eval()

processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")

# FastAPI app
app = FastAPI()

# Helper function for processing the image
def predict(image: Image.Image):
    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        predicted_class = torch.argmax(outputs.logits, dim=1).item()
    
    return predicted_class, class_names[predicted_class]

# FastAPI endpoint to handle image upload and prediction
@app.post("/predict/")
async def predict_endpoint(file: UploadFile = File(...)):
    try:
        # Read and process the image
        img_bytes = await file.read()
        img = Image.open(io.BytesIO(img_bytes))
        
        # Get the prediction
        predicted_class, predicted_name = predict(img)
        
        # Return the result as JSON
        return JSONResponse(content={"predicted_class": predicted_class, "predicted_name": predicted_name})
    
    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)

# Gradio function to integrate with the FastAPI prediction
def gradio_predict(image: Image.Image):
    predicted_class, predicted_name = predict(image)
    return f"Predicted Class: {predicted_name}"

# Gradio Interface
iface = gr.Interface(fn=gradio_predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox())

# Serve Gradio interface on FastAPI
@app.get("/gradio/")
async def gradio_interface():
    return iface.launch(share=True, inline=True)

# Run the FastAPI app using Uvicorn
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)