Spaces:
Runtime error
Runtime error
import torch | |
from typing import List, Tuple | |
class KVCache(): | |
def __init__(self) -> None: | |
self.key_cache: List[torch.Tensor] = [] | |
self.value_cache: List[torch.Tensor] = [] | |
def num_items(self) -> int: | |
if len(self.key_cache) == 0: | |
return 0 | |
else: | |
return self.key_cache[0].shape[-2] | |
def update( | |
self, | |
key_states: torch.Tensor, | |
value_states: torch.Tensor, | |
layer_idx: int, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
if len(self.key_cache) <= layer_idx: | |
self.key_cache.append(key_states) | |
self.value_cache.append(value_states) | |
else: | |
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | |
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | |
return self.key_cache[layer_idx], self.value_cache[layer_idx] |