Rivalcoder commited on
Commit
a85ce4c
·
1 Parent(s): 884137e
Files changed (1) hide show
  1. app.py +10 -32
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- import gradio as gr
3
  from fastapi import FastAPI, File, UploadFile
4
  from fastapi.responses import JSONResponse
5
  from transformers import ConvNextForImageClassification, AutoImageProcessor
@@ -29,49 +28,28 @@ processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
29
  # FastAPI app
30
  app = FastAPI()
31
 
32
- # Helper function for processing the image
33
  def predict(image: Image.Image):
34
- # Preprocess the image
35
  inputs = processor(images=image, return_tensors="pt")
36
-
37
- # Perform inference
38
  with torch.no_grad():
39
  outputs = model(**inputs)
40
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
41
-
42
  return predicted_class, class_names[predicted_class]
43
 
44
- # FastAPI endpoint to handle image upload and prediction
45
  @app.post("/predict/")
46
  async def predict_endpoint(file: UploadFile = File(...)):
47
  try:
48
- # Read and process the image
49
  img_bytes = await file.read()
50
- img = Image.open(io.BytesIO(img_bytes))
51
-
52
- # Get the prediction
53
  predicted_class, predicted_name = predict(img)
54
-
55
- # Return the result as JSON
56
- return JSONResponse(content={"predicted_class": predicted_class, "predicted_name": predicted_name})
57
-
58
  except Exception as e:
59
  return JSONResponse(content={"error": str(e)}, status_code=500)
60
 
61
- # Gradio function to integrate with the FastAPI prediction
62
- def gradio_predict(image: Image.Image):
63
- predicted_class, predicted_name = predict(image)
64
- return f"Predicted Class: {predicted_name}"
65
-
66
- # Gradio Interface
67
- iface = gr.Interface(fn=gradio_predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox())
68
-
69
- # Serve Gradio interface on FastAPI
70
- @app.get("/gradio/")
71
- async def gradio_interface():
72
- return iface.launch(share=True, inline=True)
73
-
74
- # Run the FastAPI app using Uvicorn
75
- if __name__ == "__main__":
76
- import uvicorn
77
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import torch
 
2
  from fastapi import FastAPI, File, UploadFile
3
  from fastapi.responses import JSONResponse
4
  from transformers import ConvNextForImageClassification, AutoImageProcessor
 
28
  # FastAPI app
29
  app = FastAPI()
30
 
31
+ # Prediction helper
32
  def predict(image: Image.Image):
 
33
  inputs = processor(images=image, return_tensors="pt")
 
 
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
  predicted_class = torch.argmax(outputs.logits, dim=1).item()
 
37
  return predicted_class, class_names[predicted_class]
38
 
39
+ # Endpoint: /predict
40
  @app.post("/predict/")
41
  async def predict_endpoint(file: UploadFile = File(...)):
42
  try:
 
43
  img_bytes = await file.read()
44
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
45
  predicted_class, predicted_name = predict(img)
46
+ return JSONResponse(content={
47
+ "predicted_class": predicted_class,
48
+ "predicted_name": predicted_name
49
+ })
50
  except Exception as e:
51
  return JSONResponse(content={"error": str(e)}, status_code=500)
52
 
53
+ # Required for Hugging Face Spaces (do NOT run uvicorn manually)
54
+ # Just expose the app
55
+ app = app