File size: 8,226 Bytes
c62903f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# Adopted from https://github.com/ddlBoJack/SLAM-LLM/blob/main/src/slam_llm/models/encoder.py

import types

import deepspeed
import torch
import torch.nn as nn
import torch.nn.functional as F

from egogpt.utils import rank0_print

from .model import ModelDimensions, Whisper


def load_zero_partitions(
    model,
    state_dict,
    is_deepspeed_zero3_enabled,
    pretrained_model_path,
    ignore_mismatched_sizes=False,
):
    """
    adept from pytorch lightning and transformers
    with deepspeed.zero.Init():
        model = MyModel()
    state_dict = torch.load(model_path, map_location="cpu")
    load_zero_partitions(model, prefix="")
    """

    # because zero3 puts placeholders in model params, this context
    # manager gathers (unpartitions) the params of the current layer, then loads from
    # the state dict and then re-partitions them again
    model_state_dict = model.state_dict()
    expected_keys = list(model_state_dict.keys())
    loaded_keys = list(state_dict.keys())
    missing_keys = list(set(expected_keys) - set(loaded_keys))
    unexpected_keys = list(set(loaded_keys) - set(expected_keys))

    # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
    # matching the weights in the model.
    mismatched_keys = []
    if ignore_mismatched_sizes:
        for checkpoint_key in loaded_keys:
            model_key = checkpoint_key

            if (
                model_key in model_state_dict
                and state_dict[checkpoint_key].shape
                != model_state_dict[model_key].shape
            ):
                mismatched_keys.append(
                    (
                        checkpoint_key,
                        state_dict[checkpoint_key].shape,
                        model_state_dict[model_key].shape,
                    )
                )
                del state_dict[checkpoint_key]
    # copy state_dict so _load_from_state_dict can modify it
    metadata = getattr(state_dict, "_metadata", None)
    state_dict = state_dict.copy()
    if metadata is not None:
        state_dict._metadata = metadata

    error_msgs = []

    # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
    # so we need to apply the function recursively.
    def load(module, prefix=""):
        local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
        args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
        if is_deepspeed_zero3_enabled:
            # because zero3 puts placeholders in model params, this context
            # manager gathers (unpartitions) the params of the current layer, then loads from
            # the state dict and then re-partitions them again
            with deepspeed.zero.GatheredParameters(
                list(module.parameters(recurse=False)), modifier_rank=0
            ):
                if torch.distributed.get_rank() == 0:
                    module._load_from_state_dict(*args)
        else:
            module._load_from_state_dict(*args)

        for name, child in module._modules.items():
            if child is not None:
                load(child, prefix + name + ".")

    # Make sure we are able to load base models as well as derived models (with heads)
    start_prefix = ""
    model_to_load = model
    load(model_to_load, prefix=start_prefix)
    del state_dict
    if len(error_msgs) > 0:
        error_msg = "\n\t".join(error_msgs)
        if "size mismatch" in error_msg:
            error_msg += "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
        raise RuntimeError(
            f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}"
        )
    if len(unexpected_keys) > 0:
        rank0_print(
            f"Some weights of the model checkpoint at {pretrained_model_path} were not used when"
            f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
            f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
            " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
            " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
            f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
            " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
        )
    else:
        rank0_print(
            f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n"
        )
    if len(missing_keys) > 0:
        rank0_print(
            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
            f" {pretrained_model_path} and are newly initialized: {missing_keys}\nYou should probably"
            " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
        )
    elif len(mismatched_keys) == 0:
        rank0_print(
            f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
            f" {pretrained_model_path}.\nIf your task is similar to the task the model of the checkpoint"
            f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
            " training."
        )
    if len(mismatched_keys) > 0:
        mismatched_warning = "\n".join(
            [
                f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
                for key, shape1, shape2 in mismatched_keys
            ]
        )
        rank0_print(
            f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
            f" {pretrained_model_path} and are newly initialized because the shapes did not"
            f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
            " to use it for predictions and inference."
        )


class WhisperWrappedEncoder(nn.Module):
    def __init__(self, config, delay_load=False):
        super().__init__()

        self.is_loaded = False
        self.speech_encoder_name = config.speech_encoder

        if not delay_load:
            rank0_print(f"Loading speech encoder: {self.speech_encoder_name}")
            self.load_model(config)

    def load_model(self, model_config):
        if self.is_loaded:
            print(
                "{} is already loaded, `load_model` called again, skipping.".format(
                    self.speech_encoder_name
                )
            )
            return

        def replace_layer_norm(module):
            from whisper.model import LayerNorm

            for name, child in module.named_children():
                if isinstance(child, LayerNorm):
                    old_params = child.state_dict()
                    new_layer_norm = nn.LayerNorm(
                        child.normalized_shape,
                        eps=child.eps,
                        elementwise_affine=child.elementwise_affine,
                    )
                    new_layer_norm.load_state_dict(old_params)
                    setattr(module, name, new_layer_norm)
                else:
                    replace_layer_norm(child)

        # import whisper
        # self.encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder
        checkpoint = torch.load(self.speech_encoder_name, map_location="cpu")
        dims = ModelDimensions(**checkpoint["dims"])
        model = Whisper(dims)
        deepspeed3_enabled = True
        # print(deepspeed3_enabled)
        load_zero_partitions(
            model,
            checkpoint["model_state_dict"],
            deepspeed3_enabled,
            self.speech_encoder_name,
        )
        self.encoder = model.encoder
        replace_layer_norm(self.encoder)
        self.encoder.requires_grad_(False)

        self.is_loaded = True

    def forward(self, audio):
        return self.encoder(audio)