Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,685 Bytes
326d9e6 1e22daf 326d9e6 1e22daf 326d9e6 1e22daf 326d9e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
import torch.nn as nn
import torch.nn.functional as F
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, hidden_size3, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size1)
self.dropout = nn.Dropout(0.1)
self.fc2 = nn.Linear(hidden_size1, hidden_size2)
self.dropout = nn.Dropout(0.1)
self.fc3 = nn.Linear(hidden_size2, hidden_size3)
self.dropout = nn.Dropout(0.1)
self.fc4 = nn.Linear(hidden_size3, num_classes)
def forward(self, x):
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
out = F.relu(self.fc3(out))
out = self.fc4(out)
return out
def load_models():
model_protT5 = NeuralNet(1024, 200, 100, 50, 2)
model_protT5.load_state_dict(torch.load("checkpoints/model17-protT5.pt", map_location=torch.device("cpu")))
model_protT5.eval().to("cuda")
model_cat = NeuralNet(2304, 200, 100, 100, 2)
model_cat.load_state_dict(torch.load("checkpoints/model-esm-protT5-5.pt", map_location=torch.device("cpu")))
model_cat.eval().to("cuda")
return model_protT5, model_cat
def predict_ensemble(X_protT5, X_concat, model_protT5, model_cat, weight1=0.60, weight2=0.30):
device = next(model_protT5.parameters()).device
X_protT5 = X_protT5.to(device)
X_concat = X_concat.to(device)
with torch.no_grad():
outputs1 = model_cat(X_concat)
outputs2 = model_protT5(X_protT5)
ensemble_outputs = weight1 * outputs1 + weight2 * outputs2
_, predicted = torch.max(ensemble_outputs.data, 1)
return predicted
|