Spaces:
Running on Zero

Ruurd commited on
Commit
a5ca1bf
·
1 Parent(s): f7efac8

input_size?

Browse files
Files changed (1) hide show
  1. llama_diffusion_model.py +2 -3
llama_diffusion_model.py CHANGED
@@ -213,12 +213,11 @@ class CustomTransformerModel(PreTrainedModel):
213
  self.llama = get_peft_model(self.llama, lora_config)
214
  self.llama.print_trainable_parameters() # Print number of trainable parameters
215
  self.llama = self.llama.to(torch.float16)
216
- self.input_size = 256
217
 
218
 
219
  def forward(self, input_ids, labels=None, **kwargs):
220
  batch_size, seq_length = input_ids.shape
221
- assert seq_length == self.input_size, f"Expected input length input_size, got {seq_length}"
222
 
223
  with autocast("cuda", dtype=torch.float16): # ✅ Correct future-proof usage
224
 
@@ -233,7 +232,7 @@ class CustomTransformerModel(PreTrainedModel):
233
  loss = None
234
 
235
  if labels is not None:
236
- assert labels.shape == (batch_size, self.input_size), f"Labels shape mismatch: expected (batch, input_size), got {labels.shape}"
237
 
238
  # Compute loss
239
  loss_fct = torch.nn.CrossEntropyLoss()
 
213
  self.llama = get_peft_model(self.llama, lora_config)
214
  self.llama.print_trainable_parameters() # Print number of trainable parameters
215
  self.llama = self.llama.to(torch.float16)
 
216
 
217
 
218
  def forward(self, input_ids, labels=None, **kwargs):
219
  batch_size, seq_length = input_ids.shape
220
+ assert seq_length == 256, f"Expected input length input_size, got {seq_length}"
221
 
222
  with autocast("cuda", dtype=torch.float16): # ✅ Correct future-proof usage
223
 
 
232
  loss = None
233
 
234
  if labels is not None:
235
+ assert labels.shape == (batch_size, 256), f"Labels shape mismatch: expected (batch, input_size), got {labels.shape}"
236
 
237
  # Compute loss
238
  loss_fct = torch.nn.CrossEntropyLoss()