Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
739f9ab
·
verified ·
1 Parent(s): 8512abe

Update modularStarEncoder.py

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +2 -1
modularStarEncoder.py CHANGED
@@ -208,9 +208,10 @@ def get_pooling_mask(
208
  if DEVICE<0:
209
  DEVICE = "cpu"
210
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
211
- ranges.to(DEVICE)
212
 
213
  pooling_mask = (repeated_idx <= ranges).long()
 
214
 
215
  return pooling_mask
216
 
 
208
  if DEVICE<0:
209
  DEVICE = "cpu"
210
  ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
211
+
212
 
213
  pooling_mask = (repeated_idx <= ranges).long()
214
+ pooling_mask.to(DEVICE)
215
 
216
  return pooling_mask
217