Spaces:
Running
Running
Accelerate's Big Model Inference | |
[!TIP] | |
Make sure you have Accelerate v0.9.0 or later and PyTorch v1.9.0 or later installed. | |
From Transformers v4.20.0, the [~PreTrainedModel.from_pretrained] method is supercharged with Accelerate's Big Model Inference feature to efficiently handle really big models! Big Model Inference creates a model skeleton on PyTorch's meta device. The randomly initialized parameters are only created when the pretrained weights are loaded. This way, you aren't keeping two copies of the model in memory at the same time (one for the randomly initialized model and one for the pretrained weights), and the maximum memory consumed is only the full model size. | |
To enable Big Model Inference in Transformers, set low_cpu_mem_usage=True in the [~PreTrainedModel.from_pretrained] method. | |
from transformers import AutoModelForCausalLM | |
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", low_cpu_mem_usage=True) | |
Accelerate automatically dispatches the model weights across all available devices, starting with the fastest device (GPU) first and then offloading to the slower devices (CPU and even hard drive). This is enabled by setting device_map="auto" in the [~PreTrainedModel.from_pretrained] method. When you pass the device_map parameter, low_cpu_mem_usage is automatically set to True so you don't need to specify it. | |
from transformers import AutoModelForCausalLM | |
these loading methods are equivalent | |
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto") | |
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto", low_cpu_mem_usage=True) | |
You can also write your own device_map by mapping each layer to a device. It should map all model parameters to a device, but you don't have to detail where all the submodules of a layer go if the entire layer is on the same device. | |
python | |
device_map = {"model.layers.1": 0, "model.layers.14": 1, "model.layers.31": "cpu", "lm_head": "disk"} | |
Access hf_device_map attribute to see how Accelerate split the model across devices. | |
py | |
gemma.hf_device_map | |
python out | |
{'model.embed_tokens': 0, | |
'model.layers.0': 0, | |
'model.layers.1': 0, | |
'model.layers.2': 0, | |
'model.layers.3': 0, | |
'model.layers.4': 0, | |
'model.layers.5': 0, | |
'model.layers.6': 0, | |
'model.layers.7': 0, | |
'model.layers.8': 0, | |
'model.layers.9': 0, | |
'model.layers.10': 0, | |
'model.layers.11': 0, | |
'model.layers.12': 0, | |
'model.layers.13': 0, | |
'model.layers.14': 'cpu', | |
'model.layers.15': 'cpu', | |
'model.layers.16': 'cpu', | |
'model.layers.17': 'cpu', | |
'model.layers.18': 'cpu', | |
'model.layers.19': 'cpu', | |
'model.layers.20': 'cpu', | |
'model.layers.21': 'cpu', | |
'model.layers.22': 'cpu', | |
'model.layers.23': 'cpu', | |
'model.layers.24': 'cpu', | |
'model.layers.25': 'cpu', | |
'model.layers.26': 'cpu', | |
'model.layers.27': 'cpu', | |
'model.layers.28': 'cpu', | |
'model.layers.29': 'cpu', | |
'model.layers.30': 'cpu', | |
'model.layers.31': 'cpu', | |
'model.norm': 'cpu', | |
'lm_head': 'cpu'} | |
Model data type | |
PyTorch model weights are normally instantiated as torch.float32 and it can be an issue if you try to load a model as a different data type. For example, you'd need twice as much memory to load the weights in torch.float32 and then again to load them in your desired data type, like torch.float16. | |
[!WARNING] | |
Due to how PyTorch is designed, the torch_dtype parameter only supports floating data types. | |
To avoid wasting memory like this, explicitly set the torch_dtype parameter to the desired data type or set torch_dtype="auto" to load the weights with the most optimal memory pattern (the data type is automatically derived from the model weights). | |
from transformers import AutoModelForCausalLM | |
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.float16) | |
from transformers import AutoModelForCausalLM | |
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype="auto") |