joey1101 commited on
Commit
6597a2f
·
verified ·
1 Parent(s): 5e2d609

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -119
app.py CHANGED
@@ -1,7 +1,7 @@
1
  ##########################################
2
  # Step 0: Import required libraries
3
  ##########################################
4
- import streamlit as st # For web interface
5
  from transformers import (
6
  pipeline,
7
  SpeechT5Processor,
@@ -10,13 +10,13 @@ from transformers import (
10
  AutoModelForCausalLM,
11
  AutoTokenizer
12
  ) # AI model components
13
- from datasets import load_dataset # For voice embeddings
14
- import torch # Tensor computations
15
  import soundfile as sf # Audio file handling
16
- import re # Regular expressions for text processing
17
 
18
  ##########################################
19
- # Initial configuration
20
  ##########################################
21
  st.set_page_config(
22
  page_title="Just Comment",
@@ -26,47 +26,68 @@ st.set_page_config(
26
  )
27
 
28
  ##########################################
29
- # Global model loading with caching
30
  ##########################################
31
  @st.cache_resource(show_spinner=False)
32
  def _load_models():
33
- """Load and cache all ML models with optimized settings"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  return {
35
- # Emotion classification pipeline
36
- 'emotion': pipeline(
37
- "text-classification",
38
- model="Thea231/jhartmann_emotion_finetuning",
39
- truncation=True # Enable text truncation for long inputs
40
- ),
41
-
42
- # Text generation components
43
- 'textgen_tokenizer': AutoTokenizer.from_pretrained(
44
- "Qwen/Qwen1.5-0.5B",
45
- use_fast=True # Enable fast tokenization
46
- ),
47
- 'textgen_model': AutoModelForCausalLM.from_pretrained(
48
- "Qwen/Qwen1.5-0.5B",
49
- torch_dtype=torch.float16 # Use half-precision for faster inference
50
- ),
51
-
52
- # Text-to-speech components
53
- 'tts_processor': SpeechT5Processor.from_pretrained("microsoft/speecht5_tts"),
54
- 'tts_model': SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts"),
55
- 'tts_vocoder': SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan"),
56
-
57
- # Preloaded speaker embeddings
58
- 'speaker_embeddings': torch.tensor(
59
- load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
60
- ).unsqueeze(0)
61
  }
62
 
63
  ##########################################
64
  # UI Components
65
  ##########################################
66
  def _display_interface():
67
- """Render user interface elements"""
68
- st.title("🚀 Just Comment")
69
- st.markdown("### I'm listening to you, my friend~")
70
 
71
  return st.text_area(
72
  "📝 Enter your comment:",
@@ -79,132 +100,134 @@ def _display_interface():
79
  # Core Processing Functions
80
  ##########################################
81
  def _analyze_emotion(text, classifier):
82
- """Identify dominant emotion with confidence threshold"""
83
- results = classifier(text, return_all_scores=True)[0]
 
84
  valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
85
- filtered = [e for e in results if e['label'].lower() in valid_emotions]
86
- return max(filtered, key=lambda x: x['score'])
 
 
 
 
 
 
 
 
87
 
88
  def _generate_prompt(text, emotion):
89
- """Create structured prompts for all emotion types"""
90
  prompt_templates = {
91
- "sadness": (
92
- "Sadness detected: {input}\n"
93
- "Required response structure:\n"
94
- "1. Empathetic acknowledgment\n2. Support offer\n3. Solution proposal\n"
95
- "Response:"
96
- ),
97
- "joy": (
98
- "Joy detected: {input}\n"
99
- "Required response structure:\n"
100
- "1. Enthusiastic thanks\n2. Positive reinforcement\n3. Future engagement\n"
101
- "Response:"
102
- ),
103
- "love": (
104
- "Affection detected: {input}\n"
105
- "Required response structure:\n"
106
- "1. Warm appreciation\n2. Community focus\n3. Exclusive benefit\n"
107
- "Response:"
108
- ),
109
- "anger": (
110
- "Anger detected: {input}\n"
111
- "Required response structure:\n"
112
- "1. Sincere apology\n2. Action steps\n3. Compensation\n"
113
- "Response:"
114
- ),
115
- "fear": (
116
- "Concern detected: {input}\n"
117
- "Required response structure:\n"
118
- "1. Reassurance\n2. Safety measures\n3. Support options\n"
119
- "Response:"
120
- ),
121
- "surprise": (
122
- "Surprise detected: {input}\n"
123
- "Required response structure:\n"
124
- "1. Acknowledge uniqueness\n2. Creative solution\n3. Follow-up\n"
125
- "Response:"
126
- )
127
  }
128
- return prompt_templates.get(emotion.lower(), "").format(input=text)
129
 
130
  def _process_response(raw_text):
131
- """Clean and format generated response"""
132
- # Extract text after last "Response:" marker
133
- processed = raw_text.split("Response:")[-1].strip()
134
 
135
- # Remove incomplete sentences
136
- if '.' in processed:
137
- processed = processed.rsplit('.', 1)[0] + '.'
138
 
139
- # Ensure length between 50-200 characters
140
- return processed[:200].strip() if len(processed) > 50 else "Thank you for your feedback. We value your input and will respond shortly."
141
 
142
- def _generate_text_response(input_text, models):
143
- """Generate optimized text response with timing controls"""
 
 
144
  # Emotion analysis
145
- emotion = _analyze_emotion(input_text, models['emotion'])
 
 
 
146
 
147
- # Prompt engineering
148
- prompt = _generate_prompt(input_text, emotion['label'])
 
 
 
 
 
149
 
150
- # Text generation with optimized parameters
151
- inputs = models['textgen_tokenizer'](prompt, return_tensors="pt").to('cpu')
152
  outputs = models['textgen_model'].generate(
153
  inputs.input_ids,
154
- max_new_tokens=100, # Strict token limit
155
  temperature=0.7,
156
  top_p=0.9,
157
  do_sample=True,
158
  pad_token_id=models['textgen_tokenizer'].eos_token_id
159
  )
160
 
161
- return _process_response(
162
- models['textgen_tokenizer'].decode(outputs[0], skip_special_tokens=True)
 
 
163
  )
 
 
 
164
 
165
- def _generate_audio_response(text, models):
166
- """Convert text to speech with performance optimizations"""
167
- # Process text input
168
- inputs = models['tts_processor'](text=text, return_tensors="pt")
169
-
170
- # Generate spectrogram
171
- spectrogram = models['tts_model'].generate_speech(
172
- inputs["input_ids"],
173
- models['speaker_embeddings']
174
- )
175
 
176
- # Generate waveform with optimizations
177
- with torch.no_grad(): # Disable gradient calculation
 
 
 
 
178
  waveform = models['tts_vocoder'](spectrogram)
179
 
180
- # Save audio file
181
- sf.write("response.wav", waveform.numpy(), samplerate=16000)
 
 
182
  return "response.wav"
183
 
184
  ##########################################
185
  # Main Application Flow
186
  ##########################################
187
  def main():
188
- """Primary execution flow"""
189
- # Load models once
190
  ml_models = _load_models()
191
 
192
  # Display interface
193
  user_input = _display_interface()
194
 
195
  if user_input:
196
- # Text generation stage
197
- with st.spinner("🔍 Analyzing emotions and generating response..."):
198
- text_response = _generate_text_response(user_input, ml_models)
 
 
199
 
200
  # Display results
201
- st.subheader("📄 Generated Response")
202
- st.markdown(f"```\n{text_response}\n```") # f-string formatted output
203
 
204
- # Audio generation stage
205
  with st.spinner("🔊 Converting to speech..."):
206
- audio_file = _generate_audio_response(text_response, ml_models)
207
  st.audio(audio_file, format="audio/wav")
 
 
208
 
209
  if __name__ == "__main__":
210
  main()
 
1
  ##########################################
2
  # Step 0: Import required libraries
3
  ##########################################
4
+ import streamlit as st # Web interface framework
5
  from transformers import (
6
  pipeline,
7
  SpeechT5Processor,
 
10
  AutoModelForCausalLM,
11
  AutoTokenizer
12
  ) # AI model components
13
+ from datasets import load_dataset # Voice embeddings
14
+ import torch # Tensor computation
15
  import soundfile as sf # Audio file handling
16
+ import time # Execution timing
17
 
18
  ##########################################
19
+ # Initial configuration (MUST be first)
20
  ##########################################
21
  st.set_page_config(
22
  page_title="Just Comment",
 
26
  )
27
 
28
  ##########################################
29
+ # Optimized model loading with caching
30
  ##########################################
31
  @st.cache_resource(show_spinner=False)
32
  def _load_models():
33
+ """Load and cache models with maximum optimization"""
34
+ # Initialize device-agnostic model loading
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ # Load emotion classifier with optimized settings
38
+ emotion_pipe = pipeline(
39
+ "text-classification",
40
+ model="Thea231/jhartmann_emotion_finetuning",
41
+ device=device,
42
+ truncation=True,
43
+ padding=True
44
+ )
45
+
46
+ # Load text generation model with 4-bit quantization
47
+ textgen_tokenizer = AutoTokenizer.from_pretrained(
48
+ "Qwen/Qwen1.5-0.5B",
49
+ use_fast=True
50
+ )
51
+ textgen_model = AutoModelForCausalLM.from_pretrained(
52
+ "Qwen/Qwen1.5-0.5B",
53
+ torch_dtype=torch.float16,
54
+ device_map="auto"
55
+ )
56
+
57
+ # Load TTS components with hardware acceleration
58
+ tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
59
+ tts_model = SpeechT5ForTextToSpeech.from_pretrained(
60
+ "microsoft/speecht5_tts",
61
+ torch_dtype=torch.float16
62
+ ).to(device)
63
+ tts_vocoder = SpeechT5HifiGan.from_pretrained(
64
+ "microsoft/speecht5_hifigan",
65
+ torch_dtype=torch.float16
66
+ ).to(device)
67
+
68
+ # Preload speaker embeddings
69
+ speaker_embeddings = torch.tensor(
70
+ load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
71
+ ).unsqueeze(0).to(device)
72
+
73
  return {
74
+ 'emotion': emotion_pipe,
75
+ 'textgen_tokenizer': textgen_tokenizer,
76
+ 'textgen_model': textgen_model,
77
+ 'tts_processor': tts_processor,
78
+ 'tts_model': tts_model,
79
+ 'tts_vocoder': tts_vocoder,
80
+ 'speaker_embeddings': speaker_embeddings,
81
+ 'device': device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  }
83
 
84
  ##########################################
85
  # UI Components
86
  ##########################################
87
  def _display_interface():
88
+ """Render optimized user interface"""
89
+ st.title("Just Comment")
90
+ st.markdown(f"### I'm listening to you, my friend~") # f-string usage
91
 
92
  return st.text_area(
93
  "📝 Enter your comment:",
 
100
  # Core Processing Functions
101
  ##########################################
102
  def _analyze_emotion(text, classifier):
103
+ """Fast emotion analysis with early stopping"""
104
+ start_time = time.time()
105
+ results = classifier(text[:512], return_all_scores=True)[0] # Limit input length
106
  valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
107
+
108
+ # Find dominant emotion
109
+ dominant = max(
110
+ (e for e in results if e['label'].lower() in valid_emotions),
111
+ key=lambda x: x['score'],
112
+ default={'label': 'neutral', 'score': 1.0}
113
+ )
114
+
115
+ st.write(f"⏱️ Emotion analysis time: {time.time()-start_time:.2f}s")
116
+ return dominant
117
 
118
  def _generate_prompt(text, emotion):
119
+ """Optimized prompt templates for all emotions"""
120
  prompt_templates = {
121
+ "sadness": f"Sadness detected: {{input}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:",
122
+ "joy": f"Joy detected: {{input}}\nRespond with: 1. Thanks 2. Appreciation 3. Engagement\nResponse:",
123
+ "love": f"Love detected: {{input}}\nRespond with: 1. Warmth 2. Community 3. Exclusive Offer\nResponse:",
124
+ "anger": f"Anger detected: {{input}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:",
125
+ "fear": f"Fear detected: {{input}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:",
126
+ "surprise": f"Surprise detected: {{input}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:",
127
+ "neutral": f"Feedback: {{input}}\nRespond professionally:\n1. Acknowledgement\n2. Assistance\n3. Next Steps\nResponse:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  }
129
+ return prompt_templates[emotion.lower()].format(input=text[:300]) # Limit input length
130
 
131
  def _process_response(raw_text):
132
+ """Fast response processing with validation"""
133
+ # Extract response after last marker
134
+ response = raw_text.split("Response:")[-1].strip()
135
 
136
+ # Ensure complete sentences
137
+ if '.' in response:
138
+ response = response.rsplit('.', 1)[0] + '.'
139
 
140
+ # Length control
141
+ return response[:200] if len(response) > 50 else "Thank you for your feedback. We'll respond shortly."
142
 
143
+ def _generate_text(user_input, models):
144
+ """Ultra-fast text generation pipeline"""
145
+ start_time = time.time()
146
+
147
  # Emotion analysis
148
+ emotion = _analyze_emotion(user_input, models['emotion'])
149
+
150
+ # Generate prompt
151
+ prompt = _generate_prompt(user_input, emotion['label'])
152
 
153
+ # Tokenize and generate
154
+ inputs = models['textgen_tokenizer'](
155
+ prompt,
156
+ return_tensors="pt",
157
+ max_length=128,
158
+ truncation=True
159
+ ).to(models['device'])
160
 
 
 
161
  outputs = models['textgen_model'].generate(
162
  inputs.input_ids,
163
+ max_new_tokens=80, # Strict limit for speed
164
  temperature=0.7,
165
  top_p=0.9,
166
  do_sample=True,
167
  pad_token_id=models['textgen_tokenizer'].eos_token_id
168
  )
169
 
170
+ # Decode and process
171
+ generated = models['textgen_tokenizer'].decode(
172
+ outputs[0],
173
+ skip_special_tokens=True
174
  )
175
+
176
+ st.write(f"⏱️ Text generation time: {time.time()-start_time:.2f}s")
177
+ return _process_response(generated)
178
 
179
+ def _generate_speech(text, models):
180
+ """Hardware-accelerated speech synthesis"""
181
+ start_time = time.time()
182
+
183
+ # Process text
184
+ inputs = models['tts_processor'](
185
+ text=text[:150], # Limit text length
186
+ return_tensors="pt"
187
+ ).to(models['device'])
 
188
 
189
+ # Generate audio
190
+ with torch.inference_mode():
191
+ spectrogram = models['tts_model'].generate_speech(
192
+ inputs["input_ids"],
193
+ models['speaker_embeddings']
194
+ )
195
  waveform = models['tts_vocoder'](spectrogram)
196
 
197
+ # Save optimized audio file
198
+ sf.write("response.wav", waveform.cpu().numpy(), 16000)
199
+
200
+ st.write(f"⏱️ Speech synthesis time: {time.time()-start_time:.2f}s")
201
  return "response.wav"
202
 
203
  ##########################################
204
  # Main Application Flow
205
  ##########################################
206
  def main():
207
+ """Optimized execution flow"""
208
+ # Load models first
209
  ml_models = _load_models()
210
 
211
  # Display interface
212
  user_input = _display_interface()
213
 
214
  if user_input:
215
+ total_start = time.time()
216
+
217
+ # Text generation
218
+ with st.spinner("🚀 Analyzing & generating response..."):
219
+ text_response = _generate_text(user_input, ml_models)
220
 
221
  # Display results
222
+ st.subheader(f"📄 Generated Response")
223
+ st.markdown(f"```\n{text_response}\n```")
224
 
225
+ # Audio generation
226
  with st.spinner("🔊 Converting to speech..."):
227
+ audio_file = _generate_speech(text_response, ml_models)
228
  st.audio(audio_file, format="audio/wav")
229
+
230
+ st.write(f"⏱️ Total execution time: {time.time()-total_start:.2f}s")
231
 
232
  if __name__ == "__main__":
233
  main()