vlm-o / model /utils /kv_cache.py
veerpareek's picture
Upload 35 files
577d9ca verified
raw
history blame contribute delete
947 Bytes
import torch
from typing import List, Tuple
class KVCache():
def __init__(self) -> None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
def num_items(self) -> int:
if len(self.key_cache) == 0:
return 0
else:
return self.key_cache[0].shape[-2]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]