hz2475's picture
init
72f684c
raw
history blame contribute delete
573 Bytes
import torch.nn as nn
class Adapter(nn.Module):
def __init__(self, in_features, out_features, adapter_norm="layer_norm", query_length=1, dropout_prob=0.1):
super().__init__()
self.fc = nn.Linear(in_features, out_features)
self.norm = nn.LayerNorm(out_features) if adapter_norm == "layer_norm" else None
self.dropout = nn.Dropout(dropout_prob)
self.query_length = query_length
def forward(self, x):
out = self.fc(x)
if self.norm is not None:
out = self.norm(out)
return self.dropout(out)