File size: 1,685 Bytes
326d9e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e22daf
326d9e6
 
 
1e22daf
326d9e6
 
 
 
 
1e22daf
 
 
 
326d9e6
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F


class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, hidden_size3, num_classes):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.dropout = nn.Dropout(0.1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.dropout = nn.Dropout(0.1)
        self.fc3 = nn.Linear(hidden_size2, hidden_size3)
        self.dropout = nn.Dropout(0.1)
        self.fc4 = nn.Linear(hidden_size3, num_classes)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = self.fc4(out)
        return out


def load_models():
    model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
    model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
    model_protT5.eval().to("cuda")

    model_cat = NeuralNet(2304, 200, 100, 100, 2)
    model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
    model_cat.eval().to("cuda")

    return model_protT5, model_cat


def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
    device = next(model_protT5.parameters()).device
    X_protT5 = X_protT5.to(device)
    X_concat = X_concat.to(device)
    
    with torch.no_grad():
        outputs1 = model_cat(X_concat)
        outputs2 = model_protT5(X_protT5)
        ensemble_outputs = weight1 * outputs1 + weight2 * outputs2
        _, predicted = torch.max(ensemble_outputs.data, 1)
    return predicted