danieldk's picture
danieldk HF Staff
Reflect repo name change
2bc2c3b
import torch
import torch.nn as nn
from torch.nn import functional as F
class LinearImplicitBackward(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class LinearBackward(nn.Module):
has_backward = True
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
class LinearNoBackward(nn.Module):
has_backward = False
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight, self.bias)
__all__ = ["LinearImplicitBackward", "LinearBackward", "LinearNoBackward"]