joey1101 commited on
Commit
0a4b920
Β·
verified Β·
1 Parent(s): 0e85ac7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -33
app.py CHANGED
@@ -16,71 +16,171 @@ import soundfile as sf # Audio file handling
16
  import sentencepiece # Tokenization dependency
17
 
18
  ##########################################
19
- # Set page config FIRST
20
  ##########################################
21
- st.set_page_config( # Must be the first Streamlit command
22
- page_title="πŸš€ Just Comment - I'm listening to you, my friend~",
23
  page_icon="πŸ’¬",
24
- layout="centered"
 
25
  )
26
 
27
  ##########################################
28
- # Initialize models and resources globally
29
  ##########################################
30
- @st.cache_resource # Cache resources to reduce reload time
31
  def load_models():
32
- """Load all required models once and cache them"""
33
  return {
34
- 'emotion_classifier': pipeline(
 
35
  "text-classification",
36
  model="Thea231/jhartmann_emotion_finetuning"
37
  ),
 
 
 
 
 
 
38
  'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
39
  'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
40
  'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
41
- 'textgen_tokenizer': AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
42
- 'textgen_model': AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B"),
43
  'speaker_embeddings': torch.tensor(
44
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
45
  ).unsqueeze(0)
46
  }
47
 
48
  ##########################################
49
- # Streamlit UI Configuration
50
  ##########################################
51
- def setup_ui():
52
- """Configure remaining UI elements"""
53
- st.title("πŸš€ Just Comment - Smart Response Generator")
54
- st.markdown("""
55
- <style>
56
- .reportview-container {background: #f8f9fa;}
57
- .stTextArea textarea {border: 2px solid #dee2e6;}
58
- </style>
59
- """, unsafe_allow_html=True)
60
- return st.text_area("πŸ“ Enter your customer comment:", "", height=150)
 
61
 
62
  ##########################################
63
- # (δΏζŒε…Άδ»–ε‡½ζ•°δΈε˜οΌŒδΈŽδΉ‹ε‰η›ΈεŒ)
64
- # Keep other functions unchanged as previous version
65
  ##########################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  ##########################################
68
- # Main Application Logic
69
  ##########################################
70
  def main():
71
- """Main execution flow"""
72
- models = load_models() # Load models once
73
- user_input = setup_ui()
 
 
74
 
 
75
  if user_input:
76
- with st.spinner("πŸ” Analyzing sentiment and generating response..."):
77
- response = generate_response(user_input, models)
78
-
79
- st.subheader("πŸ’‘ Generated Response:")
80
- st.markdown(f"```\n{response}\n```")
 
 
 
81
 
 
82
  with st.spinner("πŸ”Š Generating voice response..."):
83
- audio_file = generate_speech(response, models)
84
  st.audio(audio_file, format="audio/wav")
85
 
86
  if __name__ == "__main__":
 
16
  import sentencepiece # Tokenization dependency
17
 
18
  ##########################################
19
+ # Initial configuration (MUST be first)
20
  ##########################################
21
+ st.set_page_config(
22
+ page_title="πŸš€ Just Comment - AI Response Generator",
23
  page_icon="πŸ’¬",
24
+ layout="centered",
25
+ initial_sidebar_state="collapsed"
26
  )
27
 
28
  ##########################################
29
+ # Global model loading with caching
30
  ##########################################
31
+ @st.cache_resource(show_spinner=False)
32
  def load_models():
33
+ """Load and cache all ML models"""
34
  return {
35
+ # Emotion classifier
36
+ 'emotion': pipeline(
37
  "text-classification",
38
  model="Thea231/jhartmann_emotion_finetuning"
39
  ),
40
+
41
+ # Text generation models
42
+ 'textgen_tokenizer': AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
43
+ 'textgen_model': AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-0.5B"),
44
+
45
+ # TTS components
46
  'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
47
  'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
48
  'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
49
+
50
+ # Speaker embeddings
51
  'speaker_embeddings': torch.tensor(
52
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
53
  ).unsqueeze(0)
54
  }
55
 
56
  ##########################################
57
+ # UI Components
58
  ##########################################
59
+ def render_interface():
60
+ """Create user interface elements"""
61
+ st.title("πŸš€ AI Customer Response Generator")
62
+ st.caption("Analyzes feedback and generates tailored responses")
63
+
64
+ return st.text_area(
65
+ "πŸ“ Paste customer feedback here:",
66
+ placeholder="The product arrived damaged...",
67
+ height=150,
68
+ key="user_input"
69
+ )
70
 
