Bils commited on
Commit
2de59b3
·
verified ·
1 Parent(s): a38649c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -68
app.py CHANGED
@@ -16,86 +16,174 @@ import spaces
16
  from TTS.api import TTS
17
  from TTS.utils.synthesizer import Synthesizer
18
 
19
- # Load environment variables
 
 
20
  load_dotenv()
21
- hf_token = os.getenv("HF_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # ---------------------------------------------------------------------
24
  # Script Generation Function
25
  # ---------------------------------------------------------------------
26
  @spaces.GPU(duration=100)
27
  def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
 
 
 
 
28
  try:
29
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_id,
32
- use_auth_token=token,
33
- torch_dtype=torch.float16,
34
- device_map="auto",
35
- trust_remote_code=True,
36
- )
37
- llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
38
 
39
  # System prompt with clear structure instructions
40
  system_prompt = (
41
- f"You are an expert radio imaging producer specializing in sound design and music. "
42
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
43
- f"1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'.\n"
44
- f"2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
45
- f"3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
46
  )
47
 
48
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
49
- result = llama_pipeline(combined_prompt, max_new_tokens=300, do_sample=True, temperature=0.8)
50
 
51
- # Parsing output
52
- generated_text = result[0]["generated_text"].split("Output:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Extract sections based on prefixes
55
- voice_script = generated_text.split("Voice-Over Script:")[1].split("Sound Design Suggestions:")[0].strip() if "Voice-Over Script:" in generated_text else "No voice-over script found."
56
- sound_design = generated_text.split("Sound Design Suggestions:")[1].split("Music Suggestions:")[0].strip() if "Sound Design Suggestions:" in generated_text else "No sound design suggestions found."
57
- music_suggestions = generated_text.split("Music Suggestions:")[1].strip() if "Music Suggestions:" in generated_text else "No music suggestions found."
 
 
 
 
 
 
 
58
 
59
  return voice_script, sound_design, music_suggestions
 
60
  except Exception as e:
61
  return f"Error generating script: {e}", "", ""
62
 
 
63
  # ---------------------------------------------------------------------
64
  # Voice-Over Generation Function (Inactive)
65
  # ---------------------------------------------------------------------
66
  @spaces.GPU(duration=100)
67
  def generate_voice(script: str, speaker: str = "default"):
 
 
 
68
  try:
69
- # Placeholder for inactive state
70
  return "Voice-over generation is currently inactive."
71
  except Exception as e:
72
  return f"Error: {e}"
73
 
 
74
  # ---------------------------------------------------------------------
75
- # Music Generation Function (facebook/musicgen-medium)
76
  # ---------------------------------------------------------------------
77
  @spaces.GPU(duration=100)
78
  def generate_music(prompt: str, audio_length: int):
 
 
 
 
79
  try:
80
- # Load facebook/musicgen-medium model
81
- musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
82
- musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
83
 
84
- # Move the model to the appropriate device (CUDA or CPU)
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
- musicgen_model.to(device)
87
-
88
- # Prepare inputs
89
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
90
 
91
- # Generate music
92
- outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
 
93
 
94
- # Process audio data
95
  audio_data = outputs[0, 0].cpu().numpy()
 
96
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
97
 
98
- # Save generated music to a file
99
  output_path = f"{tempfile.gettempdir()}/musicgen_medium_generated_music.wav"
100
  write(output_path, 44100, normalized_audio)
101
 
@@ -106,11 +194,13 @@ def generate_music(prompt: str, audio_length: int):
106
 
107
 
108
  # ---------------------------------------------------------------------
109
- # Audio Blending Function with Ducking (Inactive)
110
  # ---------------------------------------------------------------------
111
  def blend_audio(voice_path: str, music_path: str, ducking: bool):
 
 
 
112
  try:
113
- # Placeholder for inactive state
114
  return "Audio blending functionality is currently inactive."
115
  except Exception as e:
116
  return f"Error: {e}"
@@ -121,73 +211,92 @@ def blend_audio(voice_path: str, music_path: str, ducking: bool):
121
  # ---------------------------------------------------------------------
122
  with gr.Blocks() as demo:
123
  gr.Markdown("""
124
- # 🎧 AI Promo Studio 🚀
125
- Welcome to **AI Promo Studio**, your one-stop solution for creating stunning and professional radio promos with ease!
126
- Whether you're a sound designer, radio producer, or content creator, our AI-driven tools, powered by advanced LLM Llama models, empower you to bring your vision to life in just a few steps.
127
  """)
128
 
129
  with gr.Tabs():
130
  # Step 1: Generate Script
131
  with gr.Tab("Step 1: Generate Script"):
132
  with gr.Row():
133
- user_prompt = gr.Textbox(label="Promo Idea", placeholder="E.g., A 30-second promo for a morning show.")
134
- llama_model_id = gr.Textbox(label="Llama Model ID", value="meta-llama/Meta-Llama-3-8B-Instruct")
135
- duration = gr.Slider(label="Duration (seconds)", minimum=15, maximum=60, step=15, value=30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  generate_script_button = gr.Button("Generate Script")
138
- script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5)
139
- sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3)
140
- music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3)
141
 
