mgyigit commited on
Commit
969a6ef
·
verified ·
1 Parent(s): fa6133d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -185
app.py CHANGED
@@ -13,50 +13,50 @@ import time
13
 
14
  class DrugGENConfig:
15
  # Inference configuration
16
- submodel='DrugGEN'
17
- inference_model="/home/user/app/experiments/models/DrugGEN/"
18
- sample_num=100
19
 
20
  # Data configuration
21
- inf_smiles='/home/user/app/data/chembl_test.smi'
22
- train_smiles='/home/user/app/data/chembl_train.smi'
23
- inf_batch_size=1
24
- mol_data_dir='/home/user/app/data'
25
- features=False
26
 
27
  # Model configuration
28
- act='relu'
29
- max_atom=45
30
- dim=128
31
- depth=1
32
- heads=8
33
- mlp_ratio=3
34
- dropout=0.
35
 
36
  # Seed configuration
37
- set_seed=True
38
- seed=10
39
 
40
- disable_correction=False
41
 
42
 
43
  class DrugGENAKT1Config(DrugGENConfig):
44
- submodel='DrugGEN'
45
- inference_model="/home/user/app/experiments/models/DrugGEN-akt1/"
46
- train_drug_smiles='/home/user/app/data/akt_train.smi'
47
- max_atom=45
48
 
49
 
50
  class DrugGENCDK2Config(DrugGENConfig):
51
- submodel='DrugGEN'
52
- inference_model="/home/user/app/experiments/models/DrugGEN-cdk2/"
53
- train_drug_smiles='/home/user/app//data/cdk2_train.smi'
54
- max_atom=38
55
 
56
 
57
  class NoTargetConfig(DrugGENConfig):
58
- submodel="NoTarget"
59
- inference_model="/home/user/app/experiments/models/NoTarget/"
60
 
61
 
62
  model_configs = {
@@ -66,60 +66,60 @@ model_configs = {
66
  }
67
 
68
 
69
-
70
- def function(model_name: str, num_molecules: int, seed_num: int):
71
- '''
72
- Returns:
73
- image, metrics_df, file_path, basic_metrics, advanced_metrics
74
- '''
75
- if model_name == "DrugGEN-NoTarget":
76
- model_name = "NoTarget"
77
 
 
 
 
78
  config = model_configs[model_name]
79
- config.sample_num = num_molecules
80
-
81
- if config.sample_num > 250:
82
- raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
83
 
84
- if seed_num is None or seed_num.strip() == "":
 
 
 
 
 
 
 
 
 
 
 
 
85
  config.seed = random.randint(0, 10000)
86
  else:
87
- try:
88
- config.seed = int(seed_num)
89
- except ValueError:
90
- raise gr.Error("The seed must be an integer value!")
91
-
92
- if model_name != "NoTarget":
93
- model_name = "DrugGEN"
 
 
 
 
 
 
 
 
 
 
94
 
95
  inferer = Inference(config)
96
  start_time = time.time()
97
  scores = inferer.inference() # This returns a DataFrame with specific columns
98
  et = time.time() - start_time
99
 
100
- score_df = pd.DataFrame({
101
- "Runtime (seconds)": [et],
102
- "Validity": [scores["validity"].iloc[0]],
103
- "Uniqueness": [scores["uniqueness"].iloc[0]],
104
- "Novelty (Train)": [scores["novelty"].iloc[0]],
105
- "Novelty (Test)": [scores["novelty_test"].iloc[0]],
106
- "Drug Novelty": [scores["drug_novelty"].iloc[0]],
107
- "Max Length": [scores["max_len"].iloc[0]],
108
- "Mean Atom Type": [scores["mean_atom_type"].iloc[0]],
109
- "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
110
- "SNN Drug": [scores["snn_drug"].iloc[0]],
111
- "Internal Diversity": [scores["IntDiv"].iloc[0]],
112
- "QED": [scores["qed"].iloc[0]],
113
- "SA Score": [scores["sa"].iloc[0]]
114
- })
115
-
116
  # Create basic metrics dataframe
117
  basic_metrics = pd.DataFrame({
118
  "Validity": [scores["validity"].iloc[0]],
119
  "Uniqueness": [scores["uniqueness"].iloc[0]],
120
  "Novelty (Train)": [scores["novelty"].iloc[0]],
121
- "Novelty (Test)": [scores["novelty_test"].iloc[0]],
122
- "Drug Novelty": [scores["drug_novelty"].iloc[0]],
123
  "Runtime (s)": [round(et, 2)]
124
  })
125
 
@@ -129,13 +129,13 @@ def function(model_name: str, num_molecules: int, seed_num: int):
129
  "SA Score": [scores["sa"].iloc[0]],
130
  "Internal Diversity": [scores["IntDiv"].iloc[0]],
131
  "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
132
- "SNN Drug": [scores["snn_drug"].iloc[0]],
133
- "Max Length": [scores["max_len"].iloc[0]]
134
  })
