oucgc1996 commited on
Commit
1be6080
·
verified ·
1 Parent(s): 44d4ac0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -182
app.py CHANGED
@@ -1,183 +1,184 @@
1
- import torch
2
- import random
3
- import pandas as pd
4
- from utils import create_vocab, setup_seed
5
- from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
- import gradio as gr
7
- from gradio_rangeslider import RangeSlider
8
- import time
9
-
10
- is_stopped = False
11
-
12
- seed = random.randint(0,100000)
13
- setup_seed(seed)
14
-
15
- device = torch.device("cpu")
16
- vocab_mlm = create_vocab()
17
- vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
- save_path = 'mlm-model-27.pt'
19
- train_seqs = pd.read_csv('C0_seq.csv')
20
- train_seq = train_seqs['Seq'].tolist()
21
- model = torch.load(save_path, map_location=torch.device('cpu'))
22
- model = model.to(device)
23
-
24
- def temperature_sampling(logits, temperature):
25
- logits = logits / temperature
26
- probabilities = torch.softmax(logits, dim=-1)
27
- sampled_token = torch.multinomial(probabilities, 1)
28
- return sampled_token
29
-
30
- def stop_generation():
31
- global is_stopped
32
- is_stopped = True
33
- return "Generation stopped."
34
-
35
- def CTXGen(X0, X1, X2, τ, g_num):
36
- global is_stopped
37
- is_stopped = False
38
- X3 = "X" * len(X0)
39
- msa_data = pd.read_csv('conoData_C0.csv')
40
- msa = msa_data['Sequences'].tolist()
41
- msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
42
- msa = random.choice(msa)
43
- X4 = msa.split("|")[3]
44
- X5 = msa.split("|")[4]
45
- X6 = msa.split("|")[5]
46
-
47
- model.eval()
48
- with torch.no_grad():
49
- new_seq = None
50
- IDs = []
51
- generated_seqs = []
52
- generated_seqs_FINAL = []
53
- cls_probability_all = []
54
- act_probability_all = []
55
- count = 0
56
- gen_num = g_num
57
- NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
58
- '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
59
- '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
60
- '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
61
- '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
62
- '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
63
-
64
- seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
65
- padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, new_seq)
66
- idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
67
- seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]
68
-
69
- seqseq_parent[1] = "[MASK]"
70
- input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
71
- logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
72
-
73
- cls_mask_logits_parent = logits_parent[0, 1, :]
74
- cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=10)
75
-
76
- seqseq_parent[2] = "[MASK]"
77
- input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
78
- logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
79
- act_mask_logits_parent = logits_parent[0, 2, :]
80
- act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
81
-
82
- cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
83
- act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
84
-
85
- cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
86
- act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
87
-
88
- while count < gen_num:
89
- gen_len = len(X0)
90
- seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
91
- vocab_mlm.token_to_idx["X"] = 4
92
-
93
- padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
94
- input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
95
-
96
- gen_length = len(input_text)
97
- length = gen_length - sum(1 for x in input_text if x != '[MASK]')
98
-
99
- for i in range(length):
100
- _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
101
- idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
102
- idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
103
- attn_idx = torch.tensor(attn_idx).to(device)
104
-
105
- mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
106
- mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
107
-
108
- logits = model(idx_seq,idx_msa, attn_idx)
109
- mask_logits = logits[0, mask_position.item(), :]
110
-
111
- predicted_token_id = temperature_sampling(mask_logits, τ)
112
-
113
- predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
114
- input_text[mask_position.item()] = predicted_token
115
- padded_seq[mask_position.item()] = predicted_token.strip()
116
- new_seq = padded_seq
117
-
118
- generated_seq = input_text
119
-
120
- generated_seq[1] = "[MASK]"
121
- input_ids = vocab_mlm.__getitem__(generated_seq)
122
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
123
- cls_mask_logits = logits[0, 1, :]
124
- cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
125
-
126
- generated_seq[2] = "[MASK]"
127
- input_ids = vocab_mlm.__getitem__(generated_seq)
128
- logits = model(torch.tensor([input_ids]).to(device), idx_msa)
129
- act_mask_logits = logits[0, 2, :]
130
- act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
131
-
132
- cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
133
- act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
134
-
135
- if X1 in cls_pos and X2 in act_pos:
136
- cls_proba = cls_probability[cls_pos.index(X1)].item()
137
- act_proba = act_probability[act_pos.index(X2)].item()
138
- generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
139
- if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
140
- generated_seqs.append("".join(generated_seq))
141
- if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
142
- generated_seqs_FINAL.append("".join(generated_seq))
143
- cls_probability_all.append(cls_proba)
144
- act_probability_all.append(act_proba)
145
- IDs.append(count+1)
146
- out = pd.DataFrame({
147
- 'ID':IDs,
148
- 'Generated_seq': generated_seqs_FINAL,
149
- 'Subtype': X1,
150
- 'Subtype_probability': cls_probability_all,
151
- 'Potency': X2,
152
- 'Potency_probability': act_probability_all,
153
- 'Random_seed': seed
154
- })
155
- out.to_csv("output.csv", index=False, encoding='utf-8-sig')
156
- count += 1
157
- yield out, "output.csv"
158
- return out, "output.csv"
159
-
160
- with gr.Blocks() as demo:
161
- gr.Markdown("# Conotoxin Optimization Generation")
162
- with gr.Row():
163
- X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
164
- '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
165
- '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
166
- '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
167
- '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
168
- '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
169
- X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
170
- τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
171
- g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
172
- with gr.Row():
173
- start_button = gr.Button("Start Generation")
174
- stop_button = gr.Button("Stop Generation")
175
- with gr.Row():
176
- output_df = gr.DataFrame(label="Generated Conotoxins")
177
- with gr.Row():
178
- output_file = gr.File(label="Download generated conotoxins")
179
-
180
- start_button.click(CTXGen, inputs=[X1, X2, τ, g_num], outputs=[output_df, output_file])
181
- stop_button.click(stop_generation, outputs=None)
182
-
 