142
  generate_script_button.click(
143
- fn=lambda user_prompt, model_id, duration: generate_script(user_prompt, model_id, hf_token, duration),
144
  inputs=[user_prompt, llama_model_id, duration],
145
  outputs=[script_output, sound_design_output, music_suggestion_output],
146
  )
147
 
148
- # Step 2: Generate Voice
149
  with gr.Tab("Step 2: Generate Voice"):
150
  gr.Markdown("""
151
- **Note:** Voice-over generation is currently inactive.
152
- This feature will be available in future updates!
153
  """)
154
-
155
 
156
  # Step 3: Generate Music
157
  with gr.Tab("Step 3: Generate Music"):
158
  with gr.Row():
159
- audio_length = gr.Slider(label="Music Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
160
-
 
 
 
 
 
 
161
  generate_music_button = gr.Button("Generate Music")
162
- music_output = gr.Audio(label="Generated Music", type="filepath")
163
 
164
  generate_music_button.click(
165
- fn=lambda music_suggestion, audio_length: generate_music(music_suggestion, audio_length),
166
  inputs=[music_suggestion_output, audio_length],
167
  outputs=[music_output],
168
  )
169
 
170
- # Step 4: Blend Audio
171
  with gr.Tab("Step 4: Blend Audio"):
172
  gr.Markdown("""
173
- **Note:** Audio blending functionality is currently inactive.
174
- This feature will be available in future updates!
175
  """)
176
-
177
 
 
178
  gr.Markdown("""
179
- <hr>
180
- <p style="text-align: center; font-size: 0.9em;">
181
- Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
182
- </p>
183
  """)
184
 
185
- # Add visitor badge HTML
186
  gr.HTML("""
187
- <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
188
- <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" />
189
- </a>
190
  """)
191
 
 
192
  demo.launch(debug=True)
193
-
 
16
  from TTS.api import TTS
17
  from TTS.utils.synthesizer import Synthesizer
18
 
19
+ # ---------------------------------------------------------------------
20
+ # Load Environment Variables
21
+ # ---------------------------------------------------------------------
22
  load_dotenv()
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+
25
+ # ---------------------------------------------------------------------
26
+ # Global Model Caches
27
+ # ---------------------------------------------------------------------
28
+ # We store models/pipelines in global variables for reuse,
29
+ # so they are only loaded once.
30
+ LLAMA_PIPELINES = {}
31
+ MUSICGEN_MODELS = {}
32
+
33
+ # ---------------------------------------------------------------------
34
+ # Helper Functions
35
+ # ---------------------------------------------------------------------
36
+ def get_llama_pipeline(model_id: str, token: str):
37
+ """
38
+ Returns a cached LLaMA pipeline if available; otherwise, loads it.
39
+ This significantly reduces loading time for repeated calls.
40
+ """
41
+ if model_id in LLAMA_PIPELINES:
42
+ return LLAMA_PIPELINES[model_id]
43
+
44
+ # Load new pipeline and store in cache
45
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_id,
48
+ use_auth_token=token,
49
+ torch_dtype=torch.float16,
50
+ device_map="auto",
51
+ trust_remote_code=True,
52
+ )
53
+ text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
54
+ LLAMA_PIPELINES[model_id] = text_pipeline
55
+ return text_pipeline
56
+
57
+
58
+ def get_musicgen_model(model_key: str = "facebook/musicgen-medium"):
59
+ """
60
+ Returns a cached MusicGen model if available; otherwise, loads it.
61
+ """
62
+ if model_key in MUSICGEN_MODELS:
63
+ return MUSICGEN_MODELS[model_key]
64
+
65
+ # Load new MusicGen model and store in cache
66
+ model = MusicgenForConditionalGeneration.from_pretrained(model_key)
67
+ processor = AutoProcessor.from_pretrained(model_key)
68
+
69
+ device = "cuda" if torch.cuda.is_available() else "cpu"
70
+ model.to(device)
71
+
72
+ MUSICGEN_MODELS[model_key] = (model, processor)
73
+ return model, processor
74
+
75
 
