nyasukun's picture
logging device
47eb8cc
import gradio as gr
import spaces
from transformers import pipeline, AutoTokenizer
import torch
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
# Configure logging/logger
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Stores for models and tokenizers
tokenizers = {}
pipelines = {}
# Predefined list of models to compare (can be expanded)
model_options = {
"Foundation-Sec-8B": "fdtn-ai/Foundation-Sec-8B",
"Llama-3.1-8B": "meta-llama/Llama-3.1-8B",
}
# Initialize models at startup
for model_name, model_path in model_options.items():
try:
logger.info(f"Initializing text generation model: {model_path}")
tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
pipelines[model_path] = pipeline(
"text-generation",
model=model_path,
tokenizer=tokenizers[model_path],
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
logger.info(f"Model initialized successfully: {model_path}")
except Exception as e:
logger.error(f"Error initializing model {model_path}: {str(e)}")
@spaces.GPU
def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.95):
"""Local text generation"""
try:
# Use the already initialized model
if model_path in pipelines:
model_pipeline = pipelines[model_path]
# Log GPU usage information
device_info = next(model_pipeline.model.parameters()).device
logger.info(f"Running text generation with {model_path} on device: {device_info}")
outputs = model_pipeline(
prompt,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_p=top_p,
clean_up_tokenization_spaces=True,
)
return outputs[0]["generated_text"].replace(prompt, "").strip()
else:
return f"Error: Model {model_path} not initialized"
except Exception as e:
logger.error(f"Error in text generation with {model_path}: {str(e)}")
return f"Error: {str(e)}"
# Move the generate_responses function outside of create_demo
def generate_responses(prompt, max_tokens, temperature, top_p, selected_models):
if len(selected_models) != 2:
return "Error: Please select exactly two models to compare.", ""
if len(selected_models) == 0:
return "Error: Please select at least one model", ""
# 選択されたモデルの結果を格納する辞書
responses = {}
futures_to_model = {} # 各futureとモデルを紐づけるための辞書
with ThreadPoolExecutor(max_workers=len(selected_models)) as executor:
# 各モデルに対してタスクを提出
futures = []
for model_name in selected_models:
model_path = model_options[model_name]
future = executor.submit(
generate_text_local,
model_path,
prompt,
max_new_tokens=max_tokens, # Fixed parameter name to match the function
temperature=temperature,
top_p=top_p
)
futures.append(future)
futures_to_model[future] = model_name
# 結果の収集
for future in as_completed(futures):
model_name = futures_to_model[future]
responses[model_name] = future.result()
# モデル名を冒頭に付加して返す
model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}"
model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}"
return model1_output, model2_output
# Build Gradio app
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# AI Model Comparison Tool for Security Analysis 🔒")
gr.Markdown(
"""
Compare how different AI models analyze security vulnerabilities side-by-side.
Select two models, input security-related text, and see how each model processes vulnerability information!
"""
)
# Input Section
with gr.Row():
prompt = gr.Textbox(
value="""CVE-2021-44228 is a remote code execution flaw in Apache Log4j2 via unsafe JNDI lookups ("Log4Shell"). The CWE is CWE-502.
CVE-2017-0144 is a remote code execution vulnerability in Microsoft's SMBv1 server ("EternalBlue") due to a buffer overflow. The CWE is CWE-119.
CVE-2014-0160 is an information-disclosure bug in OpenSSL's heartbeat extension ("Heartbleed") causing out-of-bounds reads. The CWE is CWE-125.
CVE-2017-5638 is a remote code execution issue in Apache Struts 2's Jakarta Multipart parser stemming from improper input validation of the Content-Type header. The CWE is CWE-20.
CVE-2019-0708 is a remote code execution vulnerability in Microsoft's Remote Desktop Services ("BlueKeep") triggered by a use-after-free. The CWE is CWE-416.
CVE-2015-10011 is a vulnerability about OpenDNS OpenResolve improper log output neutralization. The CWE is""",
label="Prompt"
)
with gr.Row():
max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=3, step=1, label="Max new tokens")
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(
minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
)
# Model Selection Section
selected_models = gr.CheckboxGroup(
choices=list(model_options.keys()),
label="Select exactly two model to compare",
value=list(model_options.keys())[:2], # Default models
)
# Dynamic Response Section
response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)
# Add a button for generating responses
submit_button = gr.Button("Generate Responses")
submit_button.click(
generate_responses,
inputs=[prompt, max_new_tokens, temperature, top_p, selected_models],
outputs=[response_box1, response_box2], # Link to response boxes
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()