Spaces:
Running on Zero

Ruurd commited on
Commit
8851563
·
verified ·
1 Parent(s): 238c8f8

Make attention mask float

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +2 -0
llama_diffusion_model.py CHANGED
@@ -133,6 +133,8 @@ class CustomTransformerModel(PreTrainedModel):
133
  raise ValueError(f"Unknown masking type: {self.config.masking_type}")
134
 
135
  attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
 
 
136
 
137
  with autocast("cuda", dtype=torch.float16):
138
  outputs = self.llama(
 
133
  raise ValueError(f"Unknown masking type: {self.config.masking_type}")
134
 
135
  attention_mask = base_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, seq_len, seq_len).clone()
136
+ attention_mask = attention_mask.to(dtype=torch.float32) # required for SDPA and Flash attention
137
+
138
 
139
  with autocast("cuda", dtype=torch.float16):
140
  outputs = self.llama(