| @@ -0,0 +1,51 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh" | |||
| template <typename T> | |||
| __global__ void HsigmoidKernel(size_t size, const T *input, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| T value = (input[pos] + static_cast<T>(3)) / static_cast<T>(6); | |||
| value = value > static_cast<T>(1) ? static_cast<T>(1) : value; | |||
| output[pos] = value > static_cast<T>(0) ? value : static_cast<T>(0); | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void HsigmoidGradKernel(size_t size, const T *dout, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { | |||
| T value = dout[pos] / static_cast<T>(6); | |||
| value = value > static_cast<T>(1) ? static_cast<T>(0) : value; | |||
| output[pos] = value > static_cast<T>(0) ? value : static_cast<T>(0); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream) { | |||
| HsigmoidKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); | |||
| } | |||
| template <typename T> | |||
| void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream) { | |||
| HsigmoidGradKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, dout, output); | |||
| } | |||
| template void CalHSigmoid<half>(const size_t &size, const half *input, half *output, cudaStream_t cuda_stream); | |||
| template void CalHSigmoid<float>(const size_t &size, const float *input, float *output, cudaStream_t cuda_stream); | |||
| template void CalHSigmoidGrad<half>(const size_t &size, const half *dout, half *output, cudaStream_t cuda_stream); | |||
| template void CalHSigmoidGrad<float>(const size_t &size, const float *dout, float *output, cudaStream_t cuda_stream); | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ | |||
| #include <cuda_runtime.h> | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| void CalHSigmoid(const size_t &size, const T *input, T *output, cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalHSigmoidGrad(const size_t &size, const T *dout, T *output, cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_HSIGMOID_IMPL_CUH_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/hsigmoid_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| HSigmoidKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(HSigmoid, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| HSigmoidKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,88 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class HSigmoidKernel : public GpuKernel { | |||
| public: | |||
| HSigmoidKernel() { ResetResource(); } | |||
| ~HSigmoidKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalHSigmoid(input_size_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSigmoid needs 1 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but HSigmoid has 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| input_size_ = 1; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/hsigmoid_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| HSigmoidGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| HSigmoidGradKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| HSigmoidGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| HSigmoidGradKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/hsigmoid_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class HSigmoidGradKernel : public GpuKernel { | |||
| public: | |||
| HSigmoidGradKernel() { ResetResource(); } | |||
| ~HSigmoidGradKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| VARIABLE_NOT_USED(workspace); | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| CalHSigmoidGrad(input_size_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but HSigmoidGrad needs 2 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but HSigmoidGrad has 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| input_size_ = 1; | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| input_size_ = 1; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| // though we are not using this mem, we still need to allocate | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_HSIGMOID_GRAD_GPU_KERNEL_H_ | |||
| @@ -59,6 +59,10 @@ AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitiveP | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -416,6 +416,20 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplHSigmoid(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 1); | |||
| return args_spec_list[0]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplHSigmoidGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| CheckArgsSize(primitive->name(), args_spec_list, 2); | |||
| return args_spec_list[1]->Broaden(); | |||
| } | |||
| AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: a tensor. | |||
| @@ -132,6 +132,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimSparseApplyProximalAdagrad, {InferImplSparseApplyProximalAdagrad, nullptr, true}}, | |||
| {prim::kPrimSGD, {InferImplSGD, nullptr, true}}, | |||
| {prim::kPrimCTCGreedyDecoder, {InferImplCTCGreedyDecoder, nullptr, true}}, | |||
| {prim::kPrimHSigmoid, {InferImplHSigmoid, nullptr, true}}, | |||
| {prim::kPrimHSigmoidGrad, {InferImplHSigmoidGrad, nullptr, true}}, | |||
| // Others | |||
| {prim::kPrimIdentity, {InferImplIdentity, nullptr, true}}, | |||
| {prim::kPrimLoad, {InferImplLoad, nullptr, true}}, | |||
| @@ -521,6 +521,8 @@ inline const PrimitivePtr kPrimSubFusion = std::make_shared<Primitive>("SubFusio | |||
| inline const PrimitivePtr kPrimMulFusion = std::make_shared<Primitive>("MulFusion"); | |||
| inline const PrimitivePtr kPrimSigmoid = std::make_shared<Primitive>("Sigmoid"); | |||
| inline const PrimitivePtr kPrimSigmoidGrad = std::make_shared<Primitive>("SigmoidGrad"); | |||
| inline const PrimitivePtr kPrimHSigmoid = std::make_shared<Primitive>("HSigmoid"); | |||
| inline const PrimitivePtr kPrimHSigmoidGrad = std::make_shared<Primitive>("HSigmoidGrad"); | |||
| inline const PrimitivePtr kPrimClip = std::make_shared<Primitive>("Clip"); | |||
| inline const PrimitivePtr kPrimHardTanh = std::make_shared<Primitive>("HardTanh"); | |||
| inline const PrimitivePtr kPrimDepthWiseConv2DTransposeFusion = | |||
| @@ -0,0 +1,111 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops.composite import GradOperation | |||
| from mindspore.ops.operations import _inner_ops as inner | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = GradOperation(get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, input_x, dout): | |||
| return self.grad(self.network)(input_x, dout) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.HSigmoid = P.HSigmoid() | |||
| def construct(self, x): | |||
| return self.HSigmoid(x) | |||
| class DynamicNet(nn.Cell): | |||
| def __init__(self): | |||
| super(DynamicNet, self).__init__() | |||
| self.HSigmoid = P.HSigmoid() | |||
| self.d = inner.GpuConvertToDynamicShape() | |||
| def construct(self, x): | |||
| x = self.d(x) | |||
| return self.HSigmoid(x) | |||
| def generate_testcases(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = np.array([-1, -2, 0, 2, 1]).astype(nptype) | |||
| net = Net() | |||
| output = net(Tensor(x)) | |||
| expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | |||
| sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) | |||
| backward_net = Grad(Net()) | |||
| output = backward_net(Tensor(x), Tensor(sens)) | |||
| expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) | |||
| np.testing.assert_almost_equal(output[0].asnumpy(), expect) | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| x = np.array([-1, -2, 0, 2, 1]).astype(nptype) | |||
| net = Net() | |||
| output = net(Tensor(x)) | |||
| expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | |||
| sens = np.array([-1.45, -2.63, 0.34, 6.43, 34.6]).astype(nptype) | |||
| backward_net = Grad(Net()) | |||
| output = backward_net(Tensor(x), Tensor(sens)) | |||
| expect = np.array([0, 0, 5.66666685e-02, 0, 0]).astype(nptype) | |||
| np.testing.assert_almost_equal(output[0].asnumpy(), expect) | |||
| def generate_dynamic_testcase(nptype): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| x = np.array([-1, -2, 0, 2, 1]).astype(nptype) | |||
| net = DynamicNet() | |||
| output = net(Tensor(x)) | |||
| expect = np.array([0.33333334, 0.16666667, 0.5, 0.8333333, 0.6666667]).astype(nptype) | |||
| np.testing.assert_almost_equal(output.asnumpy(), expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_hsigmoid_dynamic_float32(): | |||
| generate_dynamic_testcase(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_hsigmoid_float32(): | |||
| generate_testcases(np.float32) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_hsigmoid_float16(): | |||
| generate_testcases(np.float16) | |||