183
  demo.launch()
 
1
+ import torch
2
+ import random
3
+ import pandas as pd
4
+ from utils import create_vocab, setup_seed
5
+ from dataset_mlm import get_paded_token_idx_gen, add_tokens_to_vocab
6
+ import gradio as gr
7
+ from gradio_rangeslider import RangeSlider
8
+ import time
9
+
10
+ is_stopped = False
11
+
12
+ seed = random.randint(0,100000)
13
+ setup_seed(seed)
14
+
15
+ device = torch.device("cpu")
16
+ vocab_mlm = create_vocab()
17
+ vocab_mlm = add_tokens_to_vocab(vocab_mlm)
18
+ save_path = 'mlm-model-27.pt'
19
+ train_seqs = pd.read_csv('C0_seq.csv')
20
+ train_seq = train_seqs['Seq'].tolist()
21
+ model = torch.load(save_path, map_location=torch.device('cpu'))
22
+ model = model.to(device)
23
+
24
+ def temperature_sampling(logits, temperature):
25
+ logits = logits / temperature
26
+ probabilities = torch.softmax(logits, dim=-1)
27
+ sampled_token = torch.multinomial(probabilities, 1)
28
+ return sampled_token
29
+
30
+ def stop_generation():
31
+ global is_stopped
32
+ is_stopped = True
33
+ return "Generation stopped."
34
+
35
+ def CTXGen(X0, X1, X2, τ, g_num):
36
+ global is_stopped
37
+ is_stopped = False
38
+ X3 = "X" * len(X0)
39
+ msa_data = pd.read_csv('conoData_C0.csv')
40
+ msa = msa_data['Sequences'].tolist()
41
+ msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
42
+ msa = random.choice(msa)
43
+ X4 = msa.split("|")[3]
44
+ X5 = msa.split("|")[4]
45
+ X6 = msa.split("|")[5]
46
+
47
+ model.eval()
48
+ with torch.no_grad():
49
+ new_seq = None
50
+ IDs = []
51
+ generated_seqs = []
52
+ generated_seqs_FINAL = []
53
+ cls_probability_all = []
54
+ act_probability_all = []
55
+ count = 0
56
+ gen_num = g_num
57
+ NON_AA = ["B","O","U","Z","X",'<K16>', '<α1β1γδ>', '<Ca22>', '<AChBP>', '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<α4β2>',
58
+ '<GluN2B>', '<α75HT3>', '<Na14>', '<α7>', '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>',
59
+ '<Ca12>', '<Na16>', '<α6α3β2>', '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>',
60
+ '<Kshaker>', '<α3β4>', '<Na18>', '<α3β2>', '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>',
61
+ '<Na13>', '<Na12>', '<Na15>', '<α4β4>', '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>',
62
+ '<α9α10>','<α6α3β4>', '<NaTTXS>', '<Na17>','<high>','<low>','[UNK]','[SEP]','[PAD]','[CLS]','[MASK]']
63
+
64
+ seq_parent = [f"{X1}|{X2}|{X0}|{X4}|{X5}|{X6}"]
65
+ padded_seqseq_parent, _, idx_msaseq_parent, _ = get_paded_token_idx_gen(vocab_mlm, seq_parent, new_seq)
66
+ idx_msaseq_parent = torch.tensor(idx_msaseq_parent).unsqueeze(0).to(device)
67
+ seqseq_parent = ["[MASK]" if i=="X" else i for i in padded_seqseq_parent]
68
+
69
+ seqseq_parent[1] = "[MASK]"
70
+ input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
71
+ logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
72
+
73
+ cls_mask_logits_parent = logits_parent[0, 1, :]
74
+ cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=10)
75
+
76
+ seqseq_parent[2] = "[MASK]"
77
+ input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
78
+ logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
79
+ act_mask_logits_parent = logits_parent[0, 2, :]
80
+ act_probability_parent, act_mask_probs_parent = torch.topk((torch.softmax(act_mask_logits_parent, dim=-1)), k=2)
81
+
82
+ cls_pos_parent = vocab_mlm.to_tokens(list(cls_mask_probs_parent))
83
+ act_pos_parent = vocab_mlm.to_tokens(list(act_mask_probs_parent))
84
+
85
+ cls_proba_parent = cls_probability_parent[cls_pos_parent.index(X1)].item()
86
+ act_proba_parent = act_probability_parent[act_pos_parent.index(X2)].item()
87
+
88
+ while count < gen_num:
89
+ gen_len = len(X0)
90
+ seq = [f"{X1}|{X2}|{X3}|{X4}|{X5}|{X6}"]
91
+ vocab_mlm.token_to_idx["X"] = 4
92
+
93
+ padded_seq, _, _, _ = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
94
+ input_text = ["[MASK]" if i=="X" else i for i in padded_seq]
95
+
96
+ gen_length = len(input_text)
97
+ length = gen_length - sum(1 for x in input_text if x != '[MASK]')
98
+
99
+ for i in range(length):
100
+ _, idx_seq, idx_msa, attn_idx = get_paded_token_idx_gen(vocab_mlm, seq, new_seq)
101
+ idx_seq = torch.tensor(idx_seq).unsqueeze(0).to(device)
102
+ idx_msa = torch.tensor(idx_msa).unsqueeze(0).to(device)
103
+ attn_idx = torch.tensor(attn_idx).to(device)
104
+
105
+ mask_positions = [j for j in range(gen_length) if input_text[j] == "[MASK]"]
106
+ mask_position = torch.tensor([mask_positions[torch.randint(len(mask_positions), (1,))]])
107
+
108
+ logits = model(idx_seq,idx_msa, attn_idx)
109
+ mask_logits = logits[0, mask_position.item(), :]
110
+
111
+ predicted_token_id = temperature_sampling(mask_logits, τ)
112
+
113
+ predicted_token = vocab_mlm.to_tokens(int(predicted_token_id))
114
+ input_text[mask_position.item()] = predicted_token
115
+ padded_seq[mask_position.item()] = predicted_token.strip()
116
+ new_seq = padded_seq
117
+
118
+ generated_seq = input_text
119
+
120
+ generated_seq[1] = "[MASK]"
121
+ input_ids = vocab_mlm.__getitem__(generated_seq)
122
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
123
+ cls_mask_logits = logits[0, 1, :]
124
+ cls_probability, cls_mask_probs = torch.topk((torch.softmax(cls_mask_logits, dim=-1)), k=10)
125
+
126
+ generated_seq[2] = "[MASK]"
127
+ input_ids = vocab_mlm.__getitem__(generated_seq)
128
+ logits = model(torch.tensor([input_ids]).to(device), idx_msa)
129
+ act_mask_logits = logits[0, 2, :]
130
+ act_probability, act_mask_probs = torch.topk((torch.softmax(act_mask_logits, dim=-1)), k=2)
131
+
132
+ cls_pos = vocab_mlm.to_tokens(list(cls_mask_probs))
133
+ act_pos = vocab_mlm.to_tokens(list(act_mask_probs))
134
+
135
+ if X1 in cls_pos and X2 in act_pos:
136
+ cls_proba = cls_probability[cls_pos.index(X1)].item()
137
+ act_proba = act_probability[act_pos.index(X2)].item()
138
+ generated_seq = generated_seq[generated_seq.index('[MASK]') + 2:generated_seq.index('[SEP]')]
139
+ if cls_proba>=cls_proba_parent and act_proba>=act_proba_parent and generated_seq.count('C') % 2 == 0 and len("".join(generated_seq)) == gen_len:
140
+ generated_seqs.append("".join(generated_seq))
141
+ if "".join(generated_seq) not in train_seq and "".join(generated_seq) not in generated_seqs[0:-1] and all(x not in NON_AA for x in generated_seq):
142
+ generated_seqs_FINAL.append("".join(generated_seq))
143
+ cls_probability_all.append(cls_proba)
144
+ act_probability_all.append(act_proba)
145
+ IDs.append(count+1)
146
+ out = pd.DataFrame({
147
+ 'ID':IDs,
148
+ 'Generated_seq': generated_seqs_FINAL,
149
+ 'Subtype': X1,
150
+ 'Subtype_probability': cls_probability_all,
151
+ 'Potency': X2,
152
+ 'Potency_probability': act_probability_all,
153
+ 'Random_seed': seed
154
+ })
155
+ out.to_csv("output.csv", index=False, encoding='utf-8-sig')
156
+ count += 1
157
+ yield out, "output.csv"
158
+ return out, "output.csv"
159
+
160
+ with gr.Blocks() as demo:
161
+ gr.Markdown("# Conotoxin Optimization Generation")
162
+ with gr.Row():
163
+ X0 = gr.Textbox(label="conotoxin")
164
+ X1 = gr.Dropdown(choices=['<α7>','<AChBP>','<α4β2>','<α3β4>','<Ca22>','<α3β2>', '<Na12>','<α9α10>','<K16>', '<α1β1γδ>',
165
+ '<K13>', '<α1BAR>', '<α1β1ε>', '<α1AAR>', '<GluN3A>', '<GluN2B>', '<α75HT3>', '<Na14>',
166
+ '<GluN2C>', '<NET>', '<NavBh>', '<α6β3β4>', '<Na11>', '<Ca13>', '<Ca12>', '<Na16>', '<α6α3β2>',
167
+ '<GluN2A>', '<GluN2D>', '<K17>', '<α1β1δε>', '<GABA>', '<α9>', '<K12>', '<Kshaker>', '<Na18>',
168
+ '<α6α3β2β3>', '<α1β1δ>', '<α6α3β4β3>', '<α2β2>','<α6β4>', '<α2β4>','<Na13>', '<Na15>', '<α4β4>',
169
+ '<α7α6β2>', '<α1β1γ>', '<NaTTXR>', '<K11>', '<Ca23>', '<α6α3β4>', '<NaTTXS>', '<Na17>'], label="Subtype")
170
+ X2 = gr.Dropdown(choices=['<high>','<low>'], label="Potency")
171
+ τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
172
+ g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
173
+ with gr.Row():
174
+ start_button = gr.Button("Start Generation")
175
+ stop_button = gr.Button("Stop Generation")
176
+ with gr.Row():
177
+ output_df = gr.DataFrame(label="Generated Conotoxins")
178
+ with gr.Row():
179
+ output_file = gr.File(label="Download generated conotoxins")
180
+
181
+ start_button.click(CTXGen, inputs=[X0, X1, X2, τ, g_num], outputs=[output_df, output_file])
182
+ stop_button.click(stop_generation, outputs=None)
183
+
184
  demo.launch()