Merge pull request !2962 from chenweifeng/smoothl1losstags/v0.6.0-beta
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "smooth_l1_loss_impl.cuh" | |||
| #include "device/gpu/cuda_common.h" | |||
| template <typename T> | |||
| __global__ void SmoothL1LossKernel(const int input_size, const float sigma, const T *prediction, const T *target, | |||
| T *loss) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | |||
| T value = (prediction[i] - target[i]) > 0 ? (prediction[i] - target[i]) : (target[i] - prediction[i]); | |||
| if (value < sigma) { | |||
| loss[i] = static_cast<T>(0.5) * value * value; | |||
| } else { | |||
| loss[i] = value - static_cast<T>(0.5); | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, | |||
| cudaStream_t stream) { | |||
| SmoothL1LossKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target, loss); | |||
| } | |||
| template <typename T> | |||
| __global__ void SmoothL1LossGradKernel(const int input_size, const float sigma, const T *prediction, const T *target, | |||
| const T *dloss, T *dx) { | |||
| for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < input_size; i += blockDim.x * gridDim.x) { | |||
| T value = prediction[i] - target[i]; | |||
| if (value > static_cast<T>(sigma)) { | |||
| dx[i] = dloss[i]; | |||
| } else if (value < static_cast<T>(-sigma)) { | |||
| dx[i] = -dloss[i]; | |||
| } else { | |||
| dx[i] = value * dloss[i]; | |||
| } | |||
| } | |||
| } | |||
| template <typename T> | |||
| void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, | |||
| T *dx, cudaStream_t stream) { | |||
| SmoothL1LossGradKernel<<<GET_BLOCKS(input_size), GET_THREADS, 0, stream>>>(input_size, sigma, prediction, target, | |||
| dloss, dx); | |||
| } | |||
| template void SmoothL1Loss(const int &input_size, const float &sigma, const float *prediction, const float *target, | |||
| float *loss, cudaStream_t stream); | |||
| template void SmoothL1LossGrad(const int &input_size, const float &sigma, const float *prediction, const float *target, | |||
| const float *dloss, float *dx, cudaStream_t stream); | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | |||
| template <typename T> | |||
| void SmoothL1Loss(const int &input_size, const float &sigma, const T *prediction, const T *target, T *loss, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void SmoothL1LossGrad(const int &input_size, const float &sigma, const T *prediction, const T *target, const T *dloss, | |||
| T *dx, cudaStream_t stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_ | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SmoothL1Loss, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SmoothL1LossGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SmoothL1LossGpuKernel : public GpuKernel { | |||
| public: | |||
| SmoothL1LossGpuKernel() : input_size_(1), sigma_(1.0) {} | |||
| ~SmoothL1LossGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *prediction = GetDeviceAddress<T>(inputs, 0); | |||
| T *target = GetDeviceAddress<T>(inputs, 1); | |||
| T *loss = GetDeviceAddress<T>(outputs, 0); | |||
| SmoothL1Loss(input_size_, sigma_, prediction, target, loss, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| sigma_ = GetAttr<float>(kernel_node, "sigma"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| float sigma_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE(SmoothL1LossGrad, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| SmoothL1LossGradGpuKernel, float) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,76 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SmoothL1LossGradGpuKernel : public GpuKernel { | |||
| public: | |||
| SmoothL1LossGradGpuKernel() : input_size_(1), sigma_(1.0) {} | |||
| ~SmoothL1LossGradGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| T *prediction = GetDeviceAddress<T>(inputs, 0); | |||
| T *target = GetDeviceAddress<T>(inputs, 1); | |||
| T *dloss = GetDeviceAddress<T>(inputs, 2); | |||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||
| SmoothL1LossGrad(input_size_, sigma_, prediction, target, dloss, dx, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| for (size_t i = 0; i < input_shape.size(); i++) { | |||
| input_size_ *= input_shape[i]; | |||
| } | |||
| sigma_ = GetAttr<float>(kernel_node, "sigma"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| input_size_list_.push_back(input_size_ * sizeof(T)); | |||
| output_size_list_.push_back(input_size_ * sizeof(T)); | |||
| } | |||
| private: | |||
| size_t input_size_; | |||
| float sigma_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops import composite as C | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU", save_graphs=True) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_smoothl1loss(): | |||
| np.random.seed(42) | |||
| prediction = np.random.randn(20).astype(np.float32) | |||
| target = np.random.randn(20).astype(np.float32) | |||
| sigma = 1.0 | |||
| net = nn.SmoothL1Loss(sigma) | |||
| loss = net(Tensor(prediction), Tensor(target)) | |||
| expect = [0.46941718, 0.00382918, 0.16829303, 2.447778, 0.04812113, 0.05953304, | |||
| 2.2302065, 0.07672881, 0.00860204, 0.34798968, 0.00956192, 1.818008, | |||
| 0.03262977, 0.36599946, 2.047463, 0.2168481, 0.7216947, 1.7739174, | |||
| 0.08826803, 1.109165] | |||
| assert np.allclose(loss.asnumpy(), expect) | |||
| class Grad(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Grad, self).__init__() | |||
| self.grad = C.GradOperation(name="get_all", get_all=True, sens_param=True) | |||
| self.network = network | |||
| def construct(self, x1, x2, sens): | |||
| gout = self.grad(self.network)(x1, x2, sens) | |||
| return gout | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_smoothl1loss_grad(): | |||
| np.random.seed(42) | |||
| prediction = np.random.randn(20).astype(np.float32) | |||
| target = np.random.randn(20).astype(np.float32) | |||
| sens = np.random.randn(20).astype(np.float32) | |||
| sigma = 1.0 | |||
| net = nn.SmoothL1Loss(sigma) | |||
| grad = Grad(net) | |||
| dx = grad(Tensor(prediction), Tensor(target), Tensor(sens)) | |||
| dx1_expect = [-0.71552587, 0.01499678, -0.06709455, -0.30110368, -0.45868093, | |||
| 0.24838912, -0.46063876, 0.41411355, 0.04507046, -1.4708229, | |||
| 0.04481723, 0.38508227, -0.17292616, -0.52333146, -1.0309995, | |||
| 0.61330026, 0.83921754, -0.3092124, 0.1391843, -0.9755451] | |||
| dx2_expect = [0.71552587, -0.01499678, 0.06709455, 0.30110368, 0.45868093, | |||
| -0.24838912, 0.46063876, -0.41411355, -0.04507046, 1.4708229, | |||
| -0.04481723, -0.38508227, 0.17292616, 0.52333146, 1.0309995, | |||
| -0.61330026, -0.83921754, 0.3092124, -0.1391843, 0.9755451] | |||
| assert np.allclose(dx[0].asnumpy(), dx1_expect) | |||
| assert np.allclose(dx[1].asnumpy(), dx2_expect) | |||