135
 
136
- output_file_path = f'/home/user/app/experiments/inference/{model_name}/inference_drugs.txt'
137
-
138
- new_path = f'{model_name}_denovo_mols.smi'
139
  os.rename(output_file_path, new_path)
140
 
141
  with open(new_path) as f:
@@ -143,13 +143,14 @@ def function(model_name: str, num_molecules: int, seed_num: int):
143
 
144
  generated_molecule_list = inference_drugs.split("\n")[:-1]
145
 
 
146
  rng = random.Random(config.seed)
147
- if num_molecules > 12:
148
- selected_molecules = rng.choices(generated_molecule_list, k=12)
149
  else:
150
- selected_molecules = generated_molecule_list
151
-
152
- selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None]
153
 
154
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
155
  drawOptions.prepareMolsBeforeDrawing = False
@@ -160,7 +161,6 @@ def function(model_name: str, num_molecules: int, seed_num: int):
160
  molsPerRow=3,
161
  subImgSize=(400, 400),
162
  maxMols=len(selected_molecules),
163
- # legends=None,
164
  returnPNG=False,
165
  drawOptions=drawOptions,
166
  highlightAtomLists=None,
@@ -170,7 +170,6 @@ def function(model_name: str, num_molecules: int, seed_num: int):
170
  return molecule_image, new_path, basic_metrics, advanced_metrics
171
 
172
 
173
-
174
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
175
  # Add custom CSS for styling
176
  gr.HTML("""
@@ -186,44 +185,40 @@ with gr.Blocks(theme=gr.themes.Ocean()) as demo:
186
  </style>
187
  """)
188
 
