Update modularStarEncoder.py
Browse files- modularStarEncoder.py +2 -1
modularStarEncoder.py
CHANGED
@@ -208,9 +208,10 @@ def get_pooling_mask(
|
|
208 |
if DEVICE<0:
|
209 |
DEVICE = "cpu"
|
210 |
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
211 |
-
|
212 |
|
213 |
pooling_mask = (repeated_idx <= ranges).long()
|
|
|
214 |
|
215 |
return pooling_mask
|
216 |
|
|
|
208 |
if DEVICE<0:
|
209 |
DEVICE = "cpu"
|
210 |
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
211 |
+
|
212 |
|
213 |
pooling_mask = (repeated_idx <= ranges).long()
|
214 |
+
pooling_mask.to(DEVICE)
|
215 |
|
216 |
return pooling_mask
|
217 |
|