def save_attn_gradients(self, attn_gradients): | |
self.attn_gradients = attn_gradients | |
def get_attn_gradients(self): | |
return self.attn_gradients | |
def save_attn_map(self, attention_map): | |
self.attention_map = attention_map | |
def get_attn_map(self): | |
return self.attention_map | |