From 299509dbfc0e7dfdfbb4befe222cd38501b93632 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Mon, 21 Dec 2020 13:54:42 -0500 Subject: [PATCH] changed cast to round to zero if casting from float to integral --- .../kernel_compiler/gpu/cuda_impl/cast_impl.cu | 16 ++++++++-------- tests/st/ops/gpu/test_cast_op.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_impl.cu index 41eae121a3..14c4615d37 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/cast_impl.cu @@ -28,35 +28,35 @@ __device__ __forceinline__ void CastBase(const S *input_addr, T *output_addr) { // half --> integer __device__ __forceinline__ void CastBase(const half *input_addr, uint64_t *output_addr) { - *output_addr = __half2ull_rd((*input_addr)); + *output_addr = __half2ull_rz((*input_addr)); } __device__ __forceinline__ void CastBase(const half *input_addr, int64_t *output_addr) { - *output_addr = __half2ll_rd((*input_addr)); + *output_addr = __half2ll_rz((*input_addr)); } __device__ __forceinline__ void CastBase(const half *input_addr, uint32_t *output_addr) { - *output_addr = __half2uint_rd((*input_addr)); + *output_addr = __half2uint_rz((*input_addr)); } __device__ __forceinline__ void CastBase(const half *input_addr, int32_t *output_addr) { - *output_addr = __half2int_rd((*input_addr)); + *output_addr = __half2int_rz((*input_addr)); } __device__ __forceinline__ void CastBase(const half *input_addr, uint16_t *output_addr) { - *output_addr = __half2ushort_rd((*input_addr)); + *output_addr = __half2ushort_rz((*input_addr)); } __device__ __forceinline__ void CastBase(const half *input_addr, int16_t *output_addr) { - *output_addr = __half2short_rd((*input_addr)); + *output_addr = __half2short_rz((*input_addr)); } __device__ __forceinline__ void CastBase(const half *input_addr, uint8_t *output_addr) { - *output_addr = static_cast(__half2ushort_rd((*input_addr))); + *output_addr = static_cast(__half2ushort_rz((*input_addr))); } __device__ __forceinline__ void CastBase(const half *input_addr, int8_t *output_addr) { - *output_addr = static_cast(__half2short_rd((*input_addr))); + *output_addr = static_cast(__half2short_rz((*input_addr))); } // integer --> half diff --git a/tests/st/ops/gpu/test_cast_op.py b/tests/st/ops/gpu/test_cast_op.py index 8316d31f2f..d63ba2d005 100644 --- a/tests/st/ops/gpu/test_cast_op.py +++ b/tests/st/ops/gpu/test_cast_op.py @@ -604,7 +604,7 @@ def test_cast31(): @pytest.mark.env_onecard def test_cast32(): np.random.seed(10) - x = np.random.rand(*(3, 2)).astype(np.float16) + x = np.random.uniform(-5, 5, (3, 2)).astype(np.float16) x0 = Tensor(x) t0 = mstype.int32 x1 = Tensor(x)