import gradio as gr import spaces from transformers import pipeline, AutoTokenizer import torch import logging from concurrent.futures import ThreadPoolExecutor, as_completed # Configure logging/logger logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Stores for models and tokenizers tokenizers = {} pipelines = {} # Predefined list of models to compare (can be expanded) model_options = { "Foundation-Sec-8B": "fdtn-ai/Foundation-Sec-8B", "Llama-3.1-8B": "meta-llama/Llama-3.1-8B", } # Initialize models at startup for model_name, model_path in model_options.items(): try: logger.info(f"Initializing text generation model: {model_path}") tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) pipelines[model_path] = pipeline( "text-generation", model=model_path, tokenizer=tokenizers[model_path], torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) logger.info(f"Model initialized successfully: {model_path}") except Exception as e: logger.error(f"Error initializing model {model_path}: {str(e)}") @spaces.GPU def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.95): """Local text generation""" try: # Use the already initialized model if model_path in pipelines: model_pipeline = pipelines[model_path] # Log GPU usage information device_info = next(model_pipeline.model.parameters()).device logger.info(f"Running text generation with {model_path} on device: {device_info}") outputs = model_pipeline( prompt, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, clean_up_tokenization_spaces=True, ) return outputs[0]["generated_text"].replace(prompt, "").strip() else: return f"Error: Model {model_path} not initialized" except Exception as e: logger.error(f"Error in text generation with {model_path}: {str(e)}") return f"Error: {str(e)}" # Move the generate_responses function outside of create_demo def generate_responses(prompt, max_tokens, temperature, top_p, selected_models): if len(selected_models) != 2: return "Error: Please select exactly two models to compare.", "" if len(selected_models) == 0: return "Error: Please select at least one model", "" # 選択されたモデルの結果を格納する辞書 responses = {} futures_to_model = {} # 各futureとモデルを紐づけるための辞書 with ThreadPoolExecutor(max_workers=len(selected_models)) as executor: # 各モデルに対してタスクを提出 futures = [] for model_name in selected_models: model_path = model_options[model_name] future = executor.submit( generate_text_local, model_path, prompt, max_new_tokens=max_tokens, # Fixed parameter name to match the function temperature=temperature, top_p=top_p ) futures.append(future) futures_to_model[future] = model_name # 結果の収集 for future in as_completed(futures): model_name = futures_to_model[future] responses[model_name] = future.result() # モデル名を冒頭に付加して返す model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}" model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}" return model1_output, model2_output # Build Gradio app def create_demo(): with gr.Blocks() as demo: gr.Markdown("# AI Model Comparison Tool for Security Analysis 🔒") gr.Markdown( """ Compare how different AI models analyze security vulnerabilities side-by-side. Select two models, input security-related text, and see how each model processes vulnerability information! """ ) # Input Section with gr.Row(): prompt = gr.Textbox( value="""CVE-2021-44228 is a remote code execution flaw in Apache Log4j2 via unsafe JNDI lookups ("Log4Shell"). The CWE is CWE-502. CVE-2017-0144 is a remote code execution vulnerability in Microsoft's SMBv1 server ("EternalBlue") due to a buffer overflow. The CWE is CWE-119. CVE-2014-0160 is an information-disclosure bug in OpenSSL's heartbeat extension ("Heartbleed") causing out-of-bounds reads. The CWE is CWE-125. CVE-2017-5638 is a remote code execution issue in Apache Struts 2's Jakarta Multipart parser stemming from improper input validation of the Content-Type header. The CWE is CWE-20. CVE-2019-0708 is a remote code execution vulnerability in Microsoft's Remote Desktop Services ("BlueKeep") triggered by a use-after-free. The CWE is CWE-416. CVE-2015-10011 is a vulnerability about OpenDNS OpenResolve improper log output neutralization. The CWE is""", label="Prompt" ) with gr.Row(): max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=3, step=1, label="Max new tokens") temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ) # Model Selection Section selected_models = gr.CheckboxGroup( choices=list(model_options.keys()), label="Select exactly two model to compare", value=list(model_options.keys())[:2], # Default models ) # Dynamic Response Section response_box1 = gr.Textbox(label="Response from Model 1", interactive=False) response_box2 = gr.Textbox(label="Response from Model 2", interactive=False) # Add a button for generating responses submit_button = gr.Button("Generate Responses") submit_button.click( generate_responses, inputs=[prompt, max_new_tokens, temperature, top_p, selected_models], outputs=[response_box1, response_box2], # Link to response boxes ) return demo if __name__ == "__main__": demo = create_demo() demo.launch()