Bils commited on
Commit
17d10a7
·
verified ·
1 Parent(s): 12a1ead

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -193
app.py CHANGED
@@ -1,9 +1,6 @@
 
1
  import os
2
- import requests
3
  import torch
4
- import scipy.io.wavfile as wav
5
- import streamlit as st
6
- from io import BytesIO
7
  from transformers import (
8
  AutoTokenizer,
9
  AutoModelForCausalLM,
@@ -11,228 +8,111 @@ from transformers import (
11
  AutoProcessor,
12
  MusicgenForConditionalGeneration
13
  )
14
- from streamlit_lottie import st_lottie
15
-
16
- # ---------------------------------------------------------------------
17
- # 1) PAGE CONFIGURATION
18
- # ---------------------------------------------------------------------
19
- st.set_page_config(
20
- page_title="AI Radio Imaging with Llama 3",
21
- page_icon="🎧",
22
- layout="wide"
23
- )
24
-
25
- # ---------------------------------------------------------------------
26
- # 2) CUSTOM CSS / UI DESIGN
27
- # ---------------------------------------------------------------------
28
- CUSTOM_CSS = """
29
- <style>
30
- body {
31
- background-color: #121212;
32
- color: #FFFFFF;
33
- font-family: "Helvetica Neue", sans-serif;
34
- }
35
- .block-container {
36
- max-width: 1100px;
37
- padding: 1rem 1.5rem;
38
- }
39
- h1, h2, h3 {
40
- color: #1DB954;
41
- }
42
- .stButton>button {
43
- background-color: #1DB954 !important;
44
- color: #FFFFFF !important;
45
- border-radius: 24px;
46
- padding: 0.6rem 1.2rem;
47
- }
48
- .stButton>button:hover {
49
- background-color: #1ed760 !important;
50
- }
51
- textarea, input, select {
52
- border-radius: 8px !important;
53
- background-color: #282828 !important;
54
- color: #FFFFFF !important;
55
- }
56
- audio {
57
- width: 100%;
58
- margin-top: 1rem;
59
- }
60
- .footer-note {
61
- text-align: center;
62
- font-size: 14px;
63
- opacity: 0.7;
64
- margin-top: 2rem;
65
- }
66
- #MainMenu, footer {visibility: hidden;}
67
- </style>
68
- """
69
- st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
70
-
71
- # ---------------------------------------------------------------------
72
- # 3) LOAD LOTTIE ANIMATION
73
- # ---------------------------------------------------------------------
74
- @st.cache_data
75
- def load_lottie_url(url: str):
76
- r = requests.get(url)
77
- if r.status_code != 200:
78
- return None
79
- return r.json()
80
-
81
- LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
82
- lottie_animation = load_lottie_url(LOTTIE_URL)
83
 
84
  # ---------------------------------------------------------------------
85
- # 4) LOAD LLAMA 3 (GATED MODEL)
86
  # ---------------------------------------------------------------------
87
- @st.cache_resource
88
- def load_llama_pipeline(model_id: str, device: str, token: str):
89
  try:
90
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
91
  model = AutoModelForCausalLM.from_pretrained(
92
  model_id,
93
  use_auth_token=token,
94
- torch_dtype=torch.float16 if device == "auto" else torch.float32,
95
- device_map=device,
96
  low_cpu_mem_usage=True
97
  )
98
- text_gen_pipeline = pipeline(
99
- "text-generation",
100
- model=model,
101
- tokenizer=tokenizer,
102
- device_map=device
103
- )
104
- return text_gen_pipeline
105
  except Exception as e:
106
- st.error(f"Error loading Llama model: {e}")
107
- raise
108
 
109
  # ---------------------------------------------------------------------
110
- # 5) GENERATE RADIO SCRIPT
111
  # ---------------------------------------------------------------------
112
- def generate_radio_script(user_input: str, pipeline_llama) -> str:
113
- system_prompt = (
114
- "You are a top-tier radio imaging producer using Llama 3. "
115
- "Take the user's concept and craft a short, creative promo script."
116
- )
117
- combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
118
-
119
- result = pipeline_llama(
120
- combined_prompt,
121
- max_new_tokens=200,
122
- do_sample=True,
123
- temperature=0.9
124
- )
125
- output_text = result[0]["generated_text"]
126
- if "Refined script:" in output_text:
127
- output_text = output_text.split("Refined script:", 1)[-1].strip()
128
- output_text += "\n\n(Generated by Llama 3 - Radio Imaging)"
129
- return output_text
130
 
131
  # ---------------------------------------------------------------------
132
- # 6) LOAD MUSICGEN
133
  # ---------------------------------------------------------------------
134
- @st.cache_resource
135
  def load_musicgen_model():
136
- mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
137
- mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
138
- return mg_model, mg_processor
 
 
 
139
 
140
  # ---------------------------------------------------------------------
141
- # 7) HEADER
142
  # ---------------------------------------------------------------------
143
- st.title("🎧 AI Radio Imaging with Llama 3")
144
- st.subheader("Create engaging radio promos with Llama 3 + MusicGen")
145
- st.markdown("""Create **radio imaging promos** and **jingles** easily. Ensure you have access to
146
- **meta-llama/Meta-Llama-3-70B** on Hugging Face and provide your token below.""")
147
-
148
- if lottie_animation:
149
- st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie")
150
-
151
- st.markdown("---")
 
 
 
152
 
153
  # ---------------------------------------------------------------------
154
- # 8) USER INPUT
155
  # ---------------------------------------------------------------------
156
- st.subheader("🎤 Step 1: Describe Your Promo Idea")
157
- prompt = st.text_area(
158
- "Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'",
159
- height=120
160
- )
161
 
162
- col_model, col_device = st.columns(2)
163
- with col_model:
164
- llama_model_id = st.text_input(
165
- "Llama 3 Model ID",
166
- value="meta-llama/Meta-Llama-3-70B",
167
- help="Enter the exact model ID from Hugging Face."
168
- )
169
- with col_device:
170
- device_option = st.selectbox(
171
- "Device",
172
- ["auto", "cpu"],
173
- help="Choose GPU (auto) or CPU."
174
- )
175
 
176
- hf_token = os.getenv("HF_TOKEN")
177
- if not hf_token:
178
- st.error("No HF_TOKEN found. Please set it in your environment.")
179
- st.stop()
180
 
181
- if st.button("✍ Generate Promo Script"):
182
- if not prompt.strip():
183
- st.error("Please provide a concept first.")
184
- else:
185
- with st.spinner("Generating script..."):
186
- try:
187
- llama_pipeline = load_llama_pipeline(llama_model_id, device_option, hf_token)
188
- final_script = generate_radio_script(prompt, llama_pipeline)
189
- st.success("Promo script generated!")
190
- st.text_area("Generated Script", value=final_script, height=200)
191
- except Exception as e:
192
- st.error(f"Llama generation error: {e}")
193
 
194
- st.markdown("---")
195
 
196
  # ---------------------------------------------------------------------
197
- # 9) GENERATE AUDIO WITH MUSICGEN
198
  # ---------------------------------------------------------------------
199
- st.subheader("🎵 Step 2: Generate Audio")
200
- audio_length = st.slider("Track Length (tokens)", 128, 1024, 512, 64)
201
-
202
- if st.button("🎧 Create Audio"):
203
- if "final_script" not in st.session_state:
204
- st.error("Please generate a script first.")
205
- else:
206
- with st.spinner("Generating audio..."):
207
- try:
208
- mg_model, mg_processor = load_musicgen_model()
209
- inputs = mg_processor(
210
- text=[st.session_state["final_script"]],
211
- padding=True,
212
- return_tensors="pt"
213
- )
214
- audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length)
215
- sr = mg_model.config.audio_encoder.sampling_rate
216
- output_file = "radio_jingle.wav"
217
 