76
  # ---------------------------------------------------------------------
77
  # Script Generation Function
78
  # ---------------------------------------------------------------------
79
  @spaces.GPU(duration=100)
80
  def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
81
+ """
82
+ Generates a script, sound design suggestions, and music ideas from a user prompt.
83
+ Returns a tuple of strings: (voice_script, sound_design, music_suggestions).
84
+ """
85
  try:
86
+ text_pipeline = get_llama_pipeline(model_id, token)
 
 
 
 
 
 
 
 
87
 
88
  # System prompt with clear structure instructions
89
  system_prompt = (
90
+ "You are an expert radio imaging producer specializing in sound design and music. "
91
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
92
+ "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'.\n"
93
+ "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
94
+ "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
95
  )
96
 
97
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
 
98
 
99
+ # Use inference mode for efficient forward passes
100
+ with torch.inference_mode():
101
+ result = text_pipeline(
102
+ combined_prompt,
103
+ max_new_tokens=300,
104
+ do_sample=True,
105
+ temperature=0.8
106
+ )
107
+
108
+ # LLaMA pipeline returns a list of dicts with "generated_text"
109
+ generated_text = result[0]["generated_text"]
110
+
111
+ # Basic parsing to isolate everything after "Output:"
112
+ # (in case the model repeated your system prompt).
113
+ if "Output:" in generated_text:
114
+ generated_text = generated_text.split("Output:")[-1].strip()
115
+
116
+ # Extract sections based on known prefixes
117
+ voice_script = "No voice-over script found."
118
+ sound_design = "No sound design suggestions found."
119
+ music_suggestions = "No music suggestions found."
120
+
121
+ if "Voice-Over Script:" in generated_text:
122
+ parts = generated_text.split("Voice-Over Script:")
123
+ if len(parts) > 1:
124
+ # Everything after "Voice-Over Script:" up until next prefix
125
+ voice_script_part = parts[1]
126
+ voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip() \
127
+ if "Sound Design Suggestions:" in voice_script_part else voice_script_part.strip()
128
 
129
+ if "Sound Design Suggestions:" in generated_text:
130
+ parts = generated_text.split("Sound Design Suggestions:")
131
+ if len(parts) > 1:
132
+ sound_design_part = parts[1]
133
+ sound_design = sound_design_part.split("Music Suggestions:")[0].strip() \
134
+ if "Music Suggestions:" in sound_design_part else sound_design_part.strip()
135
+
136
+ if "Music Suggestions:" in generated_text:
137
+ parts = generated_text.split("Music Suggestions:")
138
+ if len(parts) > 1:
139
+ music_suggestions = parts[1].strip()
140
 
141
  return voice_script, sound_design, music_suggestions
142
+
143
  except Exception as e:
144
  return f"Error generating script: {e}", "", ""
145
 
146
+
147
  # ---------------------------------------------------------------------
148
  # Voice-Over Generation Function (Inactive)
149
  # ---------------------------------------------------------------------
150
  @spaces.GPU(duration=100)
151
  def generate_voice(script: str, speaker: str = "default"):
152
+ """
153
+ Placeholder for future voice-over generation functionality.
154
+ """
155
  try:
 
156
  return "Voice-over generation is currently inactive."
157
  except Exception as e:
158
  return f"Error: {e}"
159
 
160
+
161
  # ---------------------------------------------------------------------
162
+ # Music Generation Function
163
  # ---------------------------------------------------------------------
164
  @spaces.GPU(duration=100)
165
  def generate_music(prompt: str, audio_length: int):
166
+ """
167
+ Generates music from the 'facebook/musicgen-medium' model based on the prompt.
168
+ Returns the file path to the generated .wav file.
169
+ """
170
  try:
171
+ model_key = "facebook/musicgen-medium"
172
+ musicgen_model, musicgen_processor = get_musicgen_model(model_key)
 
