File size: 660 Bytes
64af3aa
 
f60b4b0
64af3aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
from torch.nn import functional as F


class LinearImplicitBackward(nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


class LinearBackward(nn.Module):
    has_backward = True

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


class LinearNoBackward(nn.Module):
    has_backward = False

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)


__all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"]