218
- audio_data = audio_values[0, 0].cpu().numpy()
219
- normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
220
- wav.write(output_file, rate=sr, data=normalized_audio)
221
 
222
- st.success("Audio generated! Play it below:")
223
- st.audio(output_file)
224
- except Exception as e:
225
- st.error(f"MusicGen error: {e}")
226
 
227
  # ---------------------------------------------------------------------
228
- # 10) FOOTER
229
  # ---------------------------------------------------------------------
230
- st.markdown("---")
231
- st.markdown(
232
- """
233
- <div class="footer-note">
234
- © 2025 AI Radio Imaging – Built with Hugging Face & Streamlit
235
- </div>
236
- """,
237
- unsafe_allow_html=True
238
- )
 
1
+ import gradio as gr
2
  import os
 
3
  import torch
 
 
 
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
 
8
  AutoProcessor,
9
  MusicgenForConditionalGeneration
10
  )
11
+ import scipy.io.wavfile as wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # ---------------------------------------------------------------------
14
+ # Load Llama 3 Model
15
  # ---------------------------------------------------------------------
16
+ def load_llama_pipeline(model_id: str, token: str, device: str = "cpu"):
 
17
  try:
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  use_auth_token=token,
22
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
23
+ device_map="auto" if device == "cuda" else None,
24
  low_cpu_mem_usage=True
