Update app.py
Browse files
app.py
CHANGED
@@ -1,183 +1,184 @@
|
|
1 |
-
import torch
|
2 |
-
import random
|
3 |
-
import pandas as pd
|
4 |
-
from utils import create_vocab, setup_seed
|
5 |
-
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
6 |
-
import gradio as gr
|
7 |
-
from gradio_rangeslider import RangeSlider
|
8 |
-
import time
|
9 |
-
|
10 |
-
is_stopped = False
|
11 |
-
|
12 |
-
seed = random.randint(0,100000)
|
13 |
-
setup_seed(seed)
|
14 |
-
|
15 |
-
device = torch.device("cpu")
|
16 |
-
vocab_mlm = create_vocab()
|
17 |
-
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
18 |
-
save_path = 'mlm-model-27.pt'
|
19 |
-
train_seqs = pd.read_csv('C0_seq.csv')
|
20 |
-
train_seq = train_seqs['Seq'].tolist()
|
21 |
-
model = torch.load(save_path, map_location=torch.device('cpu'))
|
22 |
-
model = model.to(device)
|
23 |
-
|
24 |
-
def temperature_sampling(logits, temperature):
|
25 |
-
logits = logits / temperature
|
26 |
-
probabilities = torch.softmax(logits, dim=-1)
|
27 |
-
sampled_token = torch.multinomial(probabilities, 1)
|
28 |
-
return sampled_token
|
29 |
-
|
30 |
-
def stop_generation():
|
31 |
-
global is_stopped
|
32 |
-
is_stopped = True
|
33 |
-
return "Generation stopped."
|
34 |
-
|
35 |
-
def CTXGen(X0, X1, X2, τ, g_num):
|
36 |
-
global is_stopped
|
37 |
-
is_stopped = False
|
38 |
-
X3 = "X" * len(X0)
|
39 |
-
msa_data = pd.read_csv('conoData_C0.csv')
|
40 |
-
msa = msa_data['Sequences'].tolist()
|
41 |
-
msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
|
42 |
-
msa = random.choice(msa)
|
43 |
-
X4 = msa.split("|")[3]
|
44 |
-
X5 = msa.split("|")[4]
|
45 |
-
X6 = msa.split("|")[5]
|
46 |
-
|
47 |
-
model.eval()
|
48 |
-
with torch.no_grad():
|
49 |
-
new_seq = None
|
50 |
-
IDs = []
|
51 |
-
generated_seqs = []
|
52 |
-
generated_seqs_FINAL = []
|
53 |
-
cls_probability_all = []
|
54 |
-
act_probability_all = []
|
55 |
-
count = 0
|
56 |
-
gen_num = g_num
|
57 |
-
NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
|
58 |
-
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
|
59 |
-
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
|
60 |
-
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
|
61 |
-
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
|
62 |
-
'<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
|
63 |
-
|
64 |
-
seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
|
65 |
-
padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, new_seq)
|
66 |
-
idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
|
67 |
-
seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]
|
68 |
-
|
69 |
-
seqseq_parent[1] = "[MASK]"
|
70 |
-
input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
|
71 |
-
logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
|
72 |
-
|
73 |
-
cls_mask_logits_parent = logits_parent[0, 1, :]
|
74 |
-
cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=10)
|
75 |
-
|
76 |
-
seqseq_parent[2] = "[MASK]"
|
77 |
-
input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
|
78 |
-
logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
|
79 |
-
act_mask_logits_parent = logits_parent[0, 2, :]
|
80 |
-
act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
|
81 |
-
|
82 |
-
cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
|
83 |
-
act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
|
84 |
-
|
85 |
-
cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
|
86 |
-
act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
|
87 |
-
|
88 |
-
while count < gen_num:
|
89 |
-
gen_len = len(X0)
|
90 |
-
seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
|
91 |
-
vocab_mlm.token_to_idx["X"] = 4
|
92 |
-
|
93 |
-
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
94 |
-
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
95 |
-
|
96 |
-
gen_length = len(input_text)
|
97 |
-
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
98 |
-
|
99 |
-
for i in range(length):
|
100 |
-
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
101 |
-
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
102 |
-
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
103 |
-
attn_idx = torch.tensor(attn_idx).to(device)
|
104 |
-
|
105 |
-
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
106 |
-
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
107 |
-
|
108 |
-
logits = model(idx_seq,idx_msa, attn_idx)
|
109 |
-
mask_logits = logits[0, mask_position.item(), :]
|
110 |
-
|
111 |
-
predicted_token_id = temperature_sampling(mask_logits, τ)
|
112 |
-
|
113 |
-
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
114 |
-
input_text[mask_position.item()] = predicted_token
|
115 |
-
padded_seq[mask_position.item()] = predicted_token.strip()
|
116 |
-
new_seq = padded_seq
|
117 |
-
|
118 |
-
generated_seq = input_text
|
119 |
-
|
120 |
-
generated_seq[1] = "[MASK]"
|
121 |
-
input_ids = vocab_mlm.__getitem__(generated_seq)
|
122 |
-
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
123 |
-
cls_mask_logits = logits[0, 1, :]
|
124 |
-
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
|
125 |
-
|
126 |
-
generated_seq[2] = "[MASK]"
|
127 |
-
input_ids = vocab_mlm.__getitem__(generated_seq)
|
128 |
-
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
129 |
-
act_mask_logits = logits[0, 2, :]
|
130 |
-
act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
|
131 |
-
|
132 |
-
cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
|
133 |
-
act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
|
134 |
-
|
135 |
-
if X1 in cls_pos and X2 in act_pos:
|
136 |
-
cls_proba = cls_probability[cls_pos.index(X1)].item()
|
137 |
-
act_proba = act_probability[act_pos.index(X2)].item()
|
138 |
-
generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
|
139 |
-
if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
|
140 |
-
generated_seqs.append("".join(generated_seq))
|
141 |
-
if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
|
142 |
-
generated_seqs_FINAL.append("".join(generated_seq))
|
143 |
-
cls_probability_all.append(cls_proba)
|
144 |
-
act_probability_all.append(act_proba)
|
145 |
-
IDs.append(count+1)
|
146 |
-
out = pd.DataFrame({
|
147 |
-
'ID':IDs,
|
148 |
-
'Generated_seq': generated_seqs_FINAL,
|
149 |
-
'Subtype': X1,
|
150 |
-
'Subtype_probability': cls_probability_all,
|
151 |
-
'Potency': X2,
|
152 |
-
'Potency_probability': act_probability_all,
|
153 |
-
'Random_seed': seed
|
154 |
-
})
|
155 |
-
out.to_csv("output.csv", index=False, encoding='utf-8-sig')
|
156 |
-
count += 1
|
157 |
-
yield out, "output.csv"
|
158 |
-
return out, "output.csv"
|
159 |
-
|
160 |
-
with gr.Blocks() as demo:
|
161 |
-
gr.Markdown("# Conotoxin Optimization Generation")
|
162 |
-
with gr.Row():
|
163 |
-
|
164 |
-
|
165 |
-
'<
|
166 |
-
'<
|
167 |
-
'
|
168 |
-
'<α
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
183 |
demo.launch()
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import pandas as pd
|
4 |
+
from utils import create_vocab, setup_seed
|
5 |
+
from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
|
6 |
+
import gradio as gr
|
7 |
+
from gradio_rangeslider import RangeSlider
|
8 |
+
import time
|
9 |
+
|
10 |
+
is_stopped = False
|
11 |
+
|
12 |
+
seed = random.randint(0,100000)
|
13 |
+
setup_seed(seed)
|
14 |
+
|
15 |
+
device = torch.device("cpu")
|
16 |
+
vocab_mlm = create_vocab()
|
17 |
+
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
18 |
+
save_path = 'mlm-model-27.pt'
|
19 |
+
train_seqs = pd.read_csv('C0_seq.csv')
|
20 |
+
train_seq = train_seqs['Seq'].tolist()
|
21 |
+
model = torch.load(save_path, map_location=torch.device('cpu'))
|
22 |
+
model = model.to(device)
|
23 |
+
|
24 |
+
def temperature_sampling(logits, temperature):
|
25 |
+
logits = logits / temperature
|
26 |
+
probabilities = torch.softmax(logits, dim=-1)
|
27 |
+
sampled_token = torch.multinomial(probabilities, 1)
|
28 |
+
return sampled_token
|
29 |
+
|
30 |
+
def stop_generation():
|
31 |
+
global is_stopped
|
32 |
+
is_stopped = True
|
33 |
+
return "Generation stopped."
|
34 |
+
|
35 |
+
def CTXGen(X0, X1, X2, τ, g_num):
|
36 |
+
global is_stopped
|
37 |
+
is_stopped = False
|
38 |
+
X3 = "X" * len(X0)
|
39 |
+
msa_data = pd.read_csv('conoData_C0.csv')
|
40 |
+
msa = msa_data['Sequences'].tolist()
|
41 |
+
msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
|
42 |
+
msa = random.choice(msa)
|
43 |
+
X4 = msa.split("|")[3]
|
44 |
+
X5 = msa.split("|")[4]
|
45 |
+
X6 = msa.split("|")[5]
|
46 |
+
|
47 |
+
model.eval()
|
48 |
+
with torch.no_grad():
|
49 |
+
new_seq = None
|
50 |
+
IDs = []
|
51 |
+
generated_seqs = []
|
52 |
+
generated_seqs_FINAL = []
|
53 |
+
cls_probability_all = []
|
54 |
+
act_probability_all = []
|
55 |
+
count = 0
|
56 |
+
gen_num = g_num
|
57 |
+
NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
|
58 |
+
'<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
|
59 |
+
'<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
|
60 |
+
'<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
|
61 |
+
'<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
|
62 |
+
'<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
|
63 |
+
|
64 |
+
seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
|
65 |
+
padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, new_seq)
|
66 |
+
idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
|
67 |
+
seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]
|
68 |
+
|
69 |
+
seqseq_parent[1] = "[MASK]"
|
70 |
+
input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
|
71 |
+
logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
|
72 |
+
|
73 |
+
cls_mask_logits_parent = logits_parent[0, 1, :]
|
74 |
+
cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=10)
|
75 |
+
|
76 |
+
seqseq_parent[2] = "[MASK]"
|
77 |
+
input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
|
78 |
+
logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
|
79 |
+
act_mask_logits_parent = logits_parent[0, 2, :]
|
80 |
+
act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
|
81 |
+
|
82 |
+
cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
|
83 |
+
act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
|
84 |
+
|
85 |
+
cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
|
86 |
+
act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
|
87 |
+
|
88 |
+
while count < gen_num:
|
89 |
+
gen_len = len(X0)
|
90 |
+
seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
|
91 |
+
vocab_mlm.token_to_idx["X"] = 4
|
92 |
+
|
93 |
+
padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
94 |
+
input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
|
95 |
+
|
96 |
+
gen_length = len(input_text)
|
97 |
+
length = gen_length - sum(1 for x in input_text if x != '[MASK]')
|
98 |
+
|
99 |
+
for i in range(length):
|
100 |
+
_, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
|
101 |
+
idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
|
102 |
+
idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
|
103 |
+
attn_idx = torch.tensor(attn_idx).to(device)
|
104 |
+
|
105 |
+
mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
|
106 |
+
mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
|
107 |
+
|
108 |
+
logits = model(idx_seq,idx_msa, attn_idx)
|
109 |
+
mask_logits = logits[0, mask_position.item(), :]
|
110 |
+
|
111 |
+
predicted_token_id = temperature_sampling(mask_logits, τ)
|
112 |
+
|
113 |
+
predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
|
114 |
+
input_text[mask_position.item()] = predicted_token
|
115 |
+
padded_seq[mask_position.item()] = predicted_token.strip()
|
116 |
+
new_seq = padded_seq
|
117 |
+
|
118 |
+
generated_seq = input_text
|
119 |
+
|
120 |
+
generated_seq[1] = "[MASK]"
|
121 |
+
input_ids = vocab_mlm.__getitem__(generated_seq)
|
122 |
+
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
123 |
+
cls_mask_logits = logits[0, 1, :]
|
124 |
+
cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
|
125 |
+
|
126 |
+
generated_seq[2] = "[MASK]"
|
127 |
+
input_ids = vocab_mlm.__getitem__(generated_seq)
|
128 |
+
logits = model(torch.tensor([input_ids]).to(device), idx_msa)
|
129 |
+
act_mask_logits = logits[0, 2, :]
|
130 |
+
act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
|
131 |
+
|
132 |
+
cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
|
133 |
+
act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
|
134 |
+
|
135 |
+
if X1 in cls_pos and X2 in act_pos:
|
136 |
+
cls_proba = cls_probability[cls_pos.index(X1)].item()
|
137 |
+
act_proba = act_probability[act_pos.index(X2)].item()
|
138 |
+
generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
|
139 |
+
if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
|
140 |
+
generated_seqs.append("".join(generated_seq))
|
141 |
+
if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
|
142 |
+
generated_seqs_FINAL.append("".join(generated_seq))
|
143 |
+
cls_probability_all.append(cls_proba)
|
144 |
+
act_probability_all.append(act_proba)
|
145 |
+
IDs.append(count+1)
|
146 |
+
out = pd.DataFrame({
|
147 |
+
'ID':IDs,
|
148 |
+
'Generated_seq': generated_seqs_FINAL,
|
149 |
+
'Subtype': X1,
|
150 |
+
'Subtype_probability': cls_probability_all,
|
151 |
+
'Potency': X2,
|
152 |
+
'Potency_probability': act_probability_all,
|
153 |
+
'Random_seed': seed
|
154 |
+
})
|
155 |
+
out.to_csv("output.csv", index=False, encoding='utf-8-sig')
|
156 |
+
count += 1
|
157 |
+
yield out, "output.csv"
|
158 |
+
return out, "output.csv"
|
159 |
+
|
160 |
+
with gr.Blocks() as demo:
|
161 |
+
gr.Markdown("# Conotoxin Optimization Generation")
|
162 |
+
with gr.Row():
|
163 |
+
X0 = gr.Textbox(label="conotoxin")
|
164 |
+
X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
|
165 |
+
'<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
|
166 |
+
'<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
|
167 |
+
'<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
|
168 |
+
'<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
|
169 |
+
'<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
|
170 |
+
X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
|
171 |
+
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
172 |
+
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
173 |
+
with gr.Row():
|
174 |
+
start_button = gr.Button("Start Generation")
|
175 |
+
stop_button = gr.Button("Stop Generation")
|
176 |
+
with gr.Row():
|
177 |
+
output_df = gr.DataFrame(label="Generated Conotoxins")
|
178 |
+
with gr.Row():
|
179 |
+
output_file = gr.File(label="Download generated conotoxins")
|
180 |
+
|
181 |
+
start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num], outputs=[output_df, output_file])
|
182 |
+
stop_button.click(stop_generation, outputs=None)
|
183 |
+
|
184 |
demo.launch()
|