|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class LinearImplicitBackward(nn.Module): |
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return F.linear(input, self.weight, self.bias) |
|
|
|
|
|
class LinearBackward(nn.Module): |
|
has_backward = True |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return F.linear(input, self.weight, self.bias) |
|
|
|
|
|
class LinearNoBackward(nn.Module): |
|
has_backward = False |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return F.linear(input, self.weight, self.bias) |
|
|
|
|
|
__all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"] |
|
|