25
  )
26
+ return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
 
 
 
 
 
 
27
  except Exception as e:
28
+ return str(e)
 
29
 
30
  # ---------------------------------------------------------------------
31
+ # Generate Radio Script
32
  # ---------------------------------------------------------------------
33
+ def generate_script(user_input: str, pipeline_llama):
34
+ try:
35
+ system_prompt = (
36
+ "You are a top-tier radio imaging producer using Llama 3. "
37
+ "Take the user's concept and craft a short, creative promo script."
38
+ )
39
+ combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
40
+ result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
41
+ return result[0]['generated_text'].split("Refined script:")[-1].strip()
42
+ except Exception as e:
43
+ return f"Error generating script: {e}"
 
 
 
 
 
 
 
44
 
45
  # ---------------------------------------------------------------------
46
+ # Load MusicGen Model
47
  # ---------------------------------------------------------------------
 
48
  def load_musicgen_model():
49
+ try:
50
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
51
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
52
+ return model, processor
53
+ except Exception as e:
54
+ return None, str(e)
55
 
56
  # ---------------------------------------------------------------------
57
+ # Generate Audio
58
  # ---------------------------------------------------------------------
59
+ def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
60
+ try:
61
+ inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
62
+ outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
63
+ sr = mg_model.config.audio_encoder.sampling_rate
64
+ audio_data = outputs[0, 0].cpu().numpy()
65
+ normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
66
+ output_file = "radio_jingle.wav"
67
+ wav.write(output_file, rate=sr, data=normalized_audio)
68
+ return output_file
69
+ except Exception as e:
70
+ return str(e)
71
 
72
  # ---------------------------------------------------------------------
73
+ # Gradio Interface
74
  # ---------------------------------------------------------------------
75
+ def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
76
+ # Load Llama 3 Pipeline
77
+ pipeline_llama = load_llama_pipeline(llama_model_id, hf_token, device="cuda" if torch.cuda.is_available() else "cpu")
78
+ if isinstance(pipeline_llama, str):
79
+ return pipeline_llama, None
80
 
81
+ # Generate Script
82
+ script = generate_script(user_prompt, pipeline_llama)
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Load MusicGen
85
+ mg_model, mg_processor = load_musicgen_model()
86
+ if isinstance(mg_processor, str):
87
+ return script, mg_processor
88
 
89
+ # Generate Audio
90
+ audio_file = generate_audio(script, audio_length, mg_model, mg_processor)
91
+ if isinstance(audio_file, str) and audio_file.startswith("Error"):
92
+ return script, audio_file
 
 
 
 
 
 
 
 
93
 
94
+ return script, audio_file
95
 
96
  # ---------------------------------------------------------------------
97
+ # Interface
98
  # ---------------------------------------------------------------------
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen")
101
+ with gr.Row():
102
+ user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show, fun and energetic.")
103
+ llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
104
+ hf_token = gr.Textbox(label="Hugging Face Token", type="password")
105
+ audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ generate_button = gr.Button("Generate Promo Script and Audio")
108
+ script_output = gr.Textbox(label="Generated Script")
109
+ audio_output = gr.Audio(label="Generated Audio", type="file")
110
 
111
+ generate_button.click(radio_imaging_app,
112
+ inputs=[user_prompt, llama_model_id, hf_token, audio_length],
113
+ outputs=[script_output, audio_output])
 
114
 
115
  # ---------------------------------------------------------------------
116
+ # Launch App
117
  # ---------------------------------------------------------------------
118
+ demo.launch()