joey1101 commited on
Commit
5e4841e
·
verified ·
1 Parent(s): 14e0981

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -138
app.py CHANGED
@@ -1,60 +1,53 @@
1
  ##########################################
2
- # Step 0: Import required libraries
3
  ##########################################
4
- import streamlit as st # Web interface framework
5
- from transformers import (
6
  pipeline,
7
  SpeechT5Processor,
8
  SpeechT5ForTextToSpeech,
9
  SpeechT5HifiGan,
10
  AutoModelForCausalLM,
11
  AutoTokenizer
12
- ) # AI model components
13
- from datasets import load_dataset # Voice embeddings
14
- import torch # Tensor computation
15
- import soundfile as sf # Audio file handling
16
- import time # Execution timing
17
 
18
  ##########################################
19
- # Initial configuration (MUST be first)
20
  ##########################################
21
- st.set_page_config(
22
  page_title="Just Comment",
23
  page_icon="💬",
24
- layout="centered",
25
- initial_sidebar_state="collapsed"
26
  )
27
 
28
  ##########################################
29
- # Optimized model loading with caching
30
  ##########################################
31
  @st.cache_resource(show_spinner=False)
32
- def _load_models():
33
- """Load and cache models with maximum optimization"""
34
- # Initialize device-agnostic model loading
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
- # Load emotion classifier with optimized settings
38
  emotion_pipe = pipeline(
39
  "text-classification",
40
  model="Thea231/jhartmann_emotion_finetuning",
41
  device=device,
42
- truncation=True,
43
- padding=True
44
  )
45
 
46
- # Load text generation model with 4-bit quantization
47
- textgen_tokenizer = AutoTokenizer.from_pretrained(
48
- "Qwen/Qwen1.5-0.5B",
49
- use_fast=True
50
- )
51
- textgen_model = AutoModelForCausalLM.from_pretrained(
52
  "Qwen/Qwen1.5-0.5B",
53
  torch_dtype=torch.float16,
54
  device_map="auto"
55
  )
56
 
57
- # Load TTS components with hardware acceleration
58
  tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
59
  tts_model = SpeechT5ForTextToSpeech.from_pretrained(
60
  "microsoft/speecht5_tts",
@@ -65,169 +58,137 @@ def _load_models():
65
  torch_dtype=torch.float16
66
  ).to(device)
67
 
68
- # Preload speaker embeddings
69
- speaker_embeddings = torch.tensor(
70
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
71
  ).unsqueeze(0).to(device)
72
 
73
  return {
74
- 'emotion': emotion_pipe,
75
- 'textgen_tokenizer': textgen_tokenizer,
76
- 'textgen_model': textgen_model,
77
- 'tts_processor': tts_processor,
78
- 'tts_model': tts_model,
79
- 'tts_vocoder': tts_vocoder,
80
- 'speaker_embeddings': speaker_embeddings,
81
- 'device': device
82
  }
83
 
84
  ##########################################
85
- # UI Components
86
  ##########################################
87
- def _display_interface():
88
- """Render optimized user interface"""
89
  st.title("Just Comment")
90
- st.markdown(f"### I'm listening to you, my friend~") # f-string usage
91
-
92
- return st.text_area(
93
  "📝 Enter your comment:",
94
- placeholder="Type your message here...",
95
  height=150,
96
- key="user_input"
97
  )
98
 
99
  ##########################################
100
- # Core Processing Functions
101
  ##########################################
102
- def _analyze_emotion(text, classifier):
103
- """Fast emotion analysis with early stopping"""
104
- start_time = time.time()
105
- results = classifier(text[:512], return_all_scores=True)[0] # Limit input length
106
- valid_emotions = {'sadness', 'joy', 'love', 'anger', 'fear', 'surprise'}
107
-
108
- # Find dominant emotion
109
- dominant = max(
110
- (e for e in results if e['label'].lower() in valid_emotions),
111
  key=lambda x: x['score'],
112
- default={'label': 'neutral', 'score': 1.0}
113
  )
