import os | |
from transformers import AutoModelForCausalLM | |
model_name = os.getenv('MODEL_NAME') | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype="bfloat16", | |
) | |
print(model_name, sum(p.numel() for p in model.parameters()), model.num_parameters()) | |