Rivalcoder commited on
Commit
05ffad0
·
1 Parent(s): 3757f34
Files changed (1) hide show
  1. app.py +5 -25
app.py CHANGED
@@ -14,7 +14,7 @@ model = RobertaForSequenceClassification.from_pretrained(model_name)
14
  emotion_analysis = pipeline("text-classification",
15
  model=model,
16
  tokenizer=tokenizer,
17
- return_all_scores=True)
18
 
19
  app = FastAPI()
20
 
@@ -25,7 +25,7 @@ def save_upload_file(upload_file: UploadFile) -> str:
25
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
26
  content = upload_file.file.read()
27
  if suffix == '.json':
28
- content = content.decode('utf-8') # Decode JSON files
29
  tmp.write(content if isinstance(content, bytes) else content.encode())
30
  return tmp.name
31
  finally:
@@ -60,12 +60,10 @@ async def predict_from_upload(file: UploadFile = File(...)):
60
  os.unlink(temp_path)
61
  raise HTTPException(status_code=500, detail=str(e))
62
 
63
- # Modified gradio_predict to handle both input types correctly
64
  def gradio_predict(input_data, file_data=None):
65
  """Handle both direct text and file uploads"""
66
  try:
67
- # Determine input source
68
- if file_data is not None: # File upload takes precedence
69
  temp_path = save_upload_file(file_data)
70
  if temp_path.endswith('.json'):
71
  with open(temp_path, 'r') as f:
@@ -75,7 +73,7 @@ def gradio_predict(input_data, file_data=None):
75
  with open(temp_path, 'r') as f:
76
  text = f.read()
77
  os.unlink(temp_path)
78
- else: # Use direct text input
79
  text = input_data
80
 
81
  if not text.strip():
@@ -92,7 +90,7 @@ def gradio_predict(input_data, file_data=None):
92
  except Exception as e:
93
  return {"error": str(e)}
94
 
95
- # Updated Gradio interface with proper input handling
96
  with gr.Blocks() as demo:
97
  gr.Markdown("# Text Emotion Analysis")
98
 
@@ -105,30 +103,12 @@ with gr.Blocks() as demo:
105
  with gr.Column():
106
  output = gr.JSON(label="Results")
107
 
108
- # Handle both input methods
109
  submit_btn.click(
110
  fn=gradio_predict,
111
  inputs=[text_input, file_input],
112
  outputs=output,
113
  api_name="predict"
114
  )
115
-
116
- # Examples with both input types
117
- gr.Examples(
118
- examples=[
119
- ["I'm feeling excited about this new project!"],
120
- ["This situation makes me anxious and worried"]
121
- ],
122
- inputs=text_input
123
- )
124
- gr.Examples(
125
- examples=[
126
- ["example1.json"],
127
- ["example2.txt"]
128
- ],
129
- inputs=file_input,
130
- label="File Examples"
131
- )
132
 
133
  app = gr.mount_gradio_app(app, demo, path="/")
134
 
 
14
  emotion_analysis = pipeline("text-classification",
15
  model=model,
16
  tokenizer=tokenizer,
17
+ top_k=None) # Replaced return_all_scores with top_k
18
 
19
  app = FastAPI()
20
 
 
25
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
26
  content = upload_file.file.read()
27
  if suffix == '.json':
28
+ content = content.decode('utf-8')
29
  tmp.write(content if isinstance(content, bytes) else content.encode())
30
  return tmp.name
31
  finally:
 
60
  os.unlink(temp_path)
61
  raise HTTPException(status_code=500, detail=str(e))
62
 
 
63
  def gradio_predict(input_data, file_data=None):
64
  """Handle both direct text and file uploads"""
65
  try:
66
+ if file_data is not None:
 
67
  temp_path = save_upload_file(file_data)
68
  if temp_path.endswith('.json'):
69
  with open(temp_path, 'r') as f:
 
73
  with open(temp_path, 'r') as f:
74
  text = f.read()
75
  os.unlink(temp_path)
76
+ else:
77
  text = input_data
78
 
79
  if not text.strip():
 
90
  except Exception as e:
91
  return {"error": str(e)}
92
 
93
+ # Simplified Gradio interface without examples
94
  with gr.Blocks() as demo:
95
  gr.Markdown("# Text Emotion Analysis")
96
 
 
103
  with gr.Column():
104
  output = gr.JSON(label="Results")
105
 
 
106
  submit_btn.click(
107
  fn=gradio_predict,
108
  inputs=[text_input, file_input],
109
  outputs=output,
110
  api_name="predict"
111
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  app = gr.mount_gradio_app(app, demo, path="/")
114