|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include "namespace_config.h" |
|
#include "philox.cuh" |
|
#include "utils.h" |
|
|
|
namespace FLASH_NAMESPACE { |
|
|
|
struct Dropout { |
|
|
|
const unsigned long long seed, offset; |
|
const uint8_t p_dropout_in_uint8_t; |
|
|
|
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, |
|
const uint8_t p_dropout_in_uint8_t, |
|
const int bid, const int hid, const int tid, const int nheads) |
|
: seed(seed) |
|
, offset(offset + (bid * nheads + hid) * 32 + tid % 32) |
|
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) { |
|
} |
|
|
|
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> |
|
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_, |
|
int block_row_start, int block_col_start, int block_row_stride) { |
|
|
|
Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout())); |
|
using T = typename Engine::value_type; |
|
auto encode_dropout = [](bool keep, T val) { |
|
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); |
|
}; |
|
static_assert(decltype(size<2>(tensor))::value % 2 == 0); |
|
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); |
|
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); |
|
|
|
#pragma unroll |
|
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { |
|
uint2 rowcol = make_uint2(block_row_start, block_col_start); |
|
#pragma unroll |
|
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { |
|
|
|
uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset); |
|
|
|
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!encode_dropout_in_sign_bit |
|
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) { |
|
uint16_t rnd_16[16]; |
|
#pragma unroll |
|
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } |
|
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16); |
|
#pragma unroll |
|
for (int j = 0; j < 2; j++) { |
|
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j)); |
|
|
|
|
|
#pragma unroll |
|
for (int i = 0; i < 4; i++) { |
|
uint32_t mask; |
|
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); |
|
tensor_uint32(i) &= mask; |
|
} |
|
|
|
} |
|
} else { |
|
#pragma unroll |
|
for (int j = 0; j < 2; j++) { |
|
#pragma unroll |
|
for (int i = 0; i < 8; i++) { |
|
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); |
|
} |
|
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j)); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
} |
|
} |
|
} |
|
|
|
}; |
|
|
|
} |
|
|