189
- with gr.Row():
190
- with gr.Column(scale=1):
191
- gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
192
-
193
- gr.HTML("""
194
- <div style="display: flex; gap: 10px; margin-bottom: 15px;">
195
- <!-- arXiv badge -->
196
- <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
197
- <div style="
198
- display: inline-block;
199
- background-color: #b31b1b;
200
- color: #ffffff !important; /* Force white text */
201
- padding: 5px 10px;
202
- border-radius: 5px;
203
- font-size: 14px;"
204
- >
205
- <span style="font-weight: bold;">arXiv</span> 2302.07868
206
- </div>
207
- </a>
208
-
209
- <!-- GitHub badge -->
210
- <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
211
- <div style="
212
- display: inline-block;
213
- background-color: #24292e;
214
- color: #ffffff !important; /* Force white text */
215
- padding: 5px 10px;
216
- border-radius: 5px;
217
- font-size: 14px;"
218
- >
219
- <span style="font-weight: bold;">GitHub</span> Repository
220
- </div>
221
- </a>
222
  </div>
223
- """)
224
-
225
- with gr.Accordion("About DrugGEN Models", open=False):
226
- gr.Markdown("""
 
 
227
  ## Model Variations
228
 
229
  ### DrugGEN-AKT1
@@ -233,104 +228,158 @@ This model is designed to generate molecules targeting the human AKT1 protein (U
233
  This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941).
234
 
235
  ### DrugGEN-NoTarget
236
- This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for:
237
- - Exploring chemical space
238
- - Generating diverse scaffolds
239
- - Creating molecules with drug-like properties
240
 
241
  For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
242
- """)
243
-
244
- with gr.Accordion("Understanding the Metrics", open=False):
245
- gr.Markdown("""
246
  ## Evaluation Metrics
247
 
248
  ### Basic Metrics
249
  - **Validity**: Percentage of generated molecules that are chemically valid
250
  - **Uniqueness**: Percentage of unique molecules among valid ones
251
- - **Runtime**: Time taken to generate the requested molecules
252
 
253
  ### Novelty Metrics
254
  - **Novelty (Train)**: Percentage of molecules not found in the training set
255
- - **Novelty (Test)**: Percentage of molecules not found in the test set
256
- - **Drug Novelty**: Percentage of molecules not found in known inhibitors of the target protein
257
 
258
  ### Structural Metrics
259
- - **Max Length**: Maximum component length in the generated molecules
260
  - **Mean Atom Type**: Average distribution of atom types
261
  - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
262
 
263
  ### Drug-likeness Metrics
264
  - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
265
- - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier)
266
 
267
  ### Similarity Metrics
268
  - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
269
- - **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs)
270
- """)
271
-
272
- model_name = gr.Radio(
273
- choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
274
- value="DrugGEN-AKT1",
275
- label="Select Target Model",
276
- info="Choose which protein target or general model to use for molecule generation"
277
- )
278
-
279
- num_molecules = gr.Slider(
280
- minimum=10,
281
- maximum=250,
282
- value=100,
283
- step=10,
284
- label="Number of Molecules to Generate",
285
- info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU."
286
- )
287
-
288
- seed_num = gr.Textbox(
289
- label="Random Seed (Optional)",
290
- value="",
291
- info="Set a specific seed for reproducible results, or leave empty for random generation"
292
- )
293
-
294
- submit_button = gr.Button(
295
- value="Generate Molecules",
296
- variant="primary",
297
- size="lg"
298
- )
299
-
300
- with gr.Column(scale=2):
301
- basic_metrics_df = gr.Dataframe(
302
- headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", "Novelty (Drug)", "Runtime (s)"],
303
- elem_id="basic-metrics"
304
- )
305
-
306
- advanced_metrics_df = gr.Dataframe(
307
- headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Drug)", "Max Length"],
308
- elem_id="advanced-metrics"
309
- )
310
-
311
- file_download = gr.File(
312
- label="Download All Generated Molecules (SMILES format)",
313
- )
314
-
315
- image_output = gr.Image(
316
- label="Structures of Randomly Selected Generated Molecules",
317
- elem_id="molecule_display"
318
- )
319
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
322
 
323
- submit_button.click(
324
- function,
325
- inputs=[model_name, num_molecules, seed_num],
 
326
  outputs=[
327
  image_output,
328
  file_download,
329
  basic_metrics_df,
330
  advanced_metrics_df
331
- ],
332
- api_name="inference"
333
  )
334
- #demo.queue(concurrency_count=1)
 
 
 
 
 
 
 
 
 
 
 
 
335
  demo.queue()
336
  demo.launch()
 
13
 
14
  class DrugGENConfig:
15
  # Inference configuration
16
+ submodel = 'DrugGEN'
17
+ inference_model = "/home/user/app/experiments/models/DrugGEN/"
18
+ sample_num = 100
19
 
20
  # Data configuration
21
+ inf_smiles = '/home/user/app/data/chembl_test.smi'
22
+ train_smiles = '/home/user/app/data/chembl_train.smi'
23
+ inf_batch_size = 1
24
+ mol_data_dir = '/home/user/app/data'
25
+ features = False
26
 
27
  # Model configuration
