File size: 294 Bytes
c9d7b4f |
1 2 3 4 5 6 7 8 9 10 |
import os
from transformers import AutoModelForCausalLM
model_name = os.getenv('MODEL_NAME')
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype="bfloat16",
)
print(model_name, sum(p.numel() for p in model.parameters()), model.num_parameters())
|