Spaces:
Sleeping
Sleeping
Rivalcoder
commited on
Commit
·
a85ce4c
1
Parent(s):
884137e
Edit
Browse files
app.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import torch
|
2 |
-
import gradio as gr
|
3 |
from fastapi import FastAPI, File, UploadFile
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from transformers import ConvNextForImageClassification, AutoImageProcessor
|
@@ -29,49 +28,28 @@ processor = AutoImageProcessor.from_pretrained("facebook/convnext-base-224")
|
|
29 |
# FastAPI app
|
30 |
app = FastAPI()
|
31 |
|
32 |
-
#
|
33 |
def predict(image: Image.Image):
|
34 |
-
# Preprocess the image
|
35 |
inputs = processor(images=image, return_tensors="pt")
|
36 |
-
|
37 |
-
# Perform inference
|
38 |
with torch.no_grad():
|
39 |
outputs = model(**inputs)
|
40 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
41 |
-
|
42 |
return predicted_class, class_names[predicted_class]
|
43 |
|
44 |
-
#
|
45 |
@app.post("/predict/")
|
46 |
async def predict_endpoint(file: UploadFile = File(...)):
|
47 |
try:
|
48 |
-
# Read and process the image
|
49 |
img_bytes = await file.read()
|
50 |
-
img = Image.open(io.BytesIO(img_bytes))
|
51 |
-
|
52 |
-
# Get the prediction
|
53 |
predicted_class, predicted_name = predict(img)
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
except Exception as e:
|
59 |
return JSONResponse(content={"error": str(e)}, status_code=500)
|
60 |
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
return f"Predicted Class: {predicted_name}"
|
65 |
-
|
66 |
-
# Gradio Interface
|
67 |
-
iface = gr.Interface(fn=gradio_predict, inputs=gr.Image(type="pil"), outputs=gr.Textbox())
|
68 |
-
|
69 |
-
# Serve Gradio interface on FastAPI
|
70 |
-
@app.get("/gradio/")
|
71 |
-
async def gradio_interface():
|
72 |
-
return iface.launch(share=True, inline=True)
|
73 |
-
|
74 |
-
# Run the FastAPI app using Uvicorn
|
75 |
-
if __name__ == "__main__":
|
76 |
-
import uvicorn
|
77 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
1 |
import torch
|
|
|
2 |
from fastapi import FastAPI, File, UploadFile
|
3 |
from fastapi.responses import JSONResponse
|
4 |
from transformers import ConvNextForImageClassification, AutoImageProcessor
|
|
|
28 |
# FastAPI app
|
29 |
app = FastAPI()
|
30 |
|
31 |
+
# Prediction helper
|
32 |
def predict(image: Image.Image):
|
|
|
33 |
inputs = processor(images=image, return_tensors="pt")
|
|
|
|
|
34 |
with torch.no_grad():
|
35 |
outputs = model(**inputs)
|
36 |
predicted_class = torch.argmax(outputs.logits, dim=1).item()
|
|
|
37 |
return predicted_class, class_names[predicted_class]
|
38 |
|
39 |
+
# Endpoint: /predict
|
40 |
@app.post("/predict/")
|
41 |
async def predict_endpoint(file: UploadFile = File(...)):
|
42 |
try:
|
|
|
43 |
img_bytes = await file.read()
|
44 |
+
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
|
|
|
|
45 |
predicted_class, predicted_name = predict(img)
|
46 |
+
return JSONResponse(content={
|
47 |
+
"predicted_class": predicted_class,
|
48 |
+
"predicted_name": predicted_name
|
49 |
+
})
|
50 |
except Exception as e:
|
51 |
return JSONResponse(content={"error": str(e)}, status_code=500)
|
52 |
|
53 |
+
# Required for Hugging Face Spaces (do NOT run uvicorn manually)
|
54 |
+
# Just expose the app
|
55 |
+
app = app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|