71
  ##########################################
72
+ # Core Logic Components
 
73
  ##########################################
74
+ def analyze_emotion(text, classifier):
75
+ """Determine dominant emotion with confidence threshold"""
76
+ results = classifier(text, return_all_scores=True)[0]
77
+ top_emotion = max(results, key=lambda x: x['score'])
78
+ return top_emotion if top_emotion['score'] > 0.6 else {'label': 'neutral', 'score': 1.0}
79
+
80
+ def generate_prompt(text, emotion):
81
+ """Create structured prompts for different emotions"""
82
+ prompt_templates = {
83
+ "anger": (
84
+ "Customer complaint: {input}\n"
85
+ "Respond with:\n"
86
+ "1. Apology\n2. Solution steps\n3. Compensation offer\n"
87
+ "Response:"
88
+ ),
89
+ "joy": (
90
+ "Positive feedback: {input}\n"
91
+ "Respond with:\n"
92
+ "1. Appreciation\n2. Highlight strengths\n3. Loyalty benefits\n"
93
+ "Response:"
94
+ ),
95
+ "neutral": (
96
+ "Customer comment: {input}\n"
97
+ "Respond with:\n"
98
+ "1. Acknowledge feedback\n2. Offer assistance\n3. Next steps\n"
99
+ "Response:"
100
+ )
101
+ }
102
+ return prompt_templates.get(emotion.lower(), prompt_templates['neutral']).format(input=text)
103
+
104
+ def process_response(output_text):
105
+ """Ensure response quality and proper formatting"""
106
+ # Remove incomplete sentences
107
+ if '.' in output_text:
108
+ output_text = output_text.rsplit('.', 1)[0] + '.'
109
+
110
+ # Length constraints
111
+ output_text = output_text[:300].strip() # Hard limit at 300 characters
112
+
113
+ # Fallback for short responses
114
+ if len(output_text) < 50:
115
+ return "Thank you for your feedback. We'll review this and contact you shortly."
116
+
117
+ return output_text
118
+
119
+ def generate_text_response(user_input, models):
120
+ """Generate and validate text response"""
121
+ # Emotion analysis
122
+ emotion = analyze_emotion(user_input, models['emotion'])
123
+
124
+ # Prompt engineering
125
+ prompt = generate_prompt(user_input, emotion['label'])
126
+
127
+ # Text generation
128
+ inputs = models['textgen_tokenizer'](prompt, return_tensors="pt")
129
+ outputs = models['textgen_model'].generate(
130
+ inputs.input_ids,
131
+ max_new_tokens=200,
132
+ temperature=0.7,
133
+ do_sample=True,
134
+ top_p=0.9
135
+ )
136
+
137
+ # Decode and process
138
+ full_response = models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
139
+ return process_response(full_response.split("Response:")[-1].strip())
140
+
141
+ def generate_audio_response(text, models):
142
+ """Convert text to speech"""
143
+ # Process text input
144
+ inputs = models['tts_processor'](text=text, return_tensors="pt")
145
+
146
+ # Generate spectrogram
147
+ spectrogram = models['tts_model'].generate_speech(
148
+ inputs["input_ids"],
149
+ models['speaker_embeddings']
150
+ )
151
+
152
+ # Generate waveform
153
+ with torch.no_grad():
154
+ waveform = models['tts_vocoder'](spectrogram)
155
+
156
+ # Save and return audio
157
+ sf.write("response.wav", waveform.numpy(), samplerate=16000)
158
+ return "response.wav"
159
 
160
  ##########################################
161
+ # Main Application Flow
162
  ##########################################
163
  def main():
164
+ # Load models once
165
+ ml_models = load_models()
166
+
167
+ # Render UI
168
+ user_input = render_interface()
169
 
170
+ # Process input
171
  if user_input:
172
+ # Text generation
173
+ with st.status("πŸ” Analyzing feedback...", expanded=True) as status:
174
+ text_response = generate_text_response(user_input, ml_models)
175
+ status.update(label="βœ… Analysis Complete", state="complete")
176
+
177
+ # Display text response
178
+ st.subheader("πŸ“ Generated Response")
179
+ st.markdown(f"```\n{text_response}\n```")
180
 
181
+ # Audio generation
182
  with st.spinner("πŸ”Š Generating voice response..."):
183
+ audio_file = generate_audio_response(text_response, ml_models)
184
  st.audio(audio_file, format="audio/wav")
185
 
186
  if __name__ == "__main__":