Fix dimension calculation issue in mask list during forward due to precision error
Background
During the forward process, a dimension mismatch issue was observed when processing the mask list, causing subsequent operations to fail.
Root Cause
Due to floating-point precision errors, applying .long() after sum() could lead to an inaccurate dimension count. Summing float tensors first accumulates minor errors, which affects the result after type conversion.
Modification
Moved the .long() conversion before the sum() operation to ensure that all values are converted to integers prior to summation, thus avoiding precision-related issues.
Impact
This change is localized to the mask processing logic within the forward pass and does not affect other functionalities. Unit tests have been run and passed successfully.
Thank you for fixing this!!!