28
+ act = 'relu'
29
+ max_atom = 45
30
+ dim = 128
31
+ depth = 1
32
+ heads = 8
33
+ mlp_ratio = 3
34
+ dropout = 0.
35
 
36
  # Seed configuration
37
+ set_seed = True
38
+ seed = 10
39
 
40
+ disable_correction = False
41
 
42
 
43
  class DrugGENAKT1Config(DrugGENConfig):
44
+ submodel = 'DrugGEN'
45
+ inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/"
46
+ train_drug_smiles = '/home/user/app/data/akt_train.smi'
47
+ max_atom = 45
48
 
49
 
50
  class DrugGENCDK2Config(DrugGENConfig):
51
+ submodel = 'DrugGEN'
52
+ inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/"
53
+ train_drug_smiles = '/home/user/app/data/cdk2_train.smi'
54
+ max_atom = 38
55
 
56
 
57
  class NoTargetConfig(DrugGENConfig):
58
+ submodel = "NoTarget"
59
+ inference_model = "/home/user/app/experiments/models/NoTarget/"
60
 
61
 
62
  model_configs = {
 
66
  }
67
 
68
 
69
+ def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str):
70
+ """
71
+ Depending on the selected mode, either generate new molecules or evaluate provided SMILES.
 
 
 
 
 
72
 
73
+ Returns:
74
+ image, file_path, basic_metrics, advanced_metrics
75
+ """
76
  config = model_configs[model_name]
 
 
 
 
77
 
78
+ if mode == "Custom Input SMILES":
79
+ # Process the custom input SMILES
80
+ smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""]
81
+ if len(smiles_list) > 100:
82
+ raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.")
83
+ # Write the custom SMILES to a temporary file and update config
84
+ temp_input_file = "custom_input.smi"
85
+ with open(temp_input_file, "w") as f:
86
+ for s in smiles_list:
87
+ f.write(s + "\n")
88
+ config.inf_smiles = temp_input_file
89
+ config.sample_num = len(smiles_list)
90
+ # Always use a random seed for custom mode
91
  config.seed = random.randint(0, 10000)
92
  else:
93
+ # Classical Generation mode
94
+ config.sample_num = num_molecules
95
+ if config.sample_num > 250:
96
+ raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.")
97
+ if seed_num is None or seed_num.strip() == "":
98
+ config.seed = random.randint(0, 10000)
99
+ else:
100
+ try:
101
+ config.seed = int(seed_num)
102
+ except ValueError:
103
+ raise gr.Error("The seed must be an integer value!")
104
+
105
+ # Adjust model name for the inference if not using NoTarget
106
+ if model_name != "DrugGEN-NoTarget":
107
+ target_model_name = "DrugGEN"
108
+ else:
109
+ target_model_name = "NoTarget"
110
 
111
  inferer = Inference(config)
112
  start_time = time.time()
113
  scores = inferer.inference() # This returns a DataFrame with specific columns
114
  et = time.time() - start_time
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  # Create basic metrics dataframe
117
  basic_metrics = pd.DataFrame({
118
  "Validity": [scores["validity"].iloc[0]],
119
  "Uniqueness": [scores["uniqueness"].iloc[0]],
120
  "Novelty (Train)": [scores["novelty"].iloc[0]],
121
+ "Novelty (Inference)": [scores["novelty_test"].iloc[0]],
122
+ "Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]],
123
  "Runtime (s)": [round(et, 2)]
124
  })
125
 
 
129
  "SA Score": [scores["sa"].iloc[0]],
130
  "Internal Diversity": [scores["IntDiv"].iloc[0]],
131
  "SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
132
+ "SNN Real Inhibitors": [scores["snn_drug"].iloc[0]],
133
+ "Average Length": [scores["max_len"].iloc[0]]
134
  })
135
 
136
+ # Process the output file from inference
137
+ output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt'
138
+ new_path = f'{target_model_name}_denovo_mols.smi'
139
  os.rename(output_file_path, new_path)
140
 
