Spaces:
Running
Running
File size: 573 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
import torch.nn as nn
class Adapter(nn.Module):
def __init__(self, in_features, out_features, adapter_norm="layer_norm", query_length=1, dropout_prob=0.1):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.norm = nn.LayerNorm(out_features) if adapter_norm == "layer_norm" else None
self.dropout = nn.Dropout(dropout_prob)
self.query_length = query_length
def forward(self, x):
out = self.fc(x)
if self.norm is not None:
out = self.norm(out)
return self.dropout(out) |