Duplicated from 3DAIGC/LAM
17cd746
1
2
3
4
5
6
7
import torch.nn as nn def count_parameters_in_MB(model): if isinstance(model, nn.Module): return sum(v.numel() for v in model.parameters()) / 1e6 else: return sum(v.numel() for v in model) / 1e6