Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
from transformers import pipeline, AutoTokenizer | |
import torch | |
import logging | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
# Configure logging/logger | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Stores for models and tokenizers | |
tokenizers = {} | |
pipelines = {} | |
# Predefined list of models to compare (can be expanded) | |
model_options = { | |
"Foundation-Sec-8B": "fdtn-ai/Foundation-Sec-8B", | |
"Llama-3.1-8B": "meta-llama/Llama-3.1-8B", | |
} | |
# Initialize models at startup | |
for model_name, model_path in model_options.items(): | |
try: | |
logger.info(f"Initializing text generation model: {model_path}") | |
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path) | |
pipelines[model_path] = pipeline( | |
"text-generation", | |
model=model_path, | |
tokenizer=tokenizers[model_path], | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
logger.info(f"Model initialized successfully: {model_path}") | |
except Exception as e: | |
logger.error(f"Error initializing model {model_path}: {str(e)}") | |
def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.95): | |
"""Local text generation""" | |
try: | |
# Use the already initialized model | |
if model_path in pipelines: | |
model_pipeline = pipelines[model_path] | |
# Log GPU usage information | |
device_info = next(model_pipeline.model.parameters()).device | |
logger.info(f"Running text generation with {model_path} on device: {device_info}") | |
outputs = model_pipeline( | |
prompt, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
clean_up_tokenization_spaces=True, | |
) | |
return outputs[0]["generated_text"].replace(prompt, "").strip() | |
else: | |
return f"Error: Model {model_path} not initialized" | |
except Exception as e: | |
logger.error(f"Error in text generation with {model_path}: {str(e)}") | |
return f"Error: {str(e)}" | |
# Move the generate_responses function outside of create_demo | |
def generate_responses(prompt, max_tokens, temperature, top_p, selected_models): | |
if len(selected_models) != 2: | |
return "Error: Please select exactly two models to compare.", "" | |
if len(selected_models) == 0: | |
return "Error: Please select at least one model", "" | |
# 選択されたモデルの結果を格納する辞書 | |
responses = {} | |
futures_to_model = {} # 各futureとモデルを紐づけるための辞書 | |
with ThreadPoolExecutor(max_workers=len(selected_models)) as executor: | |
# 各モデルに対してタスクを提出 | |
futures = [] | |
for model_name in selected_models: | |
model_path = model_options[model_name] | |
future = executor.submit( | |
generate_text_local, | |
model_path, | |
prompt, | |
max_new_tokens=max_tokens, # Fixed parameter name to match the function | |
temperature=temperature, | |
top_p=top_p | |
) | |
futures.append(future) | |
futures_to_model[future] = model_name | |
# 結果の収集 | |
for future in as_completed(futures): | |
model_name = futures_to_model[future] | |
responses[model_name] = future.result() | |
# モデル名を冒頭に付加して返す | |
model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}" | |
model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}" | |
return model1_output, model2_output | |
# Build Gradio app | |
def create_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# AI Model Comparison Tool for Security Analysis 🔒") | |
gr.Markdown( | |
""" | |
Compare how different AI models analyze security vulnerabilities side-by-side. | |
Select two models, input security-related text, and see how each model processes vulnerability information! | |
""" | |
) | |
# Input Section | |
with gr.Row(): | |
prompt = gr.Textbox( | |
value="""CVE-2021-44228 is a remote code execution flaw in Apache Log4j2 via unsafe JNDI lookups ("Log4Shell"). The CWE is CWE-502. | |
CVE-2017-0144 is a remote code execution vulnerability in Microsoft's SMBv1 server ("EternalBlue") due to a buffer overflow. The CWE is CWE-119. | |
CVE-2014-0160 is an information-disclosure bug in OpenSSL's heartbeat extension ("Heartbleed") causing out-of-bounds reads. The CWE is CWE-125. | |
CVE-2017-5638 is a remote code execution issue in Apache Struts 2's Jakarta Multipart parser stemming from improper input validation of the Content-Type header. The CWE is CWE-20. | |
CVE-2019-0708 is a remote code execution vulnerability in Microsoft's Remote Desktop Services ("BlueKeep") triggered by a use-after-free. The CWE is CWE-416. | |
CVE-2015-10011 is a vulnerability about OpenDNS OpenResolve improper log output neutralization. The CWE is""", | |
label="Prompt" | |
) | |
with gr.Row(): | |
max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=3, step=1, label="Max new tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider( | |
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" | |
) | |
# Model Selection Section | |
selected_models = gr.CheckboxGroup( | |
choices=list(model_options.keys()), | |
label="Select exactly two model to compare", | |
value=list(model_options.keys())[:2], # Default models | |
) | |
# Dynamic Response Section | |
response_box1 = gr.Textbox(label="Response from Model 1", interactive=False) | |
response_box2 = gr.Textbox(label="Response from Model 2", interactive=False) | |
# Add a button for generating responses | |
submit_button = gr.Button("Generate Responses") | |
submit_button.click( | |
generate_responses, | |
inputs=[prompt, max_new_tokens, temperature, top_p, selected_models], | |
outputs=[response_box1, response_box2], # Link to response boxes | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.launch() |