File size: 673 Bytes
72fc481
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch.nn as nn
import numpy as np
from abc import abstractmethod


class BaseModel(nn.Module):
    """

    Base class for all models

    """
    @abstractmethod
    def forward(self, *inputs):
        """

        Forward pass logic



        :return: Model output

        """
        raise NotImplementedError

    def __str__(self):
        """

        Model prints with number of trainable parameters

        """
        model_parameters = filter(lambda p: p.requires_grad, self.parameters())
        params = sum([np.prod(p.size()) for p in model_parameters])
        return super().__str__() + '\nTrainable parameters: {}'.format(params)