import gradio as gr from utils.watermark import Watermarker from utils.config import load_config from renderers.highlighter import highlight_common_words, highlight_common_words_dict, reparaphrased_sentences_html from renderers.tree import generate_subplot1, generate_subplot2 from pathlib import Path import time from typing import Dict, List, Tuple, Any import plotly.graph_objects as go class WatermarkerInterface: def __init__(self, config): self.pipeline = Watermarker(config) self.common_grams = {} self.highlight_info = [] self.masked_sentences = [] def handle_paraphrase(self, prompt: str) -> Tuple[str, str, str, str]: """Wrapper for paraphrasing that includes highlighting""" start_time = time.time() # Run paraphrasing self.pipeline.Paraphrase(prompt) # Step 1: Process the original sentence first seen_ngrams = {} # Stores first occurrence index of each n-gram original_indexed_ngrams = [] # Final indexed list for original original_sentence = self.pipeline.user_prompt original_ngrams = self.pipeline.common_grams.get(original_sentence, {}) # Step 1.1: Extract n-grams and their first occurrence index ngram_occurrences = [ (min(indices, key=lambda x: x[0])[0], gram) # Get first index for gram, indices in original_ngrams.items() ] # Step 1.2: Sort n-grams based on their first occurrence ngram_occurrences.sort() # Step 1.3: Assign sequential indices for idx, (position, gram) in enumerate(ngram_occurrences, start=1): seen_ngrams[gram] = idx # Assign sequential index original_indexed_ngrams.append((idx, gram)) print("Original Indexed N-grams:", original_indexed_ngrams) #generate highlight_info colors = ["red", "blue", "green", "purple", "orange"] highlight_info = [ (ngram, colors[i % len(colors)]) for i, (index, ngram) in enumerate(original_indexed_ngrams) ] common_grams = original_indexed_ngrams self.highlight_info = highlight_info self.common_grams = common_grams # Step 2: Process paraphrased sentences and match indices paraphrase_indexed_ngrams = {} for sentence in self.pipeline.paraphrased_sentences: sentence_ngrams = [] # Stores n-grams for this sentence sentence_ngrams_dict = self.pipeline.common_grams.get(sentence, {}) for gram, indices in sentence_ngrams_dict.items(): first_occurrence = min(indices, key=lambda x: x[0])[0] # Use the original's index if exists, otherwise assign a new one if gram in seen_ngrams: index = seen_ngrams[gram] # Use the same index as original else: index = len(seen_ngrams) + 1 # Assign new index seen_ngrams[gram] = index # Store it sentence_ngrams.append((index, gram)) sentence_ngrams.sort() paraphrase_indexed_ngrams[sentence] = sentence_ngrams print("Paraphrase Indexed N-grams:", paraphrase_indexed_ngrams) # Step 3: Generate highlighted versions using the renderer highlighted_prompt = highlight_common_words( common_grams, [self.pipeline.user_prompt], "Original Prompt with Highlighted Common Sequences" ) highlighted_accepted = highlight_common_words_dict( common_grams, self.pipeline.selected_sentences, "Accepted Paraphrased Sentences with Entailment Scores" ) highlighted_discarded = highlight_common_words_dict( common_grams, self.pipeline.discarded_sentences, "Discarded Paraphrased Sentences with Entailment Scores" ) execution_time = f"
Step 1 completed in {time.time() - start_time:.2f} seconds
" self.highlight_info = highlight_info self.common_grams = common_grams return highlighted_prompt, highlighted_accepted, highlighted_discarded, execution_time def handle_masking(self) -> Tuple[List[go.Figure], str]: """Wrapper for masking that generates visualization trees""" start_time = time.time() masking_results = self.pipeline.Masking() trees = [] highlight_info = self.highlight_info common_grams = self.common_grams sentence_to_masked = {} # Create a consolidated figure with all strategies original_sentence = None # First pass - gather all sentences and strategies for strategy, sentence_dict in masking_results.items(): for sent, data in sentence_dict.items(): if sent not in sentence_to_masked: sentence_to_masked[sent] = [] try: if not isinstance(data, dict): print(f"[ERROR] Data is not a dictionary for {sent} with strategy {strategy}") continue masked_sentence = data.get("masked_sentence", "") if masked_sentence: sentence_to_masked[sent].append((masked_sentence, strategy)) except Exception as e: print(f"Error processing {strategy} for sentence {sent}: {e}") for original_sentence, masked_sentences_data in sentence_to_masked.items(): if not masked_sentences_data: continue masked_sentences = [ms[0] for ms in masked_sentences_data] strategies = [ms[1] for ms in masked_sentences_data] try: fig = generate_subplot1( original_sentence, masked_sentences, strategies, highlight_info, common_grams ) trees.append(fig) except Exception as e: print(f"Error generating multi-strategy tree: {e}") trees.append(go.Figure()) # Pad with empty plots if needed while len(trees) < 10: trees.append(go.Figure()) execution_time = f"
Step 2 completed in {time.time() - start_time:.2f} seconds
" return trees[:10] + [execution_time] def handle_sampling(self) -> Tuple[List[go.Figure], str]: """Wrapper for sampling that generates visualization trees""" start_time = time.time() sampling_results = self.pipeline.Sampling() trees = [] # Group sentences by original sentence organized_results = {} # Generate trees for each sampled sentence for sampling_strategy, masking_dict in sampling_results.items(): for masking_strategy, sentences in masking_dict.items(): for original_sentence, data in sentences.items(): if original_sentence not in organized_results: organized_results[original_sentence] = {} if masking_strategy not in organized_results[original_sentence]: organized_results[original_sentence][masking_strategy] = { "masked_sentence": data.get("masked_sentence", ""), # Corrected reference "sampled_sentences": {} } # Add this sampling result organized_results[original_sentence][masking_strategy]["sampled_sentences"][sampling_strategy] = data.get("sampled_sentence", "") for original_sentence, data in organized_results.items(): masked_sentences = [] all_sampled_sentences = [] for masking_strategy, masking_data in list(data.items())[:3]: # Ensure this iteration is safe masked_sentence = masking_data.get("masked_sentence", "") if masked_sentence: masked_sentences.append(masked_sentence) for sampling_strategy, sampled_sentence in masking_data.get("sampled_sentences", {}).items(): if sampled_sentence: all_sampled_sentences.append(sampled_sentence) if masked_sentences: try: fig = generate_subplot2( masked_sentences, all_sampled_sentences, self.highlight_info, self.common_grams ) trees.append(fig) except Exception as e: print(f"Error generating subplot for {original_sentence}: {e}") trees.append(go.Figure()) while len(trees) < 10: trees.append(go.Figure()) execution_time = f"
Step 3 completed in {time.time() - start_time:.2f} seconds
" return trees[:10] + [execution_time] def handle_reparaphrasing(self) -> Tuple[List[str], str]: """Wrapper for re-paraphrasing that formats results as HTML""" start_time = time.time() results = self.pipeline.re_paraphrasing() html_outputs = [] # Generate HTML for each batch of re-paraphrased sentences for sampling_strategy, masking_dict in results.items(): for masking_strategy, sentences in masking_dict.items(): for original_sent, data in sentences.items(): if data["re_paraphrased_sentences"]: html = reparaphrased_sentences_html(data["re_paraphrased_sentences"]) html_outputs.append(html) # Pad with empty HTML if needed while len(html_outputs) < 120: html_outputs.append("") execution_time = f"
Step 4 completed in {time.time() - start_time:.2f} seconds
" return html_outputs[:120] + [execution_time] def create_gradio_interface(config): """Creates the Gradio interface with the updated pipeline""" interface = WatermarkerInterface(config) with gr.Blocks(theme=gr.themes.Monochrome()) as demo: #CSS to enable scrolling for reparaphrased sentences and sampling plots demo.css = """ /* Set fixed height for the reparaphrased tabs container only */ .gradio-container .tabs[id="reparaphrased-tabs"], .gradio-container .tabs[id="sampling-tabs"] { overflow-x: hidden; white-space: normal; border-radius: 8px; max-height: 600px; /* Set fixed height for the entire tabs component */ overflow-y: auto; /* Enable vertical scrolling inside the container */ } /* Tab content styling for reparaphrased and sampling tabs */ .gradio-container .tabs[id="reparaphrased-tabs"] .tabitem, .gradio-container .tabs[id="sampling-tabs"] .tabitem { overflow-x: hidden; white-space: normal; display: block; border-radius: 8px; } /* Make the tab navigation fixed at the top for scrollable tabs */ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav, .gradio-container .tabs[id="sampling-tabs"] .tab-nav { display: flex; overflow-x: auto; white-space: nowrap; scrollbar-width: thin; border-radius: 8px; scrollbar-color: #888 #f1f1f1; position: sticky; top: 0; background: white; z-index: 100; } /* Dropdown menu for scrollable tabs styling */ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown { position: relative; display: inline-block; } .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content { display: none; position: absolute; background-color: #f9f9f9; min-width: 160px; box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); z-index: 1; max-height: 300px; overflow-y: auto; } .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown:hover .tab-dropdown-content { display: block; } /* Scrollbar styling for scrollable tabs */ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar, .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar { height: 8px; border-radius: 8px; } .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-track, .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-track { background: #f1f1f1; border-radius: 8px; } .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav::-webkit-scrollbar-thumb, .gradio-container .tabs[id="sampling-tabs"] .tab-nav::-webkit-scrollbar-thumb { background: #888; border-radius: 8px; } /* Tab button styling for scrollable tabs */ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-item, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-item { flex: 0 0 auto; border-radius: 8px; } /* Plot container styling specifically for sampling tabs */ .gradio-container .tabs[id="sampling-tabs"] .plot-container { min-height: 600px; max-height: 1800px; overflow-y: auto; } /* Ensure text wraps in HTML components */ .gradio-container .prose { white-space: normal; word-wrap: break-word; overflow-wrap: break-word; } /* Dropdown button styling for scrollable tabs */ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button { background-color: #f0f0f0; border: 1px solid #ddd; border-radius: 4px; padding: 5px 10px; cursor: pointer; margin: 2px; } .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown button:hover, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown button:hover { background-color: #e0e0e0; } /* Style dropdown content items for scrollable tabs */ .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div { padding: 8px 12px; cursor: pointer; } .gradio-container .tabs[id="reparaphrased-tabs"] .tab-nav .tab-dropdown-content div:hover, .gradio-container .tabs[id="sampling-tabs"] .tab-nav .tab-dropdown-content div:hover { background-color: #e0e0e0; } /* Custom styling for execution time display */ .execution-time { text-align: right; padding: 8px 16px; font-family: inherit; color: #555; font-size: 0.9rem; font-style: italic; margin-left: auto; width: 100%; border-top: 1px solid #eee; margin-top: 8px; } /* Layout for section headers with execution time */ .section-header { display: flex; justify-content: space-between; align-items: center; width: 100%; margin-bottom: 12px; } .section-header h3 { margin: 0; } """ gr.Markdown("# **AIISC Watermarking Model**") with gr.Column(): gr.Markdown("## Input Prompt") user_input = gr.Textbox( label="Enter Your Prompt", placeholder="Type your text here..." ) with gr.Row(): with gr.Column(scale=3): gr.Markdown("## Step 1: Paraphrasing, LCS and Entailment Analysis") with gr.Column(scale=1): step1_time = gr.HTML() paraphrase_button = gr.Button("Generate Paraphrases") highlighted_user_prompt = gr.HTML(label="Highlighted User Prompt") with gr.Tabs(): with gr.TabItem("Accepted Paraphrased Sentences"): highlighted_accepted_sentences = gr.HTML() with gr.TabItem("Discarded Paraphrased Sentences"): highlighted_discarded_sentences = gr.HTML() with gr.Row(): with gr.Column(scale=3): gr.Markdown("## Step 2: Where to Mask?") with gr.Column(scale=1): step2_time = gr.HTML() masking_button = gr.Button("Apply Masking") gr.Markdown("### Masked Sentence Trees") tree1_plots = [] with gr.Tabs() as tree1_tabs: for i in range(10): with gr.TabItem(f"Masked Sentence {i+1}"): tree1 = gr.Plot() tree1_plots.append(tree1) with gr.Row(): with gr.Column(scale=3): gr.Markdown("## Step 3: How to Mask?") with gr.Column(scale=1): step3_time = gr.HTML() sampling_button = gr.Button("Sample Words") gr.Markdown("### Sampled Sentence Trees") tree2_plots = [] # Add elem_id to make this tab container scrollable with gr.Tabs(elem_id="sampling-tabs") as tree2_tabs: for i in range(10): with gr.TabItem(f"Sampled Sentence {i+1}"): # Add a custom class to the container to enable proper styling with gr.Column(elem_classes=["plot-container"]): tree2 = gr.Plot() tree2_plots.append(tree2) with gr.Row(): with gr.Column(scale=3): gr.Markdown("## Step 4: Re-paraphrasing") with gr.Column(scale=1): step4_time = gr.HTML() reparaphrase_button = gr.Button("Re-paraphrase") gr.Markdown("### Reparaphrased Sentences") reparaphrased_sentences_tabs = [] with gr.Tabs(elem_id="reparaphrased-tabs") as reparaphrased_tabs: for i in range(120): with gr.TabItem(f"Reparaphrased Batch {i+1}"): reparaphrased_sent_html = gr.HTML() reparaphrased_sentences_tabs.append(reparaphrased_sent_html) # Connect the interface functions to the buttons paraphrase_button.click( interface.handle_paraphrase, inputs=user_input, outputs=[ highlighted_user_prompt, highlighted_accepted_sentences, highlighted_discarded_sentences, step1_time ] ) masking_button.click( interface.handle_masking, inputs=None, outputs=tree1_plots + [step2_time] ) sampling_button.click( interface.handle_sampling, inputs=None, outputs=tree2_plots + [step3_time] ) reparaphrase_button.click( interface.handle_reparaphrasing, inputs=None, outputs=reparaphrased_sentences_tabs + [step4_time] ) return demo if __name__ == "__main__": project_root = Path(__file__).parent.parent config_path = project_root / "utils" / "config.yaml" config = load_config(config_path)['PECCAVI_TEXT'] create_gradio_interface(config).launch()