nyasukun commited on
Commit
11559c2
·
1 Parent(s): 18aa313
Files changed (1) hide show
  1. app.py +42 -45
app.py CHANGED
@@ -64,14 +64,53 @@ def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7,
64
  logger.error(f"Error in text generation with {model_path}: {str(e)}")
65
  return f"Error: {str(e)}"
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  # Build Gradio app
68
  def create_demo():
69
  with gr.Blocks() as demo:
70
- gr.Markdown("# AI Model Comparison Tool 🌟")
71
  gr.Markdown(
72
  """
73
- Compare responses from two AI models side-by-side.
74
- Select two models, ask a question, and compare their responses in real time!
75
  """
76
  )
77
 
@@ -110,48 +149,6 @@ CVE-2015-10011 is a vulnerability about OpenDNS OpenResolve improper log output
110
  response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
111
  response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)
112
 
113
- # Function to generate responses
114
- def generate_responses(
115
- prompt, max_tokens, temperature, top_p, selected_models
116
- ):
117
- if len(selected_models) != 2:
118
- return "Error: Please select exactly two models to compare.", ""
119
-
120
- if len(selected_models) == 0:
121
- return "Error: Please select at least one model", ""
122
-
123
- # 選択されたモデルの結果を格納する辞書
124
- responses = {}
125
- futures_to_model = {} # 各futureとモデルを紐づけるための辞書
126
-
127
- with ThreadPoolExecutor(max_workers=len(selected_models)) as executor:
128
- # 各モデルに対してタスクを提出
129
- futures = []
130
- for model_name in selected_models:
131
- model_path = model_options[model_name]
132
- future = executor.submit(
133
- generate_text_local,
134
- model_path,
135
- prompt,
136
- max_new_tokens = max_new_tokens,
137
- do_sample = True,
138
- temperature = temperature,
139
- top_p = top_p
140
- )
141
- futures.append(future)
142
- futures_to_model[future] = model_name
143
-
144
- # 結果の収集
145
- for future in as_completed(futures):
146
- model_name = futures_to_model[future]
147
- responses[model_name] = future.result()
148
-
149
- # モデル名を冒頭に付加して返す
150
- model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}"
151
- model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}"
152
-
153
- return model1_output, model2_output
154
-
155
  # Add a button for generating responses
156
  submit_button = gr.Button("Generate Responses")
157
  submit_button.click(
 
64
  logger.error(f"Error in text generation with {model_path}: {str(e)}")
65
  return f"Error: {str(e)}"
66
 
67
+ # Move the generate_responses function outside of create_demo
68
+ def generate_responses(prompt, max_tokens, temperature, top_p, selected_models):
69
+ if len(selected_models) != 2:
70
+ return "Error: Please select exactly two models to compare.", ""
71
+
72
+ if len(selected_models) == 0:
73
+ return "Error: Please select at least one model", ""
74
+
75
+ # 選択されたモデルの結果を格納する辞書
76
+ responses = {}
77
+ futures_to_model = {} # 各futureとモデルを紐づけるための辞書
78
+
79
+ with ThreadPoolExecutor(max_workers=len(selected_models)) as executor:
80
+ # 各モデルに対してタスクを提出
81
+ futures = []
82
+ for model_name in selected_models:
83
+ model_path = model_options[model_name]
84
+ future = executor.submit(
85
+ generate_text_local,
86
+ model_path,
87
+ prompt,
88
+ max_new_tokens=max_tokens, # Fixed parameter name to match the function
89
+ temperature=temperature,
90
+ top_p=top_p
91
+ )
92
+ futures.append(future)
93
+ futures_to_model[future] = model_name
94
+
95
+ # 結果の収集
96
+ for future in as_completed(futures):
97
+ model_name = futures_to_model[future]
98
+ responses[model_name] = future.result()
99
+
100
+ # モデル名を冒頭に付加して返す
101
+ model1_output = f"{selected_models[0]} Output:\n\n{responses.get(selected_models[0], '')}"
102
+ model2_output = f"{selected_models[1]} Output:\n\n{responses.get(selected_models[1], '')}"
103
+
104
+ return model1_output, model2_output
105
+
106
  # Build Gradio app
107
  def create_demo():
108
  with gr.Blocks() as demo:
109
+ gr.Markdown("# AI Model Comparison Tool for Security Analysis 🔒")
110
  gr.Markdown(
111
  """
112
+ Compare how different AI models analyze security vulnerabilities side-by-side.
113
+ Select two models, input security-related text, and see how each model processes vulnerability information!
114
  """
115
  )
116
 
 
149
  response_box1 = gr.Textbox(label="Response from Model 1", interactive=False)
150
  response_box2 = gr.Textbox(label="Response from Model 2", interactive=False)
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # Add a button for generating responses
153
  submit_button = gr.Button("Generate Responses")
154
  submit_button.click(