Rivalcoder commited on
Commit
884137e
·
1 Parent(s): 4ec17bd
Files changed (3) hide show
  1. app.py +77 -0
  2. models/convnext_base_finetuned.pth +3 -0
  3. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from fastapi import FastAPI, File, UploadFile
4
+ from fastapi.responses import JSONResponse
5
+ from transformers import ConvNextForImageClassification, AutoImageProcessor
6
+ from PIL import Image
7
+ import io
8
+
9
+ # Class names (for skin diseases)
10
+ class_names = [
11
+ 'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
12
+ 'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
13
+ 'Hair Loss Photos Alopecia and other Hair Diseases', 'Herpes HPV and other STDs Photos', 'Light Diseases and Disorders of Pigmentation',
14
+ 'Lupus and other Connective Tissue diseases', 'Melanoma Skin Cancer Nevi and Moles', 'Nail Fungus and other Nail Disease',
15
+ 'Poison Ivy Photos and other Contact Dermatitis', 'Psoriasis pictures Lichen Planus and related diseases',
16
+ 'Scabies Lyme Disease and other Infestations and Bites', 'Seborrheic Keratoses and other Benign Tumors', 'Systemic Disease',
17
+ 'Tinea Ringworm Candidiasis and other Fungal Infections', 'Urticaria Hives', 'Vascular Tumors', 'Vasculitis Photos',
18
+ 'Warts Molluscum and other Viral Infections'
19
+ ]
20
+
21
+ # Load model and processor
22
+ model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
23
+ model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
24
+ model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
25
+ model.eval()
26
+
27
+ processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
28
+
29
+ # FastAPI app
30
+ app = FastAPI()
31
+
32
+ # Helper function for processing the image
33
+ def predict(image: Image.Image):
34
+ # Preprocess the image
35
+ inputs = processor(images=image, return_tensors="pt")
36
+
37
+ # Perform inference
38
+ with torch.no_grad():
39
+ outputs = model(**inputs)
40
+ predicted_class = torch.argmax(outputs.logits, dim=1).item()
41
+
42
+ return predicted_class, class_names[predicted_class]
43
+
44
+ # FastAPI endpoint to handle image upload and prediction
45
+ @app.post("/predict/")
46
+ async def predict_endpoint(file: UploadFile = File(...)):
47
+ try:
48
+ # Read and process the image
49
+ img_bytes = await file.read()
50
+ img = Image.open(io.BytesIO(img_bytes))
51
+
52
+ # Get the prediction
53
+ predicted_class, predicted_name = predict(img)
54
+
55
+ # Return the result as JSON
56
+ return JSONResponse(content={"predicted_class": predicted_class, "predicted_name": predicted_name})
57
+
58
+ except Exception as e:
59
+ return JSONResponse(content={"error": str(e)}, status_code=500)
60
+
61
+ # Gradio function to integrate with the FastAPI prediction
62
+ def gradio_predict(image: Image.Image):
63
+ predicted_class, predicted_name = predict(image)
64
+ return f"Predicted Class: {predicted_name}"
65
+
66
+ # Gradio Interface
67
+ iface = gr.Interface(fn=gradio_predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox())
68
+
69
+ # Serve Gradio interface on FastAPI
70
+ @app.get("/gradio/")
71
+ async def gradio_interface():
72
+ return iface.launch(share=True, inline=True)
73
+
74
+ # Run the FastAPI app using Uvicorn
75
+ if __name__ == "__main__":
76
+ import uvicorn
77
+ uvicorn.run(app, host="0.0.0.0", port=8000)
models/convnext_base_finetuned.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c8dadf0c017fd3749a0dc291a1d9249bdba618c6351964d1fe65a85c07a578b
3
+ size 350500802
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ torchvision
5
+ opencv-python
6
+ numpy
7
+ Pillow
8
+ python-multipart
9
+ gradio
10
+ transformers
11
+