File size: 294 Bytes
c9d7b4f
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import os
from transformers import AutoModelForCausalLM
model_name = os.getenv('MODEL_NAME')
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype="bfloat16",
)
print(model_name, sum(p.numel() for p in model.parameters()), model.num_parameters())