Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,688 Bytes
0a8cafa 297f353 4368d98 74f37a5 4368d98 7841db2 74f37a5 7841db2 18eee30 10dfcb1 18eee30 4368d98 10dfcb1 18eee30 74f37a5 18eee30 74f37a5 18eee30 7841db2 18eee30 47eb8cc 18eee30 fc804b6 18eee30 74f37a5 18eee30 74f37a5 0a8cafa 11559c2 4fc9e70 7841db2 f3d87e2 11559c2 10dfcb1 11559c2 10dfcb1 f3d87e2 10dfcb1 ae21d92 10dfcb1 4368d98 b827456 4368d98 10dfcb1 03aaedd 10dfcb1 4b78c6c 74f37a5 03aaedd 4b78c6c 10dfcb1 ae21d92 4b78c6c 4368d98 10dfcb1 5c30376 10dfcb1 18aa313 4368d98 10dfcb1 f3d87e2 0a8cafa 7841db2 5c30376 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import gradio as gr
import spaces
from transformers import pipeline, AutoTokenizer
import torch
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
# Configure logging/logger
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Stores for models and tokenizers
tokenizers = {}
pipelines = {}
# Predefined list of models to compare (can be expanded)
model_options = {
"Foundation-Sec-8B": "fdtn-ai/Foundation-Sec-8B",
"Llama-3.1-8B": "meta-llama/Llama-3.1-8B",
}
# Initialize models at startup
for model_name, model_path in model_options.items():
try:
logger.info(f"Initializing text generation model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
logger.info(f"Model initialized successfully: {model_path}")
except Exception as e:
logger.error(f"Error initializing model {model_path}: {str(e)}")
@spaces.GPU
def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.95):
"""Local text generation"""
try:
# Use the already initialized model
if model_path in pipelines:
model_pipeline = pipelines[model_path]
# Log GPU usage information
device_info = next(model_pipeline.model.parameters()).device
logger.info(f"Running text generation with {model_path} on device: {device_info}")
outputs = model_pipeline(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
clean_up_tokenization_spaces=True,
)
return outputs[0]["generated_text"].replace(prompt, "").strip()
else:
return f"Error: Model {model_path} not initialized"
except Exception as e:
logger.error(f"Error in text generation with {model_path}: {str(e)}")
return f"Error: {str(e)}"
# Move the generate_responses function outside of create_demo
def generate_responses(prompt, max_tokens, temperature, top_p, selected_models):
if len(selected_models) != 2:
return "Error: Please select exactly two models to compare.", ""
if len(selected_models) == 0:
return "Error: Please select at least one model", ""
# 選択されたモデルの結果を格納する辞書
responses = {}
futures_to_model = {} # 各futureとモデルを紐づけるための辞書
with ThreadPoolExecutor(max_workers=len(selected_models)) as executor:
# 各モデルに対してタスクを提出
futures = []
for model_name in selected_models:
model_path = model_options[model_name]
future = executor.submit(
generate_text_local,
model_path,
prompt,
max_new_tokens=max_tokens, # Fixed parameter name to match the function
temperature=temperature,
top_p=top_p
)
futures.append(future)
futures_to_model[future] = model_name
# 結果の収集
for future in as_completed(futures):
model_name = futures_to_model[future]
responses[model_name] = future.result()
# モデル名を冒頭に付加して返す
model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}"
model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}"
return model1_output, model2_output
# Build Gradio app
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# AI Model Comparison Tool for Security Analysis 🔒")
gr.Markdown(
"""
Compare how different AI models analyze security vulnerabilities side-by-side.
Select two models, input security-related text, and see how each model processes vulnerability information!
"""
)
# Input Section
with gr.Row():
prompt = gr.Textbox(
value="""CVE-2021-44228 is a remote code execution flaw in Apache Log4j2 via unsafe JNDI lookups ("Log4Shell"). The CWE is CWE-502.
CVE-2017-0144 is a remote code execution vulnerability in Microsoft's SMBv1 server ("EternalBlue") due to a buffer overflow. The CWE is CWE-119.
CVE-2014-0160 is an information-disclosure bug in OpenSSL's heartbeat extension ("Heartbleed") causing out-of-bounds reads. The CWE is CWE-125.
CVE-2017-5638 is a remote code execution issue in Apache Struts 2's Jakarta Multipart parser stemming from improper input validation of the Content-Type header. The CWE is CWE-20.
CVE-2019-0708 is a remote code execution vulnerability in Microsoft's Remote Desktop Services ("BlueKeep") triggered by a use-after-free. The CWE is CWE-416.
CVE-2015-10011 is a vulnerability about OpenDNS OpenResolve improper log output neutralization. The CWE is""",
label="Prompt"
)
with gr.Row():
max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=3, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
)
# Model Selection Section
selected_models = gr.CheckboxGroup(
choices=list(model_options.keys()),
label="Select exactly two model to compare",
value=list(model_options.keys())[:2], # Default models
)
# Dynamic Response Section
response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)
# Add a button for generating responses
submit_button = gr.Button("Generate Responses")
submit_button.click(
generate_responses,
inputs=[prompt, max_new_tokens, temperature, top_p, selected_models],
outputs=[response_box1, response_box2], # Link to response boxes
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch() |