Rivalcoder commited on
Commit
2981dff
·
1 Parent(s): dd2fa11
Files changed (2) hide show
  1. app.py +72 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ from transformers import pipeline
3
+ import gradio as gr
4
+ from fastapi import FastAPI, UploadFile, File, Request
5
+ import os
6
+ from typing import Optional
7
+
8
+ # Initialize classifier
9
+ classifier = pipeline("audio-classification", model="superb/hubert-large-superb-er")
10
+
11
+ # Create FastAPI app (works with Gradio)
12
+ app = FastAPI()
13
+
14
+ def save_upload_file(upload_file: UploadFile) -> str:
15
+ """Save uploaded file to temporary location"""
16
+ try:
17
+ suffix = os.path.splitext(upload_file.filename)[1]
18
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
19
+ tmp.write(upload_file.file.read())
20
+ return tmp.name
21
+ finally:
22
+ upload_file.file.close()
23
+
24
+ @app.post("/api/predict")
25
+ async def predict_from_upload(file: UploadFile = File(...)):
26
+ """API endpoint for FormData uploads"""
27
+ try:
28
+ # Save the uploaded file temporarily
29
+ temp_path = save_upload_file(file)
30
+
31
+ # Process the audio
32
+ predictions = classifier(temp_path)
33
+
34
+ # Clean up
35
+ os.unlink(temp_path)
36
+
37
+ return {"predictions": predictions}
38
+ except Exception as e:
39
+ return {"error": str(e)}, 500
40
+
41
+ # Gradio interface for testing
42
+ def gradio_predict(audio_file):
43
+ """Gradio interface that handles both file objects and paths"""
44
+ if isinstance(audio_file, str): # Path from Gradio upload
45
+ audio_path = audio_file
46
+ else: # Direct file object
47
+ temp_path = save_upload_file(audio_file)
48
+ audio_path = temp_path
49
+
50
+ predictions = classifier(audio_path)
51
+
52
+ if hasattr(audio_file, 'file'): # Clean up if we created temp file
53
+ os.unlink(audio_path)
54
+
55
+ return {p["label"]: p["score"] for p in predictions}
56
+
57
+ # Create Gradio interface
58
+ demo = gr.Interface(
59
+ fn=gradio_predict,
60
+ inputs=gr.Audio(type="filepath", label="Upload Audio"),
61
+ outputs=gr.Label(num_top_classes=5),
62
+ title="Audio Emotion Recognition",
63
+ description="Upload an audio file to analyze emotional content"
64
+ )
65
+
66
+ # Mount Gradio app
67
+ app = gr.mount_gradio_app(app, demo, path="/")
68
+
69
+ # For running locally
70
+ if __name__ == "__main__":
71
+ import uvicorn
72
+ uvicorn.run(app, host="0.0.0.0", port=7860)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ transformers
3
+ torch
4
+ librosa
5
+ gradio
6
+ python-multipart
7
+ uvicorn