114
-
115
- st.write(f"⏱️ Emotion analysis time: {time.time()-start_time:.2f}s")
116
- return dominant
117
 
118
- def _generate_prompt(text, emotion):
119
- """Optimized prompt templates for all emotions"""
120
- prompt_templates = {
121
- "sadness": f"Sadness detected: {{input}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:",
122
- "joy": f"Joy detected: {{input}}\nRespond with: 1. Thanks 2. Appreciation 3. Engagement\nResponse:",
123
- "love": f"Love detected: {{input}}\nRespond with: 1. Warmth 2. Community 3. Exclusive Offer\nResponse:",
124
- "anger": f"Anger detected: {{input}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:",
125
- "fear": f"Fear detected: {{input}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:",
126
- "surprise": f"Surprise detected: {{input}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:",
127
- "neutral": f"Feedback: {{input}}\nRespond professionally:\n1. Acknowledgement\n2. Assistance\n3. Next Steps\nResponse:"
128
  }
129
- return prompt_templates[emotion.lower()].format(input=text[:300]) # Limit input length
130
-
131
- def _process_response(raw_text):
132
- """Fast response processing with validation"""
133
- # Extract response after last marker
134
- response = raw_text.split("Response:")[-1].strip()
135
-
136
- # Ensure complete sentences
137
- if '.' in response:
138
- response = response.rsplit('.', 1)[0] + '.'
139
-
140
- # Length control
141
- return response[:200] if len(response) > 50 else "Thank you for your feedback. We'll respond shortly."
142
 
143
- def _generate_text(user_input, models):
144
- """Ultra-fast text generation pipeline"""
145
- start_time = time.time()
 
146
 
147
- # Emotion analysis
148
- emotion = _analyze_emotion(user_input, models['emotion'])
149
 
150
- # Generate prompt
151
- prompt = _generate_prompt(user_input, emotion['label'])
152
-
153
- # Tokenize and generate
154
- inputs = models['textgen_tokenizer'](
155
  prompt,
156
  return_tensors="pt",
157
- max_length=128,
158
  truncation=True
159
- ).to(models['device'])
160
 
161
- outputs = models['textgen_model'].generate(
162
  inputs.input_ids,
163
- max_new_tokens=80, # Strict limit for speed
164
  temperature=0.7,
165
  top_p=0.9,
166
  do_sample=True,
167
- pad_token_id=models['textgen_tokenizer'].eos_token_id
168
  )
169
 
170
- # Decode and process
171
- generated = models['textgen_tokenizer'].decode(
172
- outputs[0],
173
- skip_special_tokens=True
174
- )
175
 
176
- st.write(f"⏱️ Text generation time: {time.time()-start_time:.2f}s")
177
- return _process_response(generated)
 
 
178
 
179
- def _generate_speech(text, models):
180
- """Hardware-accelerated speech synthesis"""
181
- start_time = time.time()
182
-
183
- # Process text
184
- inputs = models['tts_processor'](
185
  text=text[:150], # Limit text length
186
  return_tensors="pt"
187
- ).to(models['device'])
188
 
189
- # Generate audio
190
- with torch.inference_mode():
191
- spectrogram = models['tts_model'].generate_speech(
192
  inputs["input_ids"],
193
- models['speaker_embeddings']
194
  )
195
- waveform = models['tts_vocoder'](spectrogram)
196
-
197
- # Save optimized audio file
198
- sf.write("response.wav", waveform.cpu().numpy(), 16000)
199
 
200
- st.write(f"⏱️ Speech synthesis time: {time.time()-start_time:.2f}s")
201
- return "response.wav"
202
 
203
  ##########################################
204
- # Main Application Flow
205
  ##########################################
206
  def main():
207
- """Optimized execution flow"""
208
- # Load models first
209
- ml_models = _load_models()
210
 
211
- # Display interface
212
- user_input = _display_interface()
213
 
214
  if user_input:
215
- total_start = time.time()
216
-
217
  # Text generation
218
- with st.spinner("🚀 Analyzing & generating response..."):
219
- text_response = _generate_text(user_input, ml_models)
220
 
