GitOrigin-RevId: 7e78bdae91
tags/v1.7.0
| @@ -78,6 +78,72 @@ struct ArgmxxOp { | |||
| const wtype INIT; | |||
| }; | |||
| template <bool is_max> | |||
| struct ArgmxxOp<dt_float32, is_max> { | |||
| using stype_ = dt_float32; | |||
| struct wtype { | |||
| stype_ key; | |||
| dt_int32 val; | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype() {} | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val) | |||
| : key(key), val(val) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs) | |||
| : key(rhs.key), val(rhs.val) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs) | |||
| : key(rhs.key), val(rhs.val) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs) | |||
| : key(rhs.key), val(rhs.val) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile { | |||
| this->key = rhs.key; | |||
| this->val = rhs.val; | |||
| return *this; | |||
| } | |||
| }; | |||
| MEGDNN_HOST MEGDNN_DEVICE | |||
| ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C) | |||
| : src(src), | |||
| dst(dst), | |||
| A(A), | |||
| B(B), | |||
| C(C), | |||
| INIT(wtype( | |||
| is_max ? DTypeTrait<stype_>::min() : DTypeTrait<stype_>::max(), | |||
| 0)) {} | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||
| wtype res; | |||
| res.key = src[idx]; | |||
| res.val = idx / C % B; | |||
| return res; | |||
| } | |||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||
| dst[idx] = val.val; | |||
| } | |||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
| #if defined(__CUDA_ARCH__) | |||
| if (isnan(lhs.key)) | |||
| #else | |||
| if (std::isnan(lhs.key)) | |||
| #endif | |||
| return lhs; | |||
| if (is_max) { | |||
| if (lhs.key > rhs.key) | |||
| return lhs; | |||
| else | |||
| return rhs; | |||
| } else { | |||
| if (lhs.key < rhs.key) | |||
| return lhs; | |||
| else | |||
| return rhs; | |||
| } | |||
| } | |||
| stype_* src; | |||
| dt_int32* dst; | |||
| uint32_t A, B, C; | |||
| const wtype INIT; | |||
| }; | |||
| } // namespace argmxx | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -119,6 +119,28 @@ struct MinOp { | |||
| : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype> | |||
| struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||
| typedef dt_float32 wtype; | |||
| const wtype INIT; | |||
| src_ctype* src; | |||
| dst_ctype* dst; | |||
| const size_t B; | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
| #if defined(__CUDA_ARCH__) | |||
| return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||
| #else | |||
| return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||
| #endif | |||
| } | |||
| MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||
| : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||
| struct MaxOp { | |||
| typedef wtype_ wtype; | |||
| @@ -141,6 +163,28 @@ struct MaxOp { | |||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype> | |||
| struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||
| typedef dt_float32 wtype; | |||
| const wtype INIT; | |||
| src_ctype* src; | |||
| dst_ctype* dst; | |||
| const size_t B; | |||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||
| #if defined(__CUDA_ARCH__) | |||
| return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||
| #else | |||
| return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||
| #endif | |||
| } | |||
| MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||
| }; | |||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||
| struct CheckNonFiniteOp { | |||
| typedef wtype_ wtype; | |||
| @@ -30,6 +30,10 @@ struct FakeQuantKernOp { | |||
| __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | |||
| ctype x = round(input[idx] / scale) + zero_point; | |||
| if (isnan(x)) { | |||
| output[idx] = NAN; | |||
| return; | |||
| } | |||
| x = fmaxf(fminf(x, qmax), qmin); | |||
| output[idx] = (x - zero_point) * scale; | |||
| } | |||
| @@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp { | |||
| __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { | |||
| ctype x = round(input[idx] / scale) + zero_point; | |||
| grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0; | |||
| grad[idx] = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff[idx] : 0.0; | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| @@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig { | |||
| __device__ void operator()( | |||
| uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { | |||
| ctype x = round(input / scale) + zero_point; | |||
| if (isnan(x)) { | |||
| output = NAN; | |||
| return; | |||
| } | |||
| x = fmaxf(fminf(x, qmax), qmin); | |||
| output = (x - zero_point) * scale; | |||
| } | |||
| @@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig { | |||
| uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, | |||
| ctype zero_point) { | |||
| ctype x = round(input / scale) + zero_point; | |||
| grad = x <= qmax && x >= qmin ? diff : 0.0; | |||
| grad = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff : 0.0; | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| @@ -26,14 +26,18 @@ struct traits; | |||
| template <> | |||
| struct traits<true> { | |||
| static const float init; | |||
| static bool better_than(float lhs, float rhs) { return lhs > rhs; } | |||
| static bool better_than(float lhs, float rhs) { | |||
| return std::isnan(lhs) ? true : lhs > rhs; | |||
| } | |||
| }; | |||
| const float traits<true>::init = std::numeric_limits<float>::lowest(); | |||
| template <> | |||
| struct traits<false> { | |||
| static const float init; | |||
| static float better_than(float lhs, float rhs) { return lhs < rhs; } | |||
| static float better_than(float lhs, float rhs) { | |||
| return std::isnan(lhs) ? true : lhs < rhs; | |||
| } | |||
| }; | |||
| const float traits<false>::init = std::numeric_limits<float>::max(); | |||
| @@ -73,25 +73,35 @@ const ctype Trait<Mode::PRODUCT, ctype>::INIT = ctype(1); | |||
| template <typename ctype> | |||
| struct Trait<Mode::MIN, ctype> { | |||
| static const ctype INIT; | |||
| static ctype apply(ctype x, ctype y) { return x < y ? x : y; } | |||
| static ctype visit(ctype x) { return x; } | |||
| static ctype write(ctype x, size_t) { return x; } | |||
| }; | |||
| template <typename ctype> | |||
| const ctype Trait<Mode::MIN, ctype>::INIT = DTypeTrait<ctype>::max(); | |||
| template <> | |||
| struct Trait<Mode::MIN, dt_float32> { | |||
| using ctype = dt_float32; | |||
| static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; } | |||
| static ctype visit(ctype x) { return x; } | |||
| static ctype write(ctype x, size_t) { return x; } | |||
| }; | |||
| template <typename ctype> | |||
| struct Trait<Mode::MAX, ctype> { | |||
| static const ctype INIT; | |||
| static ctype apply(ctype x, ctype y) { return x > y ? x : y; } | |||
| static ctype visit(ctype x) { return x; } | |||
| static ctype write(ctype x, size_t) { return x; } | |||
| }; | |||
| template <typename ctype> | |||
| const ctype Trait<Mode::MAX, ctype>::INIT = DTypeTrait<ctype>::min(); | |||
| template <> | |||
| struct Trait<Mode::MAX, dt_float32> { | |||
| using ctype = dt_float32; | |||
| static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; } | |||
| static ctype visit(ctype x) { return x; } | |||
| static ctype write(ctype x, size_t) { return x; } | |||
| }; | |||
| template <Mode mode, typename ctype> | |||
| void reduce_fwd( | |||
| @@ -21,7 +21,9 @@ using namespace fake_quant; | |||
| TEST_F(CUDA, FAKE_QUANT) { | |||
| std::vector<TestArg> args = get_args(); | |||
| auto dtype = dtype::Float32(); | |||
| std::unique_ptr<RNG> rng; | |||
| UniformFloatRNG rng(-1.0f, 1.0f); | |||
| const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
| UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| @@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) { | |||
| .set_dtype(2, dtype) | |||
| .set_dtype(3, dtype) | |||
| .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | |||
| checker.set_allow_invalid_check(true); | |||
| checker.set_rng(0, &rng1); | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype) | |||
| .set_dtype(1, dtype) | |||
| .set_dtype(2, dtype) | |||
| .set_dtype(3, dtype) | |||
| .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_allow_invalid_check(false); | |||
| } | |||
| // test noncontiguous layout | |||
| for (auto&& arg : args) { | |||
| @@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) { | |||
| {scale_shape, dtype::Float32()}, | |||
| {zeropoint_shape, dtype::Float32()}, | |||
| ilayout}); | |||
| checker.set_allow_invalid_check(true); | |||
| checker.set_rng(0, &rng1); | |||
| checker.set_param(param).execl( | |||
| {ilayout, | |||
| {scale_shape, dtype::Float32()}, | |||
| {zeropoint_shape, dtype::Float32()}, | |||
| ilayout}); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_allow_invalid_check(false); | |||
| } | |||
| } | |||
| TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||
| std::vector<TestArg> args = get_args(); | |||
| auto dtype = dtype::Float32(); | |||
| UniformFloatRNG rng(-1.0f, 1.0f); | |||
| const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
| UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| @@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||
| .set_dtype(4, dtype) | |||
| .execs(TensorShapeArray{ | |||
| ishape, ishape, scale_shape, zeropoint_shape, ishape}); | |||
| checker.set_allow_invalid_check(true); | |||
| checker.set_rng(0, &rng1); | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype) | |||
| .set_dtype(1, dtype) | |||
| .set_dtype(2, dtype) | |||
| .set_dtype(3, dtype) | |||
| .set_dtype(4, dtype) | |||
| .execs(TensorShapeArray{ | |||
| ishape, ishape, scale_shape, zeropoint_shape, ishape}); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_allow_invalid_check(false); | |||
| } | |||
| // test noncontiguous layout | |||
| for (auto&& arg : args) { | |||
| @@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { | |||
| {scale_shape, dtype::Float32()}, | |||
| {zeropoint_shape, dtype::Float32()}, | |||
| ilayout}); | |||
| checker.set_allow_invalid_check(true); | |||
| checker.set_rng(0, &rng1); | |||
| checker.set_param(param).execl( | |||
| {ilayout, | |||
| ilayout, | |||
| {scale_shape, dtype::Float32()}, | |||
| {zeropoint_shape, dtype::Float32()}, | |||
| ilayout}); | |||
| checker.set_rng(0, &rng); | |||
| checker.set_allow_invalid_check(false); | |||
| } | |||
| } | |||
| @@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) { | |||
| // very large reduce | |||
| checker.execs({{1, 4194304, 1}, {}}); | |||
| // inputs have nan | |||
| { | |||
| const auto nan = std::numeric_limits<float>::quiet_NaN(); | |||
| UniformFloatWithValueRNG rng1 = | |||
| UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); | |||
| checker.set_allow_invalid_check(true).set_rng(0, &rng1); | |||
| for (auto mode : {Mode::MIN, Mode::MAX}) { | |||
| checker.set_param({mode, 1}); | |||
| checker.execs({{2, 64, 32}, {}}); | |||
| } | |||
| checker.set_allow_invalid_check(false); | |||
| } | |||
| checker.set_rng(0, &rng); | |||
| auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype, | |||
| Reduce::DataType data_type) { | |||
| for (int32_t axis : {0, 1, 2, 3}) { | |||
| @@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr): | |||
| data2_shape = (2, 9, 12) | |||
| data1 = np.random.random(data1_shape).astype(np.float32) | |||
| data2 = np.random.random(data2_shape).astype(np.float32) | |||
| cases = [{"input": data1}, {"input": data2}] | |||
| cases = [ | |||
| {"input": data1}, | |||
| {"input": data2}, | |||
| {"input": np.array([[[1, 2, np.nan, 4], [8, 6, 5, 2], [2, 3, 4, 5]]])}, | |||
| ] | |||
| if opr not in (F.argmin, F.argmax): | |||
| # test default axis | |||
| @@ -143,6 +143,11 @@ def test_fakequant(): | |||
| assert np.allclose(x.grad.numpy(), x1.grad.numpy()) | |||
| assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) | |||
| # test nan | |||
| x = F.full((1, 32, 3, 3), np.nan) | |||
| y = fake_quant_tensor(x, qparams).numpy() | |||
| assert np.isnan(y).all() | |||
| zero_point = tensor([1.0], dtype=np.float32) | |||
| scale = tensor([4.0], dtype=np.float32) | |||
| run(zero_point, scale) | |||