Spaces:
Running
Running
Instantiate a big model | |
A barrier to accessing very large pretrained models is the amount of memory required. When loading a pretrained PyTorch model, you usually: | |
Create a model with random weights. | |
Load your pretrained weights. | |
Put those pretrained weights in the model. | |
The first two steps both require a full version of the model in memory and if the model weighs several GBs, you may not have enough memory for two copies of it. This problem is amplified in distributed training environments because each process loads a pretrained model and stores two copies in memory. | |
[!TIP] | |
The randomly created model is initialized with "empty" tensors, which take space in memory without filling it. The random values are whatever was in this chunk of memory at the time. To improve loading speed, the _fast_init parameter is set to True by default to skip the random initialization for all weights that are correctly loaded. | |
This guide will show you how Transformers can help you load large pretrained models despite their memory requirements. | |
Sharded checkpoints | |
From Transformers v4.18.0, a checkpoint larger than 10GB is automatically sharded by the [~PreTrainedModel.save_pretrained] method. It is split into several smaller partial checkpoints and creates an index file that maps parameter names to the files they're stored in. | |
The maximum shard size is controlled with the max_shard_size parameter, but by default it is 5GB, because it is easier to run on free-tier GPU instances without running out of memory. | |
For example, let's shard BioMistral/BioMistral-7B. | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
model.save_pretrained(tmp_dir, max_shard_size="5GB") | |
print(sorted(os.listdir(tmp_dir))) | |
['config.json', 'generation_config.json', 'model-00001-of-00006.safetensors', 'model-00002-of-00006.safetensors', 'model-00003-of-00006.safetensors', 'model-00004-of-00006.safetensors', 'model-00005-of-00006.safetensors', 'model-00006-of-00006.safetensors', 'model.safetensors.index.json'] | |
The sharded checkpoint is reloaded with the [~PreTrainedModel.from_pretrained] method. | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
model.save_pretrained(tmp_dir, max_shard_size="5GB") | |
new_model = AutoModel.from_pretrained(tmp_dir) | |
The main advantage of sharded checkpoints for big models is that each shard is loaded after the previous one, which caps the memory usage to only the model size and the largest shard size. | |
You could also directly load a sharded checkpoint inside a model without the [~PreTrainedModel.from_pretrained] method (similar to PyTorch's load_state_dict() method for a full checkpoint). In this case, use the [~modeling_utils.load_sharded_checkpoint] method. | |
from transformers.modeling_utils import load_sharded_checkpoint | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
model.save_pretrained(tmp_dir, max_shard_size="5GB") | |
load_sharded_checkpoint(model, tmp_dir) | |
Shard metadata | |
The index file determines which keys are in the checkpoint and where the corresponding weights are stored. This file is loaded like any other JSON file and you can get a dictionary from it. | |
import json | |
with tempfile.TemporaryDirectory() as tmp_dir: | |
model.save_pretrained(tmp_dir, max_shard_size="5GB") | |
with open(os.path.join(tmp_dir, "model.safetensors.index.json"), "r") as f: | |
index = json.load(f) | |
print(index.keys()) | |
dict_keys(['metadata', 'weight_map']) | |
The metadata key provides the total model size. | |
index["metadata"] | |
{'total_size': 28966928384} | |
The weight_map key maps each parameter name (typically state_dict in a PyTorch model) to the shard it's stored in. | |
index["weight_map"] | |
{'lm_head.weight': 'model-00006-of-00006.safetensors', | |
'model.embed_tokens.weight': 'model-00001-of-00006.safetensors', | |
'model.layers.0.input_layernorm.weight': 'model-00001-of-00006.safetensors', | |
'model.layers.0.mlp.down_proj.weight': 'model-00001-of-00006.safetensors', | |
} | |
Accelerate's Big Model Inference | |
[!TIP] | |
Make sure you have Accelerate v0.9.0 or later and PyTorch v1.9.0 or later installed. |