Spaces:
Configuration error
Configuration error
import torch.nn as nn | |
import numpy as np | |
from abc import abstractmethod | |
class BaseModel(nn.Module): | |
""" | |
Base class for all models | |
""" | |
def forward(self, *inputs): | |
""" | |
Forward pass logic | |
:return: Model output | |
""" | |
raise NotImplementedError | |
def __str__(self): | |
""" | |
Model prints with number of trainable parameters | |
""" | |
model_parameters = filter(lambda p: p.requires_grad, self.parameters()) | |
params = sum([np.prod(p.size()) for p in model_parameters]) | |
return super().__str__() + '\nTrainable parameters: {}'.format(params) | |