From: @yuan_shen_zhou Reviewed-by: @liangchenghui,@liangchenghui Signed-off-by: @liangchenghui,@liangchenghuitags/v1.1.0
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "l2_loss.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| template <typename T> | |||
| __global__ void L2LossKernel(const size_t input_size, const T *input , T *output) { | |||
| T ret = 0; | |||
| for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < input_size; id += blockDim.x * gridDim.x) { | |||
| ret = (input[id] * input[id]); | |||
| ret /= static_cast<T>(2); | |||
| MsAtomicAdd(output, ret); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream) { | |||
| L2LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, input, output); | |||
| } | |||
| template void L2Loss<float>(const size_t input_size, const float *input , float *output, cudaStream_t stream); | |||
| template void L2Loss<half>(const size_t input_size, const half *input , half *output, cudaStream_t stream); | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_L2_LOSS_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_L2_LOSS_H_ | |||
| template <typename T> | |||
| void L2Loss(const size_t input_size, const T *input , T *output, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_L2_LOSS_H_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/l2_loss_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(L2Loss, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| L2LossGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE(L2Loss, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| L2LossGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_L2_LOSS_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_L2_LOSS_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/l2_loss.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class L2LossGpuKernel : public GpuKernel { | |||
| public: | |||
| L2LossGpuKernel() : input_size_(1) {} | |||
| ~L2LossGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||
| L2Loss(input_size_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(sizeof(T)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_L2_LOSS_GPU_KERNEL_H_ | |||
| @@ -2157,9 +2157,7 @@ class L2Loss(PrimitiveWithInfer): | |||
| Set `input_x` as x and output as loss. | |||
| .. math:: | |||
| loss = sum(x ** 2) / nelement(x) | |||
| :math:`nelement(x)` represents the number of `input_x`. | |||
| loss = sum(x ** 2) / 2 | |||
| Inputs: | |||
| - **input_x** (Tensor) - A input Tensor. Data type must be float16 or float32. | |||
| @@ -2168,7 +2166,7 @@ class L2Loss(PrimitiveWithInfer): | |||
| Tensor, has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor. | |||
| Supported Platforms: | |||
| ``Ascend`` | |||
| ``Ascend`` ``GPU`` | |||
| Examples | |||
| >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) | |||
| @@ -0,0 +1,100 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| import mindspore as ms | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| class L2LossNet(nn.Cell): | |||
| def __init__(self): | |||
| super(L2LossNet, self).__init__() | |||
| self.l2_loss = P.L2Loss() | |||
| def construct(self, x): | |||
| return self.l2_loss(x) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_pynative_fp32_22(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| error = 1e-4 | |||
| x = Tensor(np.array([[1., 2.], [3., 4.]]), ms.float32) | |||
| expect = np.array(15, np.float32) | |||
| output = P.L2Loss()(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_pynative_fp16_22(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| error = 1e-4 | |||
| x = Tensor(np.array([[1., 2.], [3., 4.]]), ms.float16) | |||
| expect = np.array(15, np.float16) | |||
| output = P.L2Loss()(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_pynative_fp32_14(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| error = 1e-4 | |||
| x = Tensor(np.array([1., 2., 3., 4.]), ms.float32) | |||
| expect = np.array(15, np.float32) | |||
| output = P.L2Loss()(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_gather_pynative_fp16_14(): | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| error = 1e-4 | |||
| x = Tensor(np.array([1., 2., 3., 4.]), ms.float16) | |||
| expect = np.array(15, np.float16) | |||
| output = P.L2Loss()(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| def test_gather_graph_fp32_14(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| error = 1e-4 | |||
| x = Tensor(np.array([1., 2., 3., 4.]), ms.float32) | |||
| expect = np.array(15, np.float32) | |||
| l2_loss = L2LossNet() | |||
| output = l2_loss(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| def test_gather_graph_fp16_14(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| error = 1e-4 | |||
| x = Tensor(np.array([1., 2., 3., 4.]), ms.float16) | |||
| expect = np.array(15, np.float16) | |||
| l2_loss = L2LossNet() | |||
| output = l2_loss(x) | |||
| diff = output.asnumpy() - expect | |||
| assert np.all(diff < error) | |||