File size: 293 Bytes
1ca9e3b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
def save_attn_gradients(self, attn_gradients):
    self.attn_gradients = attn_gradients
    
def get_attn_gradients(self):
    return self.attn_gradients

def save_attn_map(self, attention_map):
    self.attention_map = attention_map
    
def get_attn_map(self):
    return self.attention_map