141
  with open(new_path) as f:
 
143
 
144
  generated_molecule_list = inference_drugs.split("\n")[:-1]
145
 
146
+ # Randomly select up to 12 molecules for display
147
  rng = random.Random(config.seed)
148
+ if len(generated_molecule_list) > 12:
149
+ selected_smiles = rng.choices(generated_molecule_list, k=12)
150
  else:
151
+ selected_smiles = generated_molecule_list
152
+
153
+ selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None]
154
 
155
  drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
156
  drawOptions.prepareMolsBeforeDrawing = False
 
161
  molsPerRow=3,
162
  subImgSize=(400, 400),
163
  maxMols=len(selected_molecules),
 
164
  returnPNG=False,
165
  drawOptions=drawOptions,
166
  highlightAtomLists=None,
 
170
  return molecule_image, new_path, basic_metrics, advanced_metrics
171
 
172
 
 
173
  with gr.Blocks(theme=gr.themes.Ocean()) as demo:
174
  # Add custom CSS for styling
175
  gr.HTML("""
 
185
  </style>
186
  """)
187
 
188
+ gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
189
+
190
+ gr.HTML("""
191
+ <div style="display: flex; gap: 10px; margin-bottom: 15px;">
192
+ <!-- arXiv badge -->
193
+ <a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
194
+ <div style="
195
+ display: inline-block;
196
+ background-color: #b31b1b;
197
+ color: #ffffff !important;
198
+ padding: 5px 10px;
199
+ border-radius: 5px;
200
+ font-size: 14px;">
201
+ <span style="font-weight: bold;">arXiv</span> 2302.07868
202
+ </div>
203
+ </a>
204
+
205
+ <!-- GitHub badge -->
206
+ <a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
207
+ <div style="
208
+ display: inline-block;
209
+ background-color: #24292e;
210
+ color: #ffffff !important;
211
+ padding: 5px 10px;
212
+ border-radius: 5px;
213
+ font-size: 14px;">
214
+ <span style="font-weight: bold;">GitHub</span> Repository
 
 
 
 
 
 
215
  </div>
216
+ </a>
217
+ </div>
218
+ """)
219
+
220
+ with gr.Accordion("About DrugGEN Models", open=False):
221
+ gr.Markdown("""
222
  ## Model Variations
223
 
224
  ### DrugGEN-AKT1
 
228
  This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941).
229
 
230
  ### DrugGEN-NoTarget
231
+ This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein.
232
+ - Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties.
 
 
233
 
234
  For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
235
+ """)
236
+
237
+ with gr.Accordion("Understanding the Metrics", open=False):
238
+ gr.Markdown("""
239
  ## Evaluation Metrics
240
 
241
  ### Basic Metrics
242
  - **Validity**: Percentage of generated molecules that are chemically valid
243
  - **Uniqueness**: Percentage of unique molecules among valid ones
244
+ - **Runtime**: Time taken to generate or evaluate the molecules
245
 
246
  ### Novelty Metrics
247
  - **Novelty (Train)**: Percentage of molecules not found in the training set
248
+ - **Novelty (Inference)**: Percentage of molecules not found in the test set
249
+ - **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein
250
 
251
  ### Structural Metrics
252
+ - **Average Length**: Average component length in the generated molecules
253
  - **Mean Atom Type**: Average distribution of atom types
254
  - **Internal Diversity**: Diversity within the generated set (higher is more diverse)
255
 
256
  ### Drug-likeness Metrics
257
  - **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
258
+ - **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better)
259
 
260
  ### Similarity Metrics
261
  - **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds)
262
+ - **SNN Real Inhibitors**: Similarity to known drugs (higher means more similar to approved drugs)
263
+ """)
264
+
265
+ # Use Gradio Tabs to separate the two modes.
266
+ with gr.Tabs():
267
+ with gr.TabItem("Classical Generation"):
268
+ with gr.Row():
269
+ with gr.Column(scale=1):
270
+ model_name = gr.Radio(
271
+ choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
272
+ value="DrugGEN-AKT1",
273
+ label="Select Target Model",
274
+ info="Choose which protein target or general model to use for molecule generation"
275
+ )
276
+
277
+ num_molecules = gr.Slider(
278
+ minimum=10,
279
+ maximum=250,
280
+ value=100,
281
+ step=10,
282
+ label="Number of Molecules to Generate",
283
+ info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, we set a 250-molecule cap."
284
+ )
285
+
286
+ seed_num = gr.Textbox(
287
+ label="Random Seed (Optional)",
288
+ value="",
289
+ info="Set a specific seed for reproducible results, or leave empty for random generation"
290
+ )
291
+
292
+ classical_submit = gr.Button(
293
+ value="Generate Molecules",
294
+ variant="primary",
295
+ size="lg"
296
+ )
297
+ with gr.Column(scale=2):
298
+ basic_metrics_df = gr.Dataframe(
299
+ headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
300
+ elem_id="basic-metrics"
301
+ )
302
+
303
+ advanced_metrics_df = gr.Dataframe(
304
+ headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
305
+ elem_id="advanced-metrics"
306
+ )
307
+
308
+ file_download = gr.File(
309
+ label="Download All Generated Molecules (SMILES format)"
310
+ )
311
+
312
+ image_output = gr.Image(
313
+ label="Structures of Randomly Selected Generated Molecules",
314
+ elem_id="molecule_display"
315
+ )
316
+
317
+ with gr.TabItem("Custom Input SMILES"):
318
+ with gr.Row():
319
+ with gr.Column(scale=1):
320
+ # Reuse model selection for custom input
321
+ model_name_custom = gr.Radio(
322
+ choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
323
+ value="DrugGEN-AKT1",
324
+ label="Select Target Model",
325
+ info="Choose which protein target or general model to use for evaluation"
326
+ )
327
+ custom_smiles = gr.Textbox(
328
+ label="Input SMILES (one per line, maximum 100 molecules)",
329
+ placeholder="C(C(=O)O)N\nCCO\n...",
330
+ lines=10
331
+ )
332
+ custom_submit = gr.Button(
333
+ value="Evaluate Custom SMILES",
334
+ variant="primary",
335
+ size="lg"
336
+ )
337
+ with gr.Column(scale=2):
338
+ basic_metrics_df_custom = gr.Dataframe(
339
+ headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
340
+ elem_id="basic-metrics-custom"
341
+ )
342
+
343
+ advanced_metrics_df_custom = gr.Dataframe(
344
+ headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
345
+ elem_id="advanced-metrics-custom"
346
+ )
347
+
348
+ file_download_custom = gr.File(
349
+ label="Download All Molecules (SMILES format)"
350
+ )
351
+
352
+ image_output_custom = gr.Image(
353
+ label="Structures of Randomly Selected Molecules",
354
+ elem_id="molecule_display_custom"
355
+ )
356
 
357
  gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
358
 
359
+ # Set up the click actions for each tab.
360
+ classical_submit.click(
361
+ run_inference,
362
+ inputs=[gr.Variable("Generate Molecules"), model_name, num_molecules, seed_num, gr.Textbox.update(value="")],
363
  outputs=[
364
  image_output,
365
  file_download,
366
  basic_metrics_df,
367
  advanced_metrics_df
368
+ ],
369
+ api_name="inference_classical"
370
  )
371
+
372
+ custom_submit.click(
373
+ run_inference,
374
+ inputs=[gr.Variable("Custom Input SMILES"), model_name_custom, 0, gr.Textbox.update(value=""), custom_smiles],
375
+ outputs=[
376
+ image_output_custom,
377
+ file_download_custom,
378
+ basic_metrics_df_custom,
379
+ advanced_metrics_df_custom
380
+ ],
381
+ api_name="inference_custom"
382
+ )
383
+
384
  demo.queue()
385
  demo.launch()