File size: 6,688 Bytes
0a8cafa
297f353
4368d98
74f37a5
 
4368d98
7841db2
74f37a5
 
 
 
 
 
7841db2
18eee30
 
 
 
10dfcb1
 
18eee30
4368d98
10dfcb1
 
18eee30
 
74f37a5
18eee30
 
 
 
 
 
 
 
 
74f37a5
18eee30
 
 
7841db2
18eee30
 
 
 
 
 
 
47eb8cc
 
 
 
18eee30
 
 
 
 
 
 
 
 
fc804b6
18eee30
 
 
74f37a5
18eee30
74f37a5
0a8cafa
11559c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fc9e70
7841db2
f3d87e2
11559c2
10dfcb1
 
11559c2
 
10dfcb1
f3d87e2
10dfcb1
ae21d92
10dfcb1
4368d98
b827456
 
 
 
 
 
 
 
 
 
 
4368d98
10dfcb1
 
 
03aaedd
10dfcb1
 
 
 
 
4b78c6c
 
 
74f37a5
03aaedd
4b78c6c
10dfcb1
ae21d92
4b78c6c
4368d98
10dfcb1
5c30376
 
10dfcb1
 
18aa313
4368d98
10dfcb1
 
f3d87e2
0a8cafa
 
7841db2
5c30376
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import gradio as gr
import spaces
from transformers import pipeline, AutoTokenizer
import torch
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed

# Configure logging/logger
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Stores for models and tokenizers
tokenizers = {}
pipelines = {}

# Predefined list of models to compare (can be expanded)
model_options = {
    "Foundation-Sec-8B": "fdtn-ai/Foundation-Sec-8B",
    "Llama-3.1-8B": "meta-llama/Llama-3.1-8B",
}

# Initialize models at startup
for model_name, model_path in model_options.items():
    try:
        logger.info(f"Initializing text generation model: {model_path}")
        tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
        pipelines[model_path] = pipeline(
            "text-generation",
            model=model_path,
            tokenizer=tokenizers[model_path],
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True
        )
        logger.info(f"Model initialized successfully: {model_path}")
    except Exception as e:
        logger.error(f"Error initializing model {model_path}: {str(e)}")

@spaces.GPU
def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.95):
    """Local text generation"""
    try:
        # Use the already initialized model
        if model_path in pipelines:
            model_pipeline = pipelines[model_path]
            
            # Log GPU usage information
            device_info = next(model_pipeline.model.parameters()).device
            logger.info(f"Running text generation with {model_path} on device: {device_info}")

            outputs = model_pipeline(
                prompt,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                clean_up_tokenization_spaces=True,
            )

            return outputs[0]["generated_text"].replace(prompt, "").strip()
        else:
            return f"Error: Model {model_path} not initialized"
    except Exception as e:
        logger.error(f"Error in text generation with {model_path}: {str(e)}")
        return f"Error: {str(e)}"

# Move the generate_responses function outside of create_demo
def generate_responses(prompt, max_tokens, temperature, top_p, selected_models):
    if len(selected_models) != 2:
        return "Error: Please select exactly two models to compare.", ""
    
    if len(selected_models) == 0:
        return "Error: Please select at least one model", ""
    
    # 選択されたモデルの結果を格納する辞書
    responses = {}
    futures_to_model = {}  # 各futureとモデルを紐づけるための辞書
                
    with ThreadPoolExecutor(max_workers=len(selected_models)) as executor:
        # 各モデルに対してタスクを提出
        futures = []
        for model_name in selected_models:
            model_path = model_options[model_name]
            future = executor.submit(
                generate_text_local, 
                model_path, 
                prompt, 
                max_new_tokens=max_tokens,  # Fixed parameter name to match the function
                temperature=temperature, 
                top_p=top_p
            )
            futures.append(future)
            futures_to_model[future] = model_name
        
        # 結果の収集
        for future in as_completed(futures):
            model_name = futures_to_model[future]
            responses[model_name] = future.result()
    
    # モデル名を冒頭に付加して返す
    model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}"
    model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}"
    
    return model1_output, model2_output

# Build Gradio app
def create_demo():
    with gr.Blocks() as demo:
        gr.Markdown("# AI Model Comparison Tool for Security Analysis 🔒")
        gr.Markdown(
            """
            Compare how different AI models analyze security vulnerabilities side-by-side.  
            Select two models, input security-related text, and see how each model processes vulnerability information!
            """
        )

        # Input Section
        with gr.Row():
            prompt = gr.Textbox(
                value="""CVE-2021-44228 is a remote code execution flaw in Apache Log4j2 via unsafe JNDI lookups ("Log4Shell"). The CWE is CWE-502.

CVE-2017-0144 is a remote code execution vulnerability in Microsoft's SMBv1 server ("EternalBlue") due to a buffer overflow. The CWE is CWE-119.

CVE-2014-0160 is an information-disclosure bug in OpenSSL's heartbeat extension ("Heartbleed") causing out-of-bounds reads. The CWE is CWE-125.

CVE-2017-5638 is a remote code execution issue in Apache Struts 2's Jakarta Multipart parser stemming from improper input validation of the Content-Type header. The CWE is CWE-20.

CVE-2019-0708 is a remote code execution vulnerability in Microsoft's Remote Desktop Services ("BlueKeep") triggered by a use-after-free. The CWE is CWE-416.

CVE-2015-10011 is a vulnerability about OpenDNS OpenResolve improper log output neutralization. The CWE is""",
                label="Prompt"
            )

        with gr.Row():
            max_new_tokens = gr.Slider(minimum=1, maximum=2048, value=3, step=1, label="Max new tokens")
            temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
            top_p = gr.Slider(
                minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"
            )

        # Model Selection Section
        selected_models = gr.CheckboxGroup(
            choices=list(model_options.keys()),
            label="Select exactly two model to compare",
            value=list(model_options.keys())[:2],  # Default models
        )

        # Dynamic Response Section
        response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
        response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)

        # Add a button for generating responses
        submit_button = gr.Button("Generate Responses")
        submit_button.click(
            generate_responses,
            inputs=[prompt, max_new_tokens, temperature, top_p, selected_models],
            outputs=[response_box1, response_box2],  # Link to response boxes
        )

    return demo

if __name__ == "__main__":
    demo = create_demo()
    demo.launch()