from torch import nn | |
import torchvision | |
class BadNet(nn.Module): | |
# def __init__(self, input_channel, output_label) -> None: | |
# 目前只假设cifar10 | |
def __init__(self, output_label) -> None: | |
super(BadNet, self).__init__() | |
self.model = torchvision.models.resnet18(pretrained=True) | |
num_features = self.model.fc.out_features | |
self.fc = nn.Linear(in_features=num_features, out_features=output_label) | |
def forward(self, xs): | |
out = self.model(xs) | |
return self.fc(out) | |
# class BadNet(nn.Module): | |
# def __init__(self, input_channels, output_num): | |
# super().__init__() | |
# self.conv1 = nn.Sequential( | |
# nn.Conv2d(in_channels=input_channels, out_channels=16, kernel_size=5, stride=1), | |
# nn.ReLU(), | |
# nn.AvgPool2d(kernel_size=2, stride=2) | |
# ) | |
# self.conv2 = nn.Sequential( | |
# nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1), | |
# nn.ReLU(), | |
# nn.AvgPool2d(kernel_size=2, stride=2) | |
# ) | |
# fc1_input_features = 800 if input_channels == 3 else 512 | |
# self.fc1 = nn.Sequential( | |
# nn.Linear(in_features=fc1_input_features, out_features=512), | |
# nn.ReLU() | |
# ) | |
# self.fc2 = nn.Sequential( | |
# nn.Linear(in_features=512, out_features=output_num), | |
# nn.Softmax(dim=-1) | |
# ) | |
# self.dropout = nn.Dropout(p=.5) | |
# def forward(self, x): | |
# x = self.conv1(x) | |
# x = self.conv2(x) | |
# print(x.shape) | |
# x = x.view(x.size(0), -1) | |
# x = self.fc1(x) | |
# x = self.fc2(x) | |
# return x |