oucgc1996 commited on
Commit
52628d3
·
verified ·
1 Parent(s): 30cee1d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -40,13 +40,15 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
40
  msa_data = pd.read_csv('conoData_C0.csv')
41
  msa = msa_data['Sequences'].tolist()
42
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
43
- msa = random.choice(msa)
44
- X4 = msa.split("|")[3]
45
- X5 = msa.split("|")[4]
46
- X6 = msa.split("|")[5]
47
- print(X4)
48
- print(X5)
49
- print(X6)
 
 
50
  model.eval()
51
  with torch.no_grad():
52
  new_seq = None
@@ -74,7 +76,7 @@ def CTXGen(X0, X1, X2, τ, g_num, model_name):
74
  logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
75
 
76
  cls_mask_logits_parent = logits_parent[0, 1, :]
77
- cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=10)
78
 
79
  seqseq_parent[2] = "[MASK]"
80
  input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)
 
40
  msa_data = pd.read_csv('conoData_C0.csv')
41
  msa = msa_data['Sequences'].tolist()
42
  msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
43
+ if not msa:
44
+ X4 = ""
45
+ X5 = ""
46
+ X6 = ""
47
+ else:
48
+ msa = random.choice(msa)
49
+ X4 = msa.split("|")[3]
50
+ X5 = msa.split("|")[4]
51
+ X6 = msa.split("|")[5]
52
  model.eval()
53
  with torch.no_grad():
54
  new_seq = None
 
76
  logits_parent = model(torch.tensor([input_ids_parent]).to(device), idx_msaseq_parent)
77
 
78
  cls_mask_logits_parent = logits_parent[0, 1, :]
79
+ cls_probability_parent, cls_mask_probs_parent = torch.topk((torch.softmax(cls_mask_logits_parent, dim=-1)), k=53)
80
 
81
  seqseq_parent[2] = "[MASK]"
82
  input_ids_parent = vocab_mlm.__getitem__(seqseq_parent)