Rivalcoder commited on
Commit
3757f34
·
1 Parent(s): bda5a7d
Files changed (1) hide show
  1. app.py +46 -34
app.py CHANGED
@@ -16,7 +16,6 @@ emotion_analysis = pipeline("text-classification",
16
  tokenizer=tokenizer,
17
  return_all_scores=True)
18
 
19
- # Create FastAPI app
20
  app = FastAPI()
21
 
22
  def save_upload_file(upload_file: UploadFile) -> str:
@@ -36,47 +35,38 @@ def save_upload_file(upload_file: UploadFile) -> str:
36
  async def predict_from_upload(file: UploadFile = File(...)):
37
  """API endpoint for file uploads"""
38
  try:
39
- # Save the uploaded file temporarily
40
  temp_path = save_upload_file(file)
41
 
42
- # Process based on file type
43
  if temp_path.endswith('.json'):
44
  with open(temp_path, 'r') as f:
45
  data = json.load(f)
46
  text = data.get('description', '')
47
- else: # Assume text file
48
  with open(temp_path, 'r') as f:
49
  text = f.read()
50
 
51
  if not text.strip():
52
  raise HTTPException(status_code=400, detail="No text content found")
53
 
54
- # Analyze text
55
  result = emotion_analysis(text)
56
  emotions = [{'label': e['label'], 'score': float(e['score'])}
57
  for e in sorted(result[0], key=lambda x: x['score'], reverse=True)]
58
 
59
- # Clean up
60
  os.unlink(temp_path)
61
-
62
- return {
63
- "success": True,
64
- "results": emotions
65
- }
66
 
67
  except Exception as e:
68
  if 'temp_path' in locals() and os.path.exists(temp_path):
69
  os.unlink(temp_path)
70
  raise HTTPException(status_code=500, detail=str(e))
71
 
72
- # Gradio interface
73
- def gradio_predict(input_data):
74
  """Handle both direct text and file uploads"""
75
  try:
76
- if isinstance(input_data, str): # Direct text input
77
- text = input_data
78
- else: # File upload
79
- temp_path = save_upload_file(input_data)
80
  if temp_path.endswith('.json'):
81
  with open(temp_path, 'r') as f:
82
  data = json.load(f)
@@ -85,6 +75,8 @@ def gradio_predict(input_data):
85
  with open(temp_path, 'r') as f:
86
  text = f.read()
87
  os.unlink(temp_path)
 
 
88
 
89
  if not text.strip():
90
  return {"error": "No text content found"}
@@ -100,26 +92,46 @@ def gradio_predict(input_data):
100
  except Exception as e:
101
  return {"error": str(e)}
102
 
103
- # Create Gradio interface
104
- demo = gr.Interface(
105
- fn=gradio_predict,
106
- inputs=[
107
- gr.Textbox(label="Enter text directly", lines=5),
108
- gr.File(label="Or upload text/JSON file", file_types=[".txt", ".json"])
109
- ],
110
- outputs=gr.JSON(label="Emotion Analysis"),
111
- title="Text Emotion Analysis",
112
- description="Analyze emotion in text using RoBERTa model",
113
- examples=[
114
- ["I'm feeling absolutely thrilled about this new project!"],
115
- ["This situation is making me extremely anxious and worried."]
116
- ]
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- # Mount Gradio app
120
  app = gr.mount_gradio_app(app, demo, path="/")
121
 
122
- # For running locally
123
  if __name__ == "__main__":
124
  import uvicorn
125
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
16
  tokenizer=tokenizer,
17
  return_all_scores=True)
18
 
 
19
  app = FastAPI()
20
 
21
  def save_upload_file(upload_file: UploadFile) -> str:
 
35
  async def predict_from_upload(file: UploadFile = File(...)):
36
  """API endpoint for file uploads"""
37
  try:
 
38
  temp_path = save_upload_file(file)
39
 
 
40
  if temp_path.endswith('.json'):
41
  with open(temp_path, 'r') as f:
42
  data = json.load(f)
43
  text = data.get('description', '')
44
+ else:
45
  with open(temp_path, 'r') as f:
46
  text = f.read()
47
 
48
  if not text.strip():
49
  raise HTTPException(status_code=400, detail="No text content found")
50
 
 
51
  result = emotion_analysis(text)
52
  emotions = [{'label': e['label'], 'score': float(e['score'])}
53
  for e in sorted(result[0], key=lambda x: x['score'], reverse=True)]
54
 
 
55
  os.unlink(temp_path)
56
+ return {"success": True, "results": emotions}
 
 
 
 
57
 
58
  except Exception as e:
59
  if 'temp_path' in locals() and os.path.exists(temp_path):
60
  os.unlink(temp_path)
61
  raise HTTPException(status_code=500, detail=str(e))
62
 
63
+ # Modified gradio_predict to handle both input types correctly
64
+ def gradio_predict(input_data, file_data=None):
65
  """Handle both direct text and file uploads"""
66
  try:
67
+ # Determine input source
68
+ if file_data is not None: # File upload takes precedence
69
+ temp_path = save_upload_file(file_data)
 
70
  if temp_path.endswith('.json'):
71
  with open(temp_path, 'r') as f:
72
  data = json.load(f)
 
75
  with open(temp_path, 'r') as f:
76
  text = f.read()
77
  os.unlink(temp_path)
78
+ else: # Use direct text input
79
+ text = input_data
80
 
81
  if not text.strip():
82
  return {"error": "No text content found"}
 
92
  except Exception as e:
93
  return {"error": str(e)}
94
 
95
+ # Updated Gradio interface with proper input handling
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("# Text Emotion Analysis")
98
+
99
+ with gr.Row():
100
+ with gr.Column():
101
+ text_input = gr.Textbox(label="Enter text directly", lines=5)
102
+ file_input = gr.File(label="Or upload file", file_types=[".txt", ".json"])
103
+ submit_btn = gr.Button("Analyze")
104
+
105
+ with gr.Column():
106
+ output = gr.JSON(label="Results")
107
+
108
+ # Handle both input methods
109
+ submit_btn.click(
110
+ fn=gradio_predict,
111
+ inputs=[text_input, file_input],
112
+ outputs=output,
113
+ api_name="predict"
114
+ )
115
+
116
+ # Examples with both input types
117
+ gr.Examples(
118
+ examples=[
119
+ ["I'm feeling excited about this new project!"],
120
+ ["This situation makes me anxious and worried"]
121
+ ],
122
+ inputs=text_input
123
+ )
124
+ gr.Examples(
125
+ examples=[
126
+ ["example1.json"],
127
+ ["example2.txt"]
128
+ ],
129
+ inputs=file_input,
130
+ label="File Examples"
131
+ )
132
 
 
133
  app = gr.mount_gradio_app(app, demo, path="/")
134
 
 
135
  if __name__ == "__main__":
136
  import uvicorn
137
  uvicorn.run(app, host="0.0.0.0", port=7860)