Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,535 Bytes
a252b0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import torch
import accelerate.accelerator
from diffusers.models.normalization import RMSNorm, LayerNorm, FP32LayerNorm, AdaLayerNormContinuous
accelerate.accelerator.convert_outputs_to_fp32 = lambda x: x
def LayerNorm_forward(self, x):
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
LayerNorm.forward = LayerNorm_forward
torch.nn.LayerNorm.forward = LayerNorm_forward
def FP32LayerNorm_forward(self, x):
origin_dtype = x.dtype
return torch.nn.functional.layer_norm(
x.float(),
self.normalized_shape,
self.weight.float() if self.weight is not None else None,
self.bias.float() if self.bias is not None else None,
self.eps,
).to(origin_dtype)
FP32LayerNorm.forward = FP32LayerNorm_forward
def RMSNorm_forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is None:
return hidden_states.to(input_dtype)
return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
RMSNorm.forward = RMSNorm_forward
def AdaLayerNormContinuous_forward(self, x, conditioning_embedding):
emb = self.linear(self.silu(conditioning_embedding))
scale, shift = emb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
AdaLayerNormContinuous.forward = AdaLayerNormContinuous_forward
|