# import spaces # import zero import gradio as gr import sys import threading import queue import time import random from io import TextIOBase import datetime import subprocess import os from inference import postprocess_inst_names from inference import inference_patch from convert import abc2xml, xml2, pdf2img title_html = """
NotaGen is a model for generating sheet music in ABC notation format. Select a period, composer, and instrumentation to generate classical-style music!
""" # Read prompt combinations with open('prompts.txt', 'r') as f: prompts = f.readlines() valid_combinations = set() for prompt in prompts: prompt = prompt.strip() parts = prompt.split('_') valid_combinations.add((parts[0], parts[1], parts[2])) # Prepare dropdown options periods = sorted({p for p, _, _ in valid_combinations}) composers = sorted({c for _, c, _ in valid_combinations}) instruments = sorted({i for _, _, i in valid_combinations}) # Dynamically update composer and instrument dropdown options def update_components(period, composer): if not period: return [ gr.update(choices=[], value=None, interactive=False), gr.update(choices=[], value=None, interactive=False) ] valid_composers = sorted({c for p, c, _ in valid_combinations if p == period}) valid_instruments = sorted({i for p, c, i in valid_combinations if p == period and c == composer}) if composer else [] return [ gr.update( choices=valid_composers, value=composer if composer in valid_composers else None, interactive=True ), gr.update( choices=valid_instruments, value=None, interactive=bool(valid_instruments) ) ] # Custom realtime stream for outputting model inference process to frontend class RealtimeStream(TextIOBase): def __init__(self, queue): self.queue = queue def write(self, text): self.queue.put(text) return len(text) def convert_files(abc_content, period, composer, instrumentation): if not all([period, composer, instrumentation]): raise gr.Error("Please complete a valid generation first before saving") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") prompt_str = f"{period}_{composer}_{instrumentation}" filename_base = f"{timestamp}_{prompt_str}" abc_filename = f"{filename_base}.abc" with open(abc_filename, "w", encoding="utf-8") as f: f.write(abc_content) # instrumentation replacement postprocessed_inst_abc = postprocess_inst_names(abc_content) filename_base_postinst = f"{filename_base}_postinst" with open(filename_base_postinst + ".abc", "w", encoding="utf-8") as f: f.write(postprocessed_inst_abc) # Convert files file_paths = {'abc': abc_filename} try: # abc2xml abc2xml(filename_base) abc2xml(filename_base_postinst) # xml2pdf xml2(filename_base, 'pdf') # xml2mid xml2(filename_base, 'mid') xml2(filename_base_postinst, 'mid') # xml2mp3 xml2(filename_base, 'mp3') xml2(filename_base_postinst, 'mp3') # 将PDF转为图片 images = pdf2img(filename_base) for i, image in enumerate(images): image.save(f"{filename_base}_page_{i+1}.png", "PNG") file_paths.update({ 'xml': f"{filename_base_postinst}.xml", 'pdf': f"{filename_base}.pdf", 'mid': f"{filename_base_postinst}.mid", 'mp3': f"{filename_base_postinst}.mp3", 'pages': len(images), 'current_page': 0, 'base': filename_base }) except Exception as e: raise gr.Error(f"File processing failed: {str(e)}") return file_paths # Page navigation control function def update_page(direction, data): """ data contains three key pieces of information: 'pages', 'current_page', and 'base' """ if not data: return None, gr.update(interactive=False), gr.update(interactive=False), data if direction == "prev" and data['current_page'] > 0: data['current_page'] -= 1 elif direction == "next" and data['current_page'] < data['pages'] - 1: data['current_page'] += 1 current_page_index = data['current_page'] # Update image path new_image = f"{data['base']}_page_{current_page_index+1}.png" # When current_page==0, prev_btn is disabled; when current_page==pages-1, next_btn is disabled prev_btn_state = gr.update(interactive=(current_page_index > 0)) next_btn_state = gr.update(interactive=(current_page_index < data['pages'] - 1)) return new_image, prev_btn_state, next_btn_state, data # @spaces.GPU(duration=600) def generate_music(period, composer, instrumentation): """ Must ensure each yield returns the same number of values. We're preparing to return 5 values, corresponding to: 1) process_output (intermediate inference information) 2) final_output (final ABC) 3) pdf_image (path to the PNG of the first page of the PDF) 4) audio_player (mp3 path) 5) pdf_state (state for page navigation) """ # Set a different random seed each time based on current timestamp random_seed = int(time.time()) % 10000 random.seed(random_seed) # For numpy if you're using it try: import numpy as np np.random.seed(random_seed) except ImportError: pass # For torch if you're using it try: import torch torch.manual_seed(random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(random_seed) except ImportError: pass if (period, composer, instrumentation) not in valid_combinations: # If the combination is invalid, raise an error raise gr.Error("Invalid prompt combination! Please re-select from the period options") output_queue = queue.Queue() original_stdout = sys.stdout sys.stdout = RealtimeStream(output_queue) result_container = [] def run_inference(): try: # Use downloaded model weights path for inference result = inference_patch(period, composer, instrumentation) result_container.append(result) finally: sys.stdout = original_stdout thread = threading.Thread(target=run_inference) thread.start() process_output = "" final_output_abc = "" pdf_image = None audio_file = None pdf_state = None # First continuously read intermediate output while thread.is_alive(): try: text = output_queue.get(timeout=0.1) process_output += text # No final ABC yet, files not yet converted yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False) except queue.Empty: continue # After thread ends, get all remaining items from the queue while not output_queue.empty(): text = output_queue.get() process_output += text # Final inference result final_result = result_container[0] if result_container else "" # Display file conversion prompt final_output_abc = "Converting files..." yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=None, visible=False) # Convert files try: file_paths = convert_files(final_result, period, composer, instrumentation) final_output_abc = final_result # Get the first image and mp3 file if file_paths['pages'] > 0: pdf_image = f"{file_paths['base']}_page_1.png" audio_file = file_paths['mp3'] pdf_state = file_paths # Directly use the converted information dictionary as state # Prepare download file list download_list = [] if 'abc' in file_paths and os.path.exists(file_paths['abc']): download_list.append(file_paths['abc']) if 'xml' in file_paths and os.path.exists(file_paths['xml']): download_list.append(file_paths['xml']) if 'pdf' in file_paths and os.path.exists(file_paths['pdf']): download_list.append(file_paths['pdf']) if 'mid' in file_paths and os.path.exists(file_paths['mid']): download_list.append(file_paths['mid']) if 'mp3' in file_paths and os.path.exists(file_paths['mp3']): download_list.append(file_paths['mp3']) except Exception as e: # If conversion fails, return error message to output box yield process_output, f"Error converting files: {str(e)}", None, None, None, gr.update(value=None, visible=False) return # Final yield with all information - modify here to make component visible yield process_output, final_output_abc, pdf_image, audio_file, pdf_state, gr.update(value=download_list, visible=True) def get_file(file_type, period, composer, instrumentation): """ Returns the local file of specified type for Gradio download """ # Here you actually need to return based on specific file paths saved earlier, simplified for demo # If matching by timestamp, you can store all converted files in a directory and get the latest # This is just an example: possible_files = [f for f in os.listdir('.') if f.endswith(f'.{file_type}')] if not possible_files: return None # Simply return the latest possible_files.sort(key=os.path.getmtime) return possible_files[-1] css = """ /* Compact button style */ button[size="sm"] { padding: 4px 8px !important; margin: 2px !important; min-width: 60px; } /* PDF preview area */ #pdf-preview { border-radius: 8px; /* Rounded corners */ box-shadow: 0 2px 8px rgba(0,0,0,0.1); /* Shadow */ } .page-btn { padding: 12px !important; /* Increase clickable area */ margin: auto !important; /* Vertical center */ } /* Button hover effect */ .page-btn:hover { background: #f0f0f0 !important; transform: scale(1.05); } /* Layout adjustment */ .gr-row { gap: 10px !important; /* Element spacing */ } /* Audio player */ .audio-panel { margin-top: 15px !important; max-width: 400px; } #audio-preview audio { height: 200px !important; } /* Save functionality area */ .save-as-row { margin-top: 15px; padding: 10px; border-top: 1px solid #eee; } /* Download files styling */ .download-files { margin-top: 15px; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); } /* Social icons styling */ .title-container { display: flex; align-items: center; gap: 15px; margin-bottom: 10px; } .title-text { margin: 0; font-size: 1.8em; } .social-icons { display: flex; gap: 10px; } .social-icon { display: inline-flex; align-items: center; justify-content: center; width: 32px; height: 32px; border-radius: 50%; background-color: #f5f5f5; text-decoration: none; transition: transform 0.2s, background-color 0.2s; } .social-icon:hover { transform: scale(1.1); background-color: #e0e0e0; } .social-icon img { width: 20px; height: 20px; } """ with gr.Blocks(css=css) as demo: gr.HTML(title_html) # For storing PDF page count, current page and other information pdf_state = gr.State() with gr.Column(): with gr.Row(): # Left sidebar with gr.Column(): with gr.Row(): period_dd = gr.Dropdown( choices=periods, value=None, label="Period", interactive=True ) composer_dd = gr.Dropdown( choices=[], value=None, label="Composer", interactive=False ) instrument_dd = gr.Dropdown( choices=[], value=None, label="Instrumentation", interactive=False ) generate_btn = gr.Button("Generate!", variant="primary") process_output = gr.Textbox( label="Generation process", interactive=False, lines=2, max_lines=2, placeholder="Generation progress will be shown here..." ) final_output = gr.Textbox( label="Post-processed ABC notation scores", interactive=True, lines=8, max_lines=8, placeholder="Post-processed ABC scores will be shown here..." ) # Audio playback audio_player = gr.Audio( label="Audio Preview", format="mp3", interactive=False, ) # Right sidebar with gr.Column(): # Image container pdf_image = gr.Image( label="Sheet Music Preview", show_label=False, height=650, type="filepath", elem_id="pdf-preview", interactive=False, show_download_button=False ) # Page navigation buttons with gr.Row(): prev_btn = gr.Button( "⬅️ Last Page", variant="secondary", size="sm", elem_classes="page-btn" ) next_btn = gr.Button( "Next Page ➡️", variant="secondary", size="sm", elem_classes="page-btn" ) with gr.Column(): gr.Markdown("**Download Files:**") download_files = gr.Files( label="Generated Files", visible=False, elem_classes="download-files", type="filepath" # Make sure this is set to filepath ) # Dropdown linking period_dd.change( update_components, inputs=[period_dd, composer_dd], outputs=[composer_dd, instrument_dd] ) composer_dd.change( update_components, inputs=[period_dd, composer_dd], outputs=[composer_dd, instrument_dd] ) # Click generate button, note outputs must match each yield in generate_music generate_btn.click( generate_music, inputs=[period_dd, composer_dd, instrument_dd], outputs=[process_output, final_output, pdf_image, audio_player, pdf_state, download_files] ) # Page navigation prev_signal = gr.Textbox(value="prev", visible=False) next_signal = gr.Textbox(value="next", visible=False) prev_btn.click( update_page, inputs=[prev_signal, pdf_state], # ✅ Use component outputs=[pdf_image, prev_btn, next_btn, pdf_state] ) next_btn.click( update_page, inputs=[next_signal, pdf_state], # ✅ Use component outputs=[pdf_image, prev_btn, next_btn, pdf_state] ) if __name__ == "__main__": # Configure GPU/CPU handling demo.launch( server_name="0.0.0.0", server_port=7860 )