From: @TFbunny Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui,@wuxuejianpull/15858/MERGE
| @@ -26,11 +26,10 @@ __global__ void HsigmoidKernel(size_t size, const T *input, T *output) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void HsigmoidGradKernel(size_t size, const T *dout, T *output) { | |||||
| __global__ void HsigmoidGradKernel(size_t size, const T *dout, const T *x, T *output) { | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | ||||
| T value = dout[pos] / static_cast<T>(6); | T value = dout[pos] / static_cast<T>(6); | ||||
| value = value > static_cast<T>(1) ? static_cast<T>(0) : value; | |||||
| output[pos] = value > static_cast<T>(0) ? value : static_cast<T>(0); | |||||
| output[pos] = (x[pos] > static_cast<T>(-3) && x[pos] < static_cast<T>(3)) ? value : static_cast<T>(0); | |||||
| } | } | ||||
| } | } | ||||
| @@ -40,12 +39,14 @@ void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cud | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream) { | |||||
| HsigmoidGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dout, output); | |||||
| void CalHSigmoidGrad(const size_t &size, const T *dout, const T *x, T *output, cudaStream_t cuda_stream) { | |||||
| HsigmoidGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dout, x, output); | |||||
| } | } | ||||
| template void CalHSigmoid<half>(const size_t &size, const half *input, half *output, cudaStream_t cuda_stream); | template void CalHSigmoid<half>(const size_t &size, const half *input, half *output, cudaStream_t cuda_stream); | ||||
| template void CalHSigmoid<float>(const size_t &size, const float *input, float *output, cudaStream_t cuda_stream); | template void CalHSigmoid<float>(const size_t &size, const float *input, float *output, cudaStream_t cuda_stream); | ||||
| template void CalHSigmoidGrad<half>(const size_t &size, const half *dout, half *output, cudaStream_t cuda_stream); | |||||
| template void CalHSigmoidGrad<float>(const size_t &size, const float *dout, float *output, cudaStream_t cuda_stream); | |||||
| template void CalHSigmoidGrad<half>(const size_t &size, const half *dout, const half *x, half *output, | |||||
| cudaStream_t cuda_stream); | |||||
| template void CalHSigmoidGrad<float>(const size_t &size, const float *dout, const float *x, float *output, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -24,6 +24,6 @@ template <typename T> | |||||
| void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream); | void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream); | ||||
| template <typename T> | template <typename T> | ||||
| void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream); | |||||
| void CalHSigmoidGrad(const size_t &size, const T *dout, const T *x, T *output, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ | #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ | ||||
| @@ -38,8 +38,9 @@ class HSigmoidGradKernel : public GpuKernel { | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| VARIABLE_NOT_USED(workspace); | VARIABLE_NOT_USED(workspace); | ||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| T *x = GetDeviceAddress<T>(inputs, 1); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | T *output = GetDeviceAddress<T>(outputs, 0); | ||||
| CalHSigmoidGrad(input_size_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| CalHSigmoidGrad(input_size_, input, x, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -74,7 +75,6 @@ class HSigmoidGradKernel : public GpuKernel { | |||||
| protected: | protected: | ||||
| void InitSizeLists() override { | void InitSizeLists() override { | ||||
| input_size_list_.push_back(input_size_ * sizeof(T)); | input_size_list_.push_back(input_size_ * sizeof(T)); | ||||
| // though we are not using this mem, we still need to allocate | |||||
| input_size_list_.push_back(input_size_ * sizeof(T)); | input_size_list_.push_back(input_size_ * sizeof(T)); | ||||
| output_size_list_.push_back(input_size_ * sizeof(T)); | output_size_list_.push_back(input_size_ * sizeof(T)); | ||||
| } | } | ||||
| @@ -420,6 +420,9 @@ AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr | |||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tensor. | // Inputs: a tensor. | ||||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | CheckArgsSize(primitive->name(), args_spec_list, 1); | ||||
| // add check, types other than half and float are from cpu | |||||
| auto tensor = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0); | |||||
| (void)CheckTensorDType(tensor, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "Input of HSigmoid should be %s"); | |||||
| return args_spec_list[0]->Broaden(); | return args_spec_list[0]->Broaden(); | ||||
| } | } | ||||
| @@ -427,6 +430,12 @@ AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const Primitive | |||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: a tensor. | // Inputs: a tensor. | ||||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | CheckArgsSize(primitive->name(), args_spec_list, 2); | ||||
| // add check, types other than half and float are from cpu | |||||
| auto dout = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 0); | |||||
| auto x = CheckArg<AbstractTensor>(primitive->name(), args_spec_list, 1); | |||||
| (void)CheckTensorDType(dout, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, | |||||
| "Dout of HSigmoidGrad should be %s"); | |||||
| (void)CheckTensorDType(x, {kInt8, kInt16, kInt32, kInt64, kFloat16, kFloat32}, "X of HSigmoidGrad should be %s"); | |||||
| return args_spec_list[1]->Broaden(); | return args_spec_list[1]->Broaden(); | ||||
| } | } | ||||
| @@ -55,29 +55,29 @@ class DynamicNet(nn.Cell): | |||||
| def generate_testcases(nptype): | def generate_testcases(nptype): | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | ||||
| x = np.array([-1, -2, 0, 2, 1]).astype(nptype) | |||||
| x = np.array([-1, -2, 0, 4, 5]).astype(nptype) | |||||
| net = Net() | net = Net() | ||||
| output = net(Tensor(x)) | output = net(Tensor(x)) | ||||
| expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) | |||||
| expect = np.array([0.33333334, 0.16666667, 0.5, 1, 1]).astype(nptype) | |||||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | np.testing.assert_almost_equal(output.asnumpy(), expect) | ||||
| sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) | |||||
| sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(nptype) | |||||
| backward_net = Grad(Net()) | backward_net = Grad(Net()) | ||||
| output = backward_net(Tensor(x), Tensor(sens)) | output = backward_net(Tensor(x), Tensor(sens)) | ||||
| expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) | |||||
| expect = np.array([-0.2416667, 0.1049999, 5.66666685e-02, 0, 0]).astype(nptype) | |||||
| np.testing.assert_almost_equal(output[0].asnumpy(), expect) | np.testing.assert_almost_equal(output[0].asnumpy(), expect) | ||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | ||||
| x = np.array([-1, -2, 0, 2, 1]).astype(nptype) | |||||
| x = np.array([-1, -2, 0, 4, 5]).astype(nptype) | |||||
| net = Net() | net = Net() | ||||
| output = net(Tensor(x)) | output = net(Tensor(x)) | ||||
| expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) | |||||
| expect = np.array([0.33333334, 0.16666667, 0.5, 1, 1]).astype(nptype) | |||||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | np.testing.assert_almost_equal(output.asnumpy(), expect) | ||||
| sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) | |||||
| sens = np.array([-1.45, 0.63, 0.34, 6.43, 34.6]).astype(nptype) | |||||
| backward_net = Grad(Net()) | backward_net = Grad(Net()) | ||||
| output = backward_net(Tensor(x), Tensor(sens)) | output = backward_net(Tensor(x), Tensor(sens)) | ||||
| expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) | |||||
| expect = np.array([-0.2416667, 0.1049999, 5.66666685e-02, 0, 0]).astype(nptype) | |||||
| np.testing.assert_almost_equal(output[0].asnumpy(), expect) | np.testing.assert_almost_equal(output[0].asnumpy(), expect) | ||||