221
- # Display results
222
- st.subheader(f"📄 Generated Response")
223
- st.markdown(f"```\n{text_response}\n```")
224
 
225
  # Audio generation
226
- with st.spinner("🔊 Converting to speech..."):
227
- audio_file = _generate_speech(text_response, ml_models)
228
- st.audio(audio_file, format="audio/wav")
229
-
230
- st.write(f"⏱️ Total execution time: {time.time()-total_start:.2f}s")
231
 
232
  if __name__ == "__main__":
233
  main()
 
1
  ##########################################
2
+ # Step 0: Essential imports
3
  ##########################################
4
+ import streamlit as st # Web interface
5
+ from transformers import ( # AI components
6
  pipeline,
7
  SpeechT5Processor,
8
  SpeechT5ForTextToSpeech,
9
  SpeechT5HifiGan,
10
  AutoModelForCausalLM,
11
  AutoTokenizer
12
+ )
13
+ from datasets import load_dataset # Voice data
14
+ import torch # Tensor operations
15
+ import soundfile as sf # Audio processing
 
16
 
17
  ##########################################
18
+ # Initial configuration (MUST BE FIRST)
19
  ##########################################
20
+ st.set_page_config( # Set page config first
21
  page_title="Just Comment",
22
  page_icon="💬",
23
+ layout="centered"
 
24
  )
25
 
26
  ##########################################
27
+ # Optimized model loader with caching
28
  ##########################################
29
  @st.cache_resource(show_spinner=False)
30
+ def _load_components():
31
+ """Load and cache all models with hardware optimization"""
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
+ # Emotion classifier (fast)
35
  emotion_pipe = pipeline(
36
  "text-classification",
37
  model="Thea231/jhartmann_emotion_finetuning",
38
  device=device,
39
+ truncation=True
 
40
  )
41
 
42
+ # Text generator (optimized)
43
+ text_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B")
44
+ text_model = AutoModelForCausalLM.from_pretrained(
 
 
 
45
  "Qwen/Qwen1.5-0.5B",
46
  torch_dtype=torch.float16,
47
  device_map="auto"
48
  )
49
 
50
+ # TTS system (accelerated)
51
  tts_processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
52
  tts_model = SpeechT5ForTextToSpeech.from_pretrained(
53
  "microsoft/speecht5_tts",
 
58
  torch_dtype=torch.float16
59
  ).to(device)
60
 
61
+ # Preloaded voice profile
62
+ speaker_emb = torch.tensor(
63
  load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")[7306]["xvector"]
64
  ).unsqueeze(0).to(device)
65
 
66
  return {
67
+ "emotion": emotion_pipe,
68
+ "text_model": text_model,
69
+ "text_tokenizer": text_tokenizer,
70
+ "tts_processor": tts_processor,
71
+ "tts_model": tts_model,
72
+ "tts_vocoder": tts_vocoder,
73
+ "speaker_emb": speaker_emb,
74
+ "device": device
75
  }
76
 
77
  ##########################################
78
+ # User interface components
79
  ##########################################
80
+ def _show_interface():
81
+ """Render input interface"""
82
  st.title("Just Comment")
83
+ st.markdown(f"### I'm listening to you, my friend~")
84
+ return st.text_area( # Input field
 
85
  "📝 Enter your comment:",
86
+ placeholder="Share your thoughts...",
87
  height=150,
88
+ key="input"
89
  )
90
 
91
  ##########################################
92
+ # Core processing functions
93
  ##########################################
