import torch class LinearClassifier(torch.nn.Module): def __init__(self, input_dim): super(LinearClassifier, self).__init__() self.linear = torch.nn.Linear(input_dim, 1) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): x = self.linear(x) x = self.sigmoid(x) return x