diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu index 70032336fb..18bd527ce3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cu @@ -26,11 +26,10 @@ __global__ void HsigmoidKernel(size_t size, const T *input, T *output) { } template -__global__ void HsigmoidGradKernel(size_t size, const T *dout, T *output) { +__global__ void HsigmoidGradKernel(size_t size, const T *dout, const T *x, T *output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { T value = dout[pos] / static_cast(6); - value = value > static_cast(1) ? static_cast(0) : value; - output[pos] = value > static_cast(0) ? value : static_cast(0); + output[pos] = (x[pos] > static_cast(-3) && x[pos] < static_cast(3)) ? value : static_cast(0); } } @@ -40,12 +39,14 @@ void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cud } template -void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream) { - HsigmoidGradKernel<<>>(size, dout, output); +void CalHSigmoidGrad(const size_t &size, const T *dout, const T *x, T *output, cudaStream_t cuda_stream) { + HsigmoidGradKernel<<>>(size, dout, x, output); } template void CalHSigmoid(const size_t &size, const half *input, half *output, cudaStream_t cuda_stream); template void CalHSigmoid(const size_t &size, const float *input, float *output, cudaStream_t cuda_stream); -template void CalHSigmoidGrad(const size_t &size, const half *dout, half *output, cudaStream_t cuda_stream); -template void CalHSigmoidGrad(const size_t &size, const float *dout, float *output, cudaStream_t cuda_stream); +template void CalHSigmoidGrad(const size_t &size, const half *dout, const half *x, half *output, + cudaStream_t cuda_stream); +template void CalHSigmoidGrad(const size_t &size, const float *dout, const float *x, float *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh index 99a3e377f5..f596b9cfb2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh @@ -24,6 +24,6 @@ template void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream); template -void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream); +void CalHSigmoidGrad(const size_t &size, const T *dout, const T *x, T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h index 75cd061496..6cfc8034ff 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h @@ -38,8 +38,9 @@ class HSigmoidGradKernel : public GpuKernel { const std::vector &outputs, void *stream_ptr) override { VARIABLE_NOT_USED(workspace); T *input = GetDeviceAddress(inputs, 0); + T *x = GetDeviceAddress(inputs, 1); T *output = GetDeviceAddress(outputs, 0); - CalHSigmoidGrad(input_size_, input, output, reinterpret_cast(stream_ptr)); + CalHSigmoidGrad(input_size_, input, x, output, reinterpret_cast(stream_ptr)); return true; } @@ -74,7 +75,6 @@ class HSigmoidGradKernel : public GpuKernel { protected: void InitSizeLists() override { input_size_list_.push_back(input_size_ * sizeof(T)); - // though we are not using this mem, we still need to allocate input_size_list_.push_back(input_size_ * sizeof(T)); output_size_list_.push_back(input_size_ * sizeof(T)); } diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index d9ae0e73cb..b90d075b0c 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -420,6 +420,9 @@ AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor. CheckArgsSize(primitive->name(), args_spec_list, 1); + // add check, types other than half and float are from cpu + auto tensor = CheckArg(primitive->name(), args_spec_list, 0); + (void)CheckTensorDType(tensor, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "Input of HSigmoid should be %s"); return args_spec_list[0]->Broaden(); } @@ -427,6 +430,12 @@ AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor. CheckArgsSize(primitive->name(), args_spec_list, 2); + // add check, types other than half and float are from cpu + auto dout = CheckArg(primitive->name(), args_spec_list, 0); + auto x = CheckArg(primitive->name(), args_spec_list, 1); + (void)CheckTensorDType(dout, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, + "Dout of HSigmoidGrad should be %s"); + (void)CheckTensorDType(x, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "X of HSigmoidGrad should be %s"); return args_spec_list[1]->Broaden(); } diff --git a/tests/st/ops/gpu/test_hsigmoid_op.py b/tests/st/ops/gpu/test_hsigmoid_op.py index 435ca80d3c..09a772357d 100644 --- a/tests/st/ops/gpu/test_hsigmoid_op.py +++ b/tests/st/ops/gpu/test_hsigmoid_op.py @@ -55,29 +55,29 @@ class DynamicNet(nn.Cell): def generate_testcases(nptype): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - x = np.array([-1, -2, 0, 2, 1]).astype(nptype) + x = np.array([-1, -2, 0, 4, 5]).astype(nptype) net = Net() output = net(Tensor(x)) - expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) + expect = np.array([0.33333334, 0.16666667, 0.5, 1, 1]).astype(nptype) np.testing.assert_almost_equal(output.asnumpy(), expect) - sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) + sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(nptype) backward_net = Grad(Net()) output = backward_net(Tensor(x), Tensor(sens)) - expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) + expect = np.array([-0.2416667, 0.1049999, 5.66666685e-02, 0, 0]).astype(nptype) np.testing.assert_almost_equal(output[0].asnumpy(), expect) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - x = np.array([-1, -2, 0, 2, 1]).astype(nptype) + x = np.array([-1, -2, 0, 4, 5]).astype(nptype) net = Net() output = net(Tensor(x)) - expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) + expect = np.array([0.33333334, 0.16666667, 0.5, 1, 1]).astype(nptype) np.testing.assert_almost_equal(output.asnumpy(), expect) - sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) + sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(nptype) backward_net = Grad(Net()) output = backward_net(Tensor(x), Tensor(sens)) - expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) + expect = np.array([-0.2416667, 0.1049999, 5.66666685e-02, 0, 0]).astype(nptype) np.testing.assert_almost_equal(output[0].asnumpy(), expect)