|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ...models.utils import is_parallel |
|
|
|
__all__ = ["EMA"] |
|
|
|
|
|
def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None: |
|
for k, v in ema.state_dict().items(): |
|
if v.dtype.is_floating_point: |
|
v -= (1.0 - decay) * (v - new_state_dict[k].detach()) |
|
|
|
|
|
class EMA: |
|
def __init__(self, model: nn.Module, decay: float, warmup_steps=2000): |
|
self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval() |
|
self.decay = decay |
|
self.warmup_steps = warmup_steps |
|
|
|
for p in self.shadows.parameters(): |
|
p.requires_grad = False |
|
|
|
def step(self, model: nn.Module, global_step: int) -> None: |
|
with torch.no_grad(): |
|
msd = (model.module if is_parallel(model) else model).state_dict() |
|
update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps))) |
|
|
|
def state_dict(self) -> dict[float, dict[str, torch.Tensor]]: |
|
return {self.decay: self.shadows.state_dict()} |
|
|
|
def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None: |
|
for decay in state_dict: |
|
if decay == self.decay: |
|
self.shadows.load_state_dict(state_dict[decay]) |
|
|