94
+ def _fast_emotion(text, analyzer):
95
+ """Rapid emotion detection with input limits"""
96
+ result = analyzer(text[:256], return_all_scores=True)[0] # Limit input length
97
+ emotions = ['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
98
+ return max(
99
+ (e for e in result if e['label'].lower() in emotions),
 
 
 
100
  key=lambda x: x['score'],
101
+ default={'label': 'neutral', 'score': 0}
102
  )
 
 
 
103
 
104
+ def _build_prompt(text, emotion):
105
+ """Template-based prompt engineering"""
106
+ templates = {
107
+ "sadness": f"Sadness detected: {{text}}\nRespond with: 1. Empathy 2. Support 3. Solution\nResponse:",
108
+ "joy": f"Joy detected: {{text}}\nRespond with: 1. Thanks 2. Praise 3. Engagement\nResponse:",
109
+ "love": f"Love detected: {{text}}\nRespond with: 1. Appreciation 2. Connection 3. Offer\nResponse:",
110
+ "anger": f"Anger detected: {{text}}\nRespond with: 1. Apology 2. Action 3. Compensation\nResponse:",
111
+ "fear": f"Fear detected: {{text}}\nRespond with: 1. Reassurance 2. Safety 3. Support\nResponse:",
112
+ "surprise": f"Surprise detected: {{text}}\nRespond with: 1. Acknowledgement 2. Solution 3. Follow-up\nResponse:",
113
+ "neutral": f"Feedback: {{text}}\nProfessional response:\n1. Acknowledgement\n2. Assistance\n3. Next steps\nResponse:"
114
  }
115
+ return templates[emotion.lower()].format(text=text[:200]) # Input truncation
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ def _generate_response(text, models):
118
+ """Optimized text generation pipeline"""
119
+ # Emotion detection
120
+ emotion = _fast_emotion(text, models["emotion"])
121
 
122
+ # Prompt construction
123
+ prompt = _build_prompt(text, emotion["label"])
124
 
125
+ # Generate text
126
+ inputs = models["text_tokenizer"](
 
 
 
127
  prompt,
128
  return_tensors="pt",
129
+ max_length=100,
130
  truncation=True
131
+ ).to(models["device"])
132
 
133
+ output = models["text_model"].generate(
134
  inputs.input_ids,
135
+ max_new_tokens=120, # Balanced length
136
  temperature=0.7,
137
  top_p=0.9,
138
  do_sample=True,
139
+ pad_token_id=models["text_tokenizer"].eos_token_id
140
  )
141
 
142
+ # Process output
143
+ full_text = models["text_tokenizer"].decode(output[0], skip_special_tokens=True)
144
+ response = full_text.split("Response:")[-1].strip()
 
 
145
 
146
+ # Ensure completeness
147
+ if "." in response:
148
+ response = response.rsplit(".", 1)[0] + "."
149
+ return response[:200] or "Thank you for your feedback. We'll respond shortly."
150
 
151
+ def _text_to_speech(text, models):
152
+ """High-speed audio synthesis"""
153
+ inputs = models["tts_processor"](
 
 
 
154
  text=text[:150], # Limit text length
155
  return_tensors="pt"
156
+ ).to(models["device"])
157
 
158
+ with torch.inference_mode(): # Accelerated inference
159
+ spectrogram = models["tts_model"].generate_speech(
 
160
  inputs["input_ids"],
161
+ models["speaker_emb"]
162
  )
163
+ audio = models["tts_vocoder"](spectrogram)
 
 
 
164
 
165
+ sf.write("output.wav", audio.cpu().numpy(), 16000)
166
+ return "output.wav"
167
 
168
  ##########################################
169
+ # Main application flow
170
  ##########################################
171
  def main():
172
+ """Primary execution controller"""
173
+ # Load components
174
+ components = _load_components()
175
 
176
+ # Show interface
177
+ user_input = _show_interface()
178
 
179
  if user_input:
 
 
180
  # Text generation
181
+ with st.spinner("🔍 Analyzing..."):
182
+ response = _generate_response(user_input, components)
183
 
184
+ # Display result
185
+ st.subheader(f"📄 Response")
186
+ st.markdown(f"```\n{response}\n```") # f-string formatted
187
 
188
  # Audio generation
189
+ with st.spinner("🔊 Synthesizing..."):
190
+ audio_path = _text_to_speech(response, components)
191
+ st.audio(audio_path, format="audio/wav")
 
 
192
 
193
  if __name__ == "__main__":
194
  main()