Rivalcoder commited on
Commit
2118cd6
·
1 Parent(s): 91f9d2d
Files changed (2) hide show
  1. app.py +24 -10
  2. requirements.txt +1 -0
app.py CHANGED
@@ -8,8 +8,11 @@ import gradio as gr
8
  from starlette.middleware.cors import CORSMiddleware
9
  from fastapi.staticfiles import StaticFiles
10
  from gradio.routes import mount_gradio_app
 
 
 
11
 
12
- # Class names
13
  class_names = [
14
  'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
15
  'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
@@ -21,7 +24,7 @@ class_names = [
21
  'Warts Molluscum and other Viral Infections'
22
  ]
23
 
24
- # Load model and processor
25
  model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
26
  model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
27
  model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
@@ -29,17 +32,19 @@ model.eval()
29
 
30
  processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
31
 
32
- # FastAPI app
33
  app = FastAPI()
 
 
34
  app.add_middleware(
35
  CORSMiddleware,
36
- allow_origins=["*"], # Adjust for production
37
  allow_credentials=True,
38
  allow_methods=["*"],
39
  allow_headers=["*"],
40
  )
41
 
42
- # Predict function
43
  def predict(image: Image.Image):
44
  inputs = processor(images=image, return_tensors="pt")
45
  with torch.no_grad():
@@ -47,9 +52,10 @@ def predict(image: Image.Image):
47
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
48
  return predicted_class, class_names[predicted_class]
49
 
50
- # FastAPI route
51
- @app.post("/predict/")
52
- async def predict_endpoint(file: UploadFile = File(...)):
 
53
  try:
54
  img_bytes = await file.read()
55
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
@@ -61,15 +67,18 @@ async def predict_endpoint(file: UploadFile = File(...)):
61
  except Exception as e:
62
  return JSONResponse(content={"error": str(e)}, status_code=500)
63
 
 
64
  @app.get("/")
65
  def redirect_root_to_gradio():
66
  return RedirectResponse(url="/gradio")
67
 
68
- # Gradio interface
69
  def gradio_interface(image):
 
70
  predicted_class, predicted_name = predict(image)
71
  return f"{predicted_name} (Class {predicted_class})"
72
 
 
73
  gradio_app = gr.Interface(
74
  fn=gradio_interface,
75
  inputs=gr.Image(type="pil"),
@@ -78,5 +87,10 @@ gradio_app = gr.Interface(
78
  description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model."
79
  )
80
 
81
- # Mount Gradio in FastAPI
82
  app = mount_gradio_app(app, gradio_app, path="/gradio")
 
 
 
 
 
 
8
  from starlette.middleware.cors import CORSMiddleware
9
  from fastapi.staticfiles import StaticFiles
10
  from gradio.routes import mount_gradio_app
11
+ import tempfile
12
+ import os
13
+ from typing import Optional
14
 
15
+ # Class names for skin disease classification
16
  class_names = [
17
  'Acne and Rosacea Photos', 'Actinic Keratosis Basal Cell Carcinoma and other Malignant Lesions', 'Atopic Dermatitis Photos',
18
  'Bullous Disease Photos', 'Cellulitis Impetigo and other Bacterial Infections', 'Eczema Photos', 'Exanthems and Drug Eruptions',
 
24
  'Warts Molluscum and other Viral Infections'
25
  ]
26
 
27
+ # Load the ConvNeXt model and processor
28
  model = ConvNextForImageClassification.from_pretrained("facebook/convnext-base-224")
29
  model.classifier = torch.nn.Linear(in_features=1024, out_features=23)
30
  model.load_state_dict(torch.load("./models/convnext_base_finetuned.pth", map_location="cpu"))
 
32
 
33
  processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
34
 
35
+ # FastAPI app setup
36
  app = FastAPI()
37
+
38
+ # CORS Middleware to allow cross-origin requests
39
  app.add_middleware(
40
  CORSMiddleware,
41
+ allow_origins=["*"], # Allow all origins for demo purposes
42
  allow_credentials=True,
43
  allow_methods=["*"],
44
  allow_headers=["*"],
45
  )
46
 
47
+ # Function to predict the skin disease from an image
48
  def predict(image: Image.Image):
49
  inputs = processor(images=image, return_tensors="pt")
50
  with torch.no_grad():
 
52
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
53
  return predicted_class, class_names[predicted_class]
54
 
55
+ # FastAPI route for prediction via API
56
+ @app.post("/api/predict")
57
+ async def predict_from_upload(file: UploadFile = File(...)):
58
+ """API endpoint for image uploads"""
59
  try:
60
  img_bytes = await file.read()
61
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
67
  except Exception as e:
68
  return JSONResponse(content={"error": str(e)}, status_code=500)
69
 
70
+ # Redirect root to Gradio interface
71
  @app.get("/")
72
  def redirect_root_to_gradio():
73
  return RedirectResponse(url="/gradio")
74
 
75
+ # Gradio interface for testing
76
  def gradio_interface(image):
77
+ """Gradio function to handle the prediction from image"""
78
  predicted_class, predicted_name = predict(image)
79
  return f"{predicted_name} (Class {predicted_class})"
80
 
81
+ # Gradio app setup
82
  gradio_app = gr.Interface(
83
  fn=gradio_interface,
84
  inputs=gr.Image(type="pil"),
 
87
  description="Upload a skin image to classify the condition using a fine-tuned ConvNeXt model."
88
  )
89
 
90
+ # Mount Gradio app into FastAPI
91
  app = mount_gradio_app(app, gradio_app, path="/gradio")
92
+
93
+ # For running the app locally with uvicorn
94
+ if __name__ == "__main__":
95
+ import uvicorn
96
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt CHANGED
@@ -8,4 +8,5 @@ Pillow
8
  python-multipart
9
  gradio
10
  transformers
 
11
 
 
8
  python-multipart
9
  gradio
10
  transformers
11
+ starlette
12