File size: 1,311 Bytes
1423dc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from transformers import MambaConfig, MambaModel, Mamba2Config, Mamba2Model

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name()}")
    print(f"CUDA version: {torch.version.cuda}")

batch, channel, height, width = 256, 16, 8, 8
x = torch.randn(batch, channel, height, width).to("cuda")
print(f'x: {x.shape}')

B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1)  # [B, H, W, C]
print(f'Permuted x: {x.shape}')

x = x.reshape(B, H * W, C)  # [B, L, C], L = H * W
print(f'Reshaped x: {x.shape}')

# Initializing a Mamba configuration
configuration = MambaConfig(vocab_size=0, hidden_size=channel, num_hidden_layers=2)
# configuration = Mamba2Config(hidden_size=channel)

# Initializing a model (with random weights) from the configuration
model = MambaModel(configuration).to("cuda")
# model = Mamba2Model(configuration).to("cuda")
print(f'Model: {model}')

# Accessing the model configuration
configuration = model.config
print(f'Configuration: {configuration}')

# y = model(inputs_embeds=x).last_hidden_state
y = model(inputs_embeds=x, return_dict=True)[0]
print(f'y: {y.shape}')

y = y.reshape(B, H, W, -1)
print(f'Reshaped y: {y.shape}')

y = y.permute(0, 3, 1, 2)  # [B, C, H, W]
print(f'Permuted y: {y.shape}')