Merge pull request !17726 from tom_chen/resize_bilineartags/v1.3.0
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cuh" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||
| #include "runtime/device/gpu/cuda_common.h" | |||
| #include "include/cuda_fp16.h" | |||
| template <typename T> | |||
| @@ -46,6 +47,39 @@ __global__ void ResizeBilinear(const T *input, const int n, const int c, const i | |||
| return; | |||
| } | |||
| template <typename T> | |||
| __global__ void ResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w, | |||
| const int output_h, const int output_w, const int nchw, const int chw, const int hw, const float h_scale, | |||
| const float w_scale, T *output) { | |||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nchw; pos += blockDim.x * gridDim.x) { | |||
| const int posn = pos / chw; | |||
| const int posc = pos / hw % c; | |||
| const int posh = pos / input_w % input_h; | |||
| const int posw = pos % input_w; | |||
| const float posw_scaled = w_scale * posw; | |||
| const float posh_scaled = h_scale * posh; | |||
| const int w_low = max(static_cast<int>(floorf(posw_scaled)), 0); // NOLINT | |||
| const int w_high = min(static_cast<int>(ceilf(posw_scaled)), output_w - 1); // NOLINT | |||
| const int h_low = max(static_cast<int>(floorf(posh_scaled)), 0); // NOLINT | |||
| const int h_high = min(static_cast<int>(ceilf(posh_scaled)), output_h - 1); // NOLINT | |||
| const float w_alpha = posw_scaled - w_low; | |||
| const float w_beta = 1.0f - w_alpha; | |||
| const float h_alpha = posh_scaled - h_low; | |||
| const float h_beta = 1.0f - h_alpha; | |||
| const float grad = input[pos]; | |||
| const T dp1 = static_cast<T>(h_beta * w_beta * grad); | |||
| const T dp2 = static_cast<T>(h_beta * w_alpha * grad); | |||
| const T dp3 = static_cast<T>(h_alpha * w_beta * grad); | |||
| const T dp4 = static_cast<T>(h_alpha * w_alpha * grad); | |||
| const int output_start = output_h * output_w * (posn * c + posc); | |||
| MsAtomicAdd(&output[output_start + (h_low * output_w) + w_low], dp1); | |||
| MsAtomicAdd(&output[output_start + (h_low * output_w) + w_high], dp2); | |||
| MsAtomicAdd(&output[output_start + (h_high * output_w) + w_low], dp3); | |||
| MsAtomicAdd(&output[output_start + (h_high * output_w) + w_high], dp4); | |||
| } | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalResizeBilinear(const T *input, const int n, const int c, const int input_h, const int input_w, | |||
| const int output_h, const int output_w, const float h_scale, const float w_scale, float *output, | |||
| @@ -58,9 +92,28 @@ void CalResizeBilinear(const T *input, const int n, const int c, const int input | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void CalResizeBilinearGrad(const float *input, const int n, const int c, const int input_h, const int input_w, | |||
| const int output_h, const int output_w, const float h_scale, const float w_scale, T *output, | |||
| cudaStream_t cuda_stream) { | |||
| const int hw = input_h * input_w; | |||
| const int chw = c * hw; | |||
| const int nchw = n * chw; | |||
| ResizeBilinearGrad<<<GET_BLOCKS(nchw), GET_THREADS, 0, cuda_stream>>>(input, n, c, input_h, input_w, output_h, | |||
| output_w, nchw, chw, hw, h_scale, w_scale, output); | |||
| return; | |||
| } | |||
| template void CalResizeBilinear<float>(const float *input, const int n, const int c, const int input_h, | |||
| const int input_w, const int output_h, const int output_w, const float h_scale, const float w_scale, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalResizeBilinear<half>(const half *input, const int n, const int c, const int input_h, | |||
| const int input_w, const int output_h, const int output_w, const float h_scale, const float w_scale, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalResizeBilinearGrad<float>(const float *input, const int n, const int c, const int input_h, | |||
| const int input_w, const int output_h, const int output_w, const float h_scale, const float w_scale, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template void CalResizeBilinearGrad<half>(const float *input, const int n, const int c, const int input_h, | |||
| const int input_w, const int output_h, const int output_w, const float h_scale, const float w_scale, half *output, | |||
| cudaStream_t cuda_stream); | |||
| @@ -21,4 +21,8 @@ template <typename T> | |||
| void CalResizeBilinear(const T *input, const int n_, const int c_, const int input_h_, const int input_w_, | |||
| const int output_h_, const int output_w_, const float h_scale, const float w_scale, float *output, | |||
| cudaStream_t cuda_stream); | |||
| template <typename T> | |||
| void CalResizeBilinearGrad(const float *input, const int n_, const int c_, const int input_h_, const int input_w_, | |||
| const int output_h_, const int output_w_, const float h_scale, const float w_scale, T *output, | |||
| cudaStream_t cuda_stream); | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RESIZE_BILINEAR_H_ | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/gpu/nn/resize_bilinear_grad_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ResizeBilinearGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| ResizeBilinearGradGpuKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| ResizeBilinearGrad, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| ResizeBilinearGradGpuKernel, half) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,137 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RESIZE_BILINEAR_GRAD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RESIZE_BILINEAR_GRAD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_bilinear_impl.cuh" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class ResizeBilinearGradGpuKernel : public GpuKernel { | |||
| public: | |||
| ResizeBilinearGradGpuKernel() { ResetResource(); } | |||
| ~ResizeBilinearGradGpuKernel() override = default; | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| float *dy = GetDeviceAddress<float>(inputs, 0); | |||
| T *dx = GetDeviceAddress<T>(outputs, 0); | |||
| float h_scale = Scaling(dx_h_, dy_h_, align_corners_); | |||
| float w_scale = Scaling(dx_w_, dy_w_, align_corners_); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemsetAsync(dx, 0, dx_size_, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemsetAsync dx failed"); | |||
| CalResizeBilinearGrad(dy, n_, c_, dy_h_, dy_w_, dx_h_, dx_w_, h_scale, w_scale, dx, | |||
| reinterpret_cast<cudaStream_t>(stream_ptr)); | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) override { | |||
| kernel_node_ = kernel_node; | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeBilinearGrad needs 1 input."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeBilinearGrad has 1 output."; | |||
| return false; | |||
| } | |||
| std::vector<size_t> dy_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| std::vector<size_t> x_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| std::vector<size_t> dx_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| if (dy_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "Input is " << dy_shape.size() << "-D, but ResizeBilinearGrad supports only 4-D inputs."; | |||
| return false; | |||
| } | |||
| if (x_shape.size() != 4) { | |||
| MS_LOG(ERROR) << "Input is " << x_shape.size() << "-D, but ResizeBilinearGrad supports only 4-D inputs."; | |||
| return false; | |||
| } | |||
| n_ = SizeToInt(dy_shape[0]); | |||
| c_ = SizeToInt(dy_shape[1]); | |||
| dy_h_ = SizeToInt(dy_shape[2]); | |||
| dy_w_ = SizeToInt(dy_shape[3]); | |||
| dx_h_ = SizeToInt(dx_shape[2]); | |||
| dx_w_ = SizeToInt(dx_shape[3]); | |||
| dy_size_ = sizeof(float); | |||
| for (auto x : dy_shape) { | |||
| dy_size_ *= x; | |||
| } | |||
| dx_size_ = sizeof(T); | |||
| for (auto x : dx_shape) { | |||
| dx_size_ *= x; | |||
| } | |||
| align_corners_ = GetAttr<bool>(kernel_node, "align_corners"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| void ResetResource() noexcept override { | |||
| align_corners_ = false; | |||
| n_ = 0; | |||
| c_ = 0; | |||
| dy_h_ = 0; | |||
| dy_w_ = 0; | |||
| dx_h_ = 0; | |||
| dx_w_ = 0; | |||
| dy_size_ = 0; | |||
| dx_size_ = 0; | |||
| workspace_size_ = 0; | |||
| input_size_list_.clear(); | |||
| output_size_list_.clear(); | |||
| workspace_size_list_.clear(); | |||
| } | |||
| protected: | |||
| void InitSizeLists() override { | |||
| input_size_list_.push_back(dy_size_); | |||
| output_size_list_.push_back(dx_size_); | |||
| } | |||
| private: | |||
| float Scaling(const int in_size, const int out_size, bool align_corners) { | |||
| return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1) | |||
| : in_size / static_cast<float>(out_size); | |||
| } | |||
| bool align_corners_; | |||
| int n_; | |||
| int c_; | |||
| int dy_h_; | |||
| int dy_w_; | |||
| int dx_h_; | |||
| int dx_w_; | |||
| size_t dy_size_; | |||
| size_t dx_size_; | |||
| size_t workspace_size_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_RESIZE_BILINEAR_GRAD_GPU_KERNEL_H_ | |||
| @@ -0,0 +1,88 @@ | |||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore.ops.operations import _grad_ops as G | |||
| class ResizeBilinearGradNet(nn.Cell): | |||
| def __init__(self, align_corners=False): | |||
| super(ResizeBilinearGradNet, self).__init__() | |||
| self.rb1 = G.ResizeBilinearGrad(align_corners=align_corners) | |||
| def construct(self, dy, size): | |||
| return self.rb1(dy, size) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_resize_bilinear_grad_align_corners(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) | |||
| x = np.array([[[[1.1, 2.2, 3.2, 2.5], | |||
| [3.3, 4.4, 5.7, 8.1], | |||
| [3.3, 4.4, 5.7, 8.1], | |||
| [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float16) | |||
| expect = np.array([[[[1., 0., 0., 2.], | |||
| [0., 0., 0., 0.], | |||
| [0., 0., 0., 0.], | |||
| [3., 0., 0., 4.]]]]).astype(np.float16) | |||
| net = ResizeBilinearGradNet(align_corners=True) | |||
| output = net(Tensor(dy), Tensor(x)) | |||
| assert np.all(output.asnumpy() == expect) | |||
| x = np.array([[[[1.1, 2.2, 3.2, 2.5], | |||
| [3.3, 4.4, 5.7, 8.1], | |||
| [3.3, 4.4, 5.7, 8.1], | |||
| [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float32) | |||
| expect = np.array([[[[1., 0., 0., 2.], | |||
| [0., 0., 0., 0.], | |||
| [0., 0., 0., 0.], | |||
| [3., 0., 0., 4.]]]]).astype(np.float32) | |||
| net = ResizeBilinearGradNet(align_corners=True) | |||
| output = net(Tensor(dy), Tensor(x)) | |||
| assert np.all(output.asnumpy() == expect) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_resize_bilinear_grad(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| dy = np.array([[[[1, 1, 0, 0], | |||
| [1, 1, 0, 0], | |||
| [0, 0, 1, 1], | |||
| [0, 0, 1, 1]]]]).astype(np.float32) | |||
| x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float16) | |||
| expect = np.array([[[[2.25, 0.75], | |||
| [0.75, 4.25]]]]).astype(np.float16) | |||
| net = ResizeBilinearGradNet() | |||
| output = net(Tensor(dy), Tensor(x)) | |||
| assert np.all(output.asnumpy() == expect) | |||
| x = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32) | |||
| expect = np.array([[[[2.25, 0.75], | |||
| [0.75, 4.25]]]]).astype(np.float32) | |||
| net = ResizeBilinearGradNet() | |||
| output = net(Tensor(dy), Tensor(x)) | |||
| assert np.all(output.asnumpy() == expect) | |||