Spaces:
Running
on
Zero
Running
on
Zero
input_size?
Browse files- llama_diffusion_model.py +2 -3
llama_diffusion_model.py
CHANGED
@@ -213,12 +213,11 @@ class CustomTransformerModel(PreTrainedModel):
|
|
213 |
self.llama = get_peft_model(self.llama, lora_config)
|
214 |
self.llama.print_trainable_parameters() # Print number of trainable parameters
|
215 |
self.llama = self.llama.to(torch.float16)
|
216 |
-
self.input_size = 256
|
217 |
|
218 |
|
219 |
def forward(self, input_ids, labels=None, **kwargs):
|
220 |
batch_size, seq_length = input_ids.shape
|
221 |
-
assert seq_length ==
|
222 |
|
223 |
with autocast("cuda", dtype=torch.float16): # ✅ Correct future-proof usage
|
224 |
|
@@ -233,7 +232,7 @@ class CustomTransformerModel(PreTrainedModel):
|
|
233 |
loss = None
|
234 |
|
235 |
if labels is not None:
|
236 |
-
assert labels.shape == (batch_size,
|
237 |
|
238 |
# Compute loss
|
239 |
loss_fct = torch.nn.CrossEntropyLoss()
|
|
|
213 |
self.llama = get_peft_model(self.llama, lora_config)
|
214 |
self.llama.print_trainable_parameters() # Print number of trainable parameters
|
215 |
self.llama = self.llama.to(torch.float16)
|
|
|
216 |
|
217 |
|
218 |
def forward(self, input_ids, labels=None, **kwargs):
|
219 |
batch_size, seq_length = input_ids.shape
|
220 |
+
assert seq_length == 256, f"Expected input length input_size, got {seq_length}"
|
221 |
|
222 |
with autocast("cuda", dtype=torch.float16): # ✅ Correct future-proof usage
|
223 |
|
|
|
232 |
loss = None
|
233 |
|
234 |
if labels is not None:
|
235 |
+
assert labels.shape == (batch_size, 256), f"Labels shape mismatch: expected (batch, input_size), got {labels.shape}"
|
236 |
|
237 |
# Compute loss
|
238 |
loss_fct = torch.nn.CrossEntropyLoss()
|