Mariam-Elz commited on
Commit
2e06668
·
verified ·
1 Parent(s): bb92043

Upload imagedream/ldm/modules/ema.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. imagedream/ldm/modules/ema.py +86 -86
imagedream/ldm/modules/ema.py CHANGED
@@ -1,86 +1,86 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- class LitEma(nn.Module):
6
- def __init__(self, model, decay=0.9999, use_num_upates=True):
7
- super().__init__()
8
- if decay < 0.0 or decay > 1.0:
9
- raise ValueError("Decay must be between 0 and 1")
10
-
11
- self.m_name2s_name = {}
12
- self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
- self.register_buffer(
14
- "num_updates",
15
- torch.tensor(0, dtype=torch.int)
16
- if use_num_upates
17
- else torch.tensor(-1, dtype=torch.int),
18
- )
19
-
20
- for name, p in model.named_parameters():
21
- if p.requires_grad:
22
- # remove as '.'-character is not allowed in buffers
23
- s_name = name.replace(".", "")
24
- self.m_name2s_name.update({name: s_name})
25
- self.register_buffer(s_name, p.clone().detach().data)
26
-
27
- self.collected_params = []
28
-
29
- def reset_num_updates(self):
30
- del self.num_updates
31
- self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
-
33
- def forward(self, model):
34
- decay = self.decay
35
-
36
- if self.num_updates >= 0:
37
- self.num_updates += 1
38
- decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
-
40
- one_minus_decay = 1.0 - decay
41
-
42
- with torch.no_grad():
43
- m_param = dict(model.named_parameters())
44
- shadow_params = dict(self.named_buffers())
45
-
46
- for key in m_param:
47
- if m_param[key].requires_grad:
48
- sname = self.m_name2s_name[key]
49
- shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
- shadow_params[sname].sub_(
51
- one_minus_decay * (shadow_params[sname] - m_param[key])
52
- )
53
- else:
54
- assert not key in self.m_name2s_name
55
-
56
- def copy_to(self, model):
57
- m_param = dict(model.named_parameters())
58
- shadow_params = dict(self.named_buffers())
59
- for key in m_param:
60
- if m_param[key].requires_grad:
61
- m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
- else:
63
- assert not key in self.m_name2s_name
64
-
65
- def store(self, parameters):
66
- """
67
- Save the current parameters for restoring later.
68
- Args:
69
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
- temporarily stored.
71
- """
72
- self.collected_params = [param.clone() for param in parameters]
73
-
74
- def restore(self, parameters):
75
- """
76
- Restore the parameters stored with the `store` method.
77
- Useful to validate the model with EMA parameters without affecting the
78
- original optimization process. Store the parameters before the
79
- `copy_to` method. After validation (or model saving), use this to
80
- restore the former parameters.
81
- Args:
82
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
- updated with the stored parameters.
84
- """
85
- for c_param, param in zip(self.collected_params, parameters):
86
- param.data.copy_(c_param.data)
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def reset_num_updates(self):
30
+ del self.num_updates
31
+ self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int))
32
+
33
+ def forward(self, model):
34
+ decay = self.decay
35
+
36
+ if self.num_updates >= 0:
37
+ self.num_updates += 1
38
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
39
+
40
+ one_minus_decay = 1.0 - decay
41
+
42
+ with torch.no_grad():
43
+ m_param = dict(model.named_parameters())
44
+ shadow_params = dict(self.named_buffers())
45
+
46
+ for key in m_param:
47
+ if m_param[key].requires_grad:
48
+ sname = self.m_name2s_name[key]
49
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
50
+ shadow_params[sname].sub_(
51
+ one_minus_decay * (shadow_params[sname] - m_param[key])
52
+ )
53
+ else:
54
+ assert not key in self.m_name2s_name
55
+
56
+ def copy_to(self, model):
57
+ m_param = dict(model.named_parameters())
58
+ shadow_params = dict(self.named_buffers())
59
+ for key in m_param:
60
+ if m_param[key].requires_grad:
61
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
62
+ else:
63
+ assert not key in self.m_name2s_name
64
+
65
+ def store(self, parameters):
66
+ """
67
+ Save the current parameters for restoring later.
68
+ Args:
69
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
70
+ temporarily stored.
71
+ """
72
+ self.collected_params = [param.clone() for param in parameters]
73
+
74
+ def restore(self, parameters):
75
+ """
76
+ Restore the parameters stored with the `store` method.
77
+ Useful to validate the model with EMA parameters without affecting the
78
+ original optimization process. Store the parameters before the
79
+ `copy_to` method. After validation (or model saving), use this to
80
+ restore the former parameters.
81
+ Args:
82
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
83
+ updated with the stored parameters.
84
+ """
85
+ for c_param, param in zip(self.collected_params, parameters):
86
+ param.data.copy_(c_param.data)