File size: 1,305 Bytes
4a40efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
try:
    from apex.normalization import FusedLayerNorm
except ImportError as e:
    try:
        from xformers.triton import FusedLayerNorm
    except ImportError as e:
        FusedLayerNorm = None


def replace_all_layernorms(model):
    if FusedLayerNorm is None:
        print("WARNING: apex.normalization & xformers.triton.FusedLayerNorm is not found, \
              skip using FusedLayerNorm")
        return model
    for name, module in model.named_children():
        if isinstance(module, torch.nn.LayerNorm):
            setattr(model, name, FusedLayerNorm(
                module.normalized_shape, module.eps, module.elementwise_affine))
        else:
            replace_all_layernorms(module)
    return model


def replace_all_groupnorms(model):
    try:
        from apex.contrib.group_norm import GroupNorm
    except ImportError as e:
        print("WARNING: apex.contrib.group_norm is not found, skip using apex groupnorm")
        return model
    for name, module in model.named_children():
        if isinstance(module, torch.nn.GroupNorm):
            setattr(model, name, GroupNorm(
                module.num_groups, module.num_channels,
                eps=module.eps, affine=module.affine))
        else:
            replace_all_groupnorms(module)
    return model