173
 
 
174
  device = "cuda" if torch.cuda.is_available() else "cpu"
175
+ # Prepare input
 
 
176
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
177
 
178
+ # Generate music within inference mode
179
+ with torch.inference_mode():
180
+ outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
181
 
 
182
  audio_data = outputs[0, 0].cpu().numpy()
183
+ # Normalize audio to int16 format
184
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
185
 
186
+ # Save generated music to a temp file
187
  output_path = f"{tempfile.gettempdir()}/musicgen_medium_generated_music.wav"
188
  write(output_path, 44100, normalized_audio)
189
 
 
194
 
195
 
196
  # ---------------------------------------------------------------------
197
+ # Audio Blending Function (Inactive)
198
  # ---------------------------------------------------------------------
199
  def blend_audio(voice_path: str, music_path: str, ducking: bool):
200
+ """
201
+ Placeholder for future audio blending functionality with optional ducking.
202
+ """
203
  try:
 
204
  return "Audio blending functionality is currently inactive."
205
  except Exception as e:
206
  return f"Error: {e}"
 
211
  # ---------------------------------------------------------------------
212
  with gr.Blocks() as demo:
213
  gr.Markdown("""
214
+ # 🎧 AI Promo Studio 🚀
215
+ Welcome to **AI Promo Studio**, your one-stop solution for creating stunning and professional radio promos with ease!
216
+ Whether you're a sound designer, radio producer, or content creator, our AI-driven tools, powered by advanced LLM Llama models, empower you to bring your vision to life in just a few steps.
217
  """)
218
 
219
  with gr.Tabs():
220
  # Step 1: Generate Script
221
  with gr.Tab("Step 1: Generate Script"):
222
  with gr.Row():
223
+ user_prompt = gr.Textbox(
224
+ label="Promo Idea",
225
+ placeholder="E.g., A 30-second promo for a morning show...",
226
+ lines=2
227
+ )
228
+ llama_model_id = gr.Textbox(
229
+ label="LLaMA Model ID",
230
+ value="meta-llama/Meta-Llama-3-8B-Instruct",
231
+ placeholder="Enter a valid Hugging Face model ID"
232
+ )
233
+ duration = gr.Slider(
234
+ label="Desired Promo Duration (seconds)",
235
+ minimum=15,
236
+ maximum=60,
237
+ step=15,
238
+ value=30
239
+ )
240
 
241
  generate_script_button = gr.Button("Generate Script")
242
+ script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
243
+ sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
244
+ music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
245
 
246
  generate_script_button.click(
247
+ fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
248
  inputs=[user_prompt, llama_model_id, duration],
249
  outputs=[script_output, sound_design_output, music_suggestion_output],
250
  )
251
 
252
+ # Step 2: Generate Voice (Inactive)
253
  with gr.Tab("Step 2: Generate Voice"):
254
  gr.Markdown("""
255
+ **Note:** Voice-over generation is currently inactive.
256
+ This feature will be available in future updates!
257
  """)
 
258
 
259
  # Step 3: Generate Music
260
  with gr.Tab("Step 3: Generate Music"):
261
  with gr.Row():
262
+ audio_length = gr.Slider(
263
+ label="Music Length (tokens)",
264
+ minimum=128,
265
+ maximum=1024,
266
+ step=64,
267
+ value=512,
268
+ info="Increase tokens for longer audio, but be mindful of inference time."
269
+ )
270
  generate_music_button = gr.Button("Generate Music")
271
+ music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
272
 
273
  generate_music_button.click(
274
+ fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
275
  inputs=[music_suggestion_output, audio_length],
276
  outputs=[music_output],
277
  )
278
 
279
+ # Step 4: Blend Audio (Inactive)
280
  with gr.Tab("Step 4: Blend Audio"):
281
  gr.Markdown("""
282
+ **Note:** Audio blending functionality is currently inactive.
283
+ This feature will be available in future updates!
284
  """)
 
285
 
286
+ # Footer / Credits
287
  gr.Markdown("""
288
+ <hr>
289
+ <p style="text-align: center; font-size: 0.9em;">
290
+ Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
291
+ </p>
292
  """)
293
 
294
+ # Visitor Badge
295
  gr.HTML("""
296
+ <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
297
+ <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" />
298
+ </a>
299
  """)
300
 
301
+ # Launch the Gradio app
302
  demo.launch(debug=True)