Merge pull request !3176 from 34bunny/GPU-ResizeNearestNeighbortags/v0.6.0-beta
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ResizeNearestNeighborGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ResizeNearestNeighborGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighbor, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ResizeNearestNeighborGpuKernel, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,111 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class ResizeNearestNeighborGpuKernel : public GpuKernel { | |||||
| public: | |||||
| ResizeNearestNeighborGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {} | |||||
| ~ResizeNearestNeighborGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| int size = SizeToInt(output_size_ / sizeof(T)); | |||||
| float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); | |||||
| float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); | |||||
| CalResizeNearestNeighbor(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output, | |||||
| output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], align_corners_, | |||||
| h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| shape_size_ = input_shape.size(); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) { | |||||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " | |||||
| << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; | |||||
| return false; | |||||
| } | |||||
| input_size_ = 1; | |||||
| for (size_t i = 0; i < shape_size_; i++) { | |||||
| input_size_ *= input_shape[i]; | |||||
| input_shape_.push_back(input_shape[i]); | |||||
| } | |||||
| input_size_ *= sizeof(T); | |||||
| output_size_ = 1; | |||||
| for (size_t i = 0; i < shape_size_; i++) { | |||||
| output_size_ *= output_shape[i]; | |||||
| output_shape_.push_back(output_shape[i]); | |||||
| } | |||||
| output_size_ *= sizeof(T); | |||||
| align_corners_ = GetAttr<bool>(kernel_node, "align_corners"); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| } | |||||
| private: | |||||
| float Scaling(const int in_size, const int out_size, bool align_corners) { | |||||
| return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1) | |||||
| : in_size / static_cast<float>(out_size); | |||||
| } | |||||
| bool align_corners_; | |||||
| size_t shape_size_; | |||||
| std::vector<int> input_shape_; | |||||
| std::vector<int> output_shape_; | |||||
| size_t input_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ResizeNearestNeighborGradGpuKernel, float) | |||||
| MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||||
| ResizeNearestNeighborGradGpuKernel, half) | |||||
| MS_REG_GPU_KERNEL_ONE(ResizeNearestNeighborGrad, | |||||
| KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ResizeNearestNeighborGradGpuKernel, int) | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,111 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel.h" | |||||
| #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" | |||||
| namespace mindspore { | |||||
| namespace kernel { | |||||
| template <typename T> | |||||
| class ResizeNearestNeighborGradGpuKernel : public GpuKernel { | |||||
| public: | |||||
| ResizeNearestNeighborGradGpuKernel() : align_corners_(false), shape_size_(0), input_size_(0), output_size_(0) {} | |||||
| ~ResizeNearestNeighborGradGpuKernel() override = default; | |||||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||||
| T *input = GetDeviceAddress<T>(inputs, 0); | |||||
| T *output = GetDeviceAddress<T>(outputs, 0); | |||||
| int size = SizeToInt(output_size_ / sizeof(T)); | |||||
| float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); | |||||
| float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); | |||||
| CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], | |||||
| output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], | |||||
| align_corners_, h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr)); | |||||
| return true; | |||||
| } | |||||
| bool Init(const CNodePtr &kernel_node) override { | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||||
| if (input_num != 1) { | |||||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but ResizeNearestNeighbor needs 1 input."; | |||||
| return false; | |||||
| } | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||||
| if (output_num != 1) { | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; | |||||
| return false; | |||||
| } | |||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||||
| shape_size_ = input_shape.size(); | |||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||||
| if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) { | |||||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " | |||||
| << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; | |||||
| return false; | |||||
| } | |||||
| input_size_ = 1; | |||||
| for (size_t i = 0; i < shape_size_; i++) { | |||||
| input_size_ *= input_shape[i]; | |||||
| input_shape_.push_back(input_shape[i]); | |||||
| } | |||||
| input_size_ *= sizeof(T); | |||||
| output_size_ = 1; | |||||
| for (size_t i = 0; i < shape_size_; i++) { | |||||
| output_size_ *= output_shape[i]; | |||||
| output_shape_.push_back(output_shape[i]); | |||||
| } | |||||
| output_size_ *= sizeof(T); | |||||
| align_corners_ = GetAttr<bool>(kernel_node, "align_corners"); | |||||
| InitSizeLists(); | |||||
| return true; | |||||
| } | |||||
| protected: | |||||
| void InitSizeLists() override { | |||||
| input_size_list_.push_back(input_size_); | |||||
| output_size_list_.push_back(output_size_); | |||||
| } | |||||
| private: | |||||
| float Scaling(const int in_size, const int out_size, bool align_corners) { | |||||
| return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1) | |||||
| : in_size / static_cast<float>(out_size); | |||||
| } | |||||
| bool align_corners_; | |||||
| size_t shape_size_; | |||||
| std::vector<int> input_shape_; | |||||
| std::vector<int> output_shape_; | |||||
| size_t input_size_; | |||||
| size_t output_size_; | |||||
| size_t workspace_size_; | |||||
| std::vector<size_t> input_size_list_; | |||||
| std::vector<size_t> output_size_list_; | |||||
| std::vector<size_t> workspace_size_list_; | |||||
| }; | |||||
| } // namespace kernel | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_RESIZE_NEAREST_NEIGHBOR_GRAD_GPU_KERNEL_H_ | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdio.h> | |||||
| #include <stdint.h> | |||||
| #include <math.h> | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, | |||||
| const int d4, bool align_corners, float h_scale, float w_scale) { | |||||
| // initialization | |||||
| // HalfPixelCenters false | |||||
| int input_pos; | |||||
| int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION]; | |||||
| int in_height = s3; | |||||
| int in_width = s4; | |||||
| // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + | |||||
| // pos_array[1] * output_shape[2] * output_shape[3] + | |||||
| // pos_array[2] * output_shape[3] + | |||||
| // pos_array[3] | |||||
| T h_scale_ = static_cast<T>(h_scale); | |||||
| T w_scale_ = static_cast<T>(w_scale); | |||||
| T out_h_; | |||||
| T out_w_; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||||
| pos_array[0] = pos / (d2 * d3 * d4) % d1; | |||||
| pos_array[1] = pos / (d3 * d4) % d2; | |||||
| pos_array[2] = pos / (d4) % d3; | |||||
| pos_array[3] = pos % d4; | |||||
| out_h_ = static_cast<T>(pos_array[2]); | |||||
| out_w_ = static_cast<T>(pos_array[3]); | |||||
| const int in_y = | |||||
| min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)), | |||||
| in_height - 1); | |||||
| const int in_x = | |||||
| min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)), | |||||
| in_width - 1); | |||||
| // pos_array[0] N, pos_array[1] C, in_y H, in_x W | |||||
| input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; | |||||
| output[pos] = input[input_pos]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | |||||
| bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) { | |||||
| ResizeNearestNeighborGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||||
| size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); | |||||
| return; | |||||
| } | |||||
| template void CalResizeNearestNeighborGrad<float>(const int size, const float *input, const int s1, const int s2, | |||||
| const int s3, const int s4, float *output, const int d1, const int d2, | |||||
| const int d3, const int d4, bool align_corners, float h_scale, | |||||
| float w_scale, cudaStream_t cuda_stream); | |||||
| template void CalResizeNearestNeighborGrad<half>(const int size, const half *input, const int s1, const int s2, | |||||
| const int s3, const int s4, half *output, const int d1, const int d2, | |||||
| const int d3, const int d4, bool align_corners, float h_scale, | |||||
| float w_scale, cudaStream_t cuda_stream); | |||||
| template void CalResizeNearestNeighborGrad<int>(const int size, const int *input, const int s1, const int s2, | |||||
| const int s3, const int s4, int *output, const int d1, const int d2, | |||||
| const int d3, const int d4, bool align_corners, float h_scale, | |||||
| float w_scale, cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ | |||||
| #include <cuda_runtime.h> | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| #define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4 | |||||
| template <typename T> | |||||
| void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | |||||
| bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_GRAD_IMPL_CUH_ | |||||
| @@ -0,0 +1,81 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <stdio.h> | |||||
| #include <stdint.h> | |||||
| #include <math.h> | |||||
| #include <algorithm> | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_impl.cuh" | |||||
| template <typename T> | |||||
| __global__ void ResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | |||||
| bool align_corners, float h_scale, float w_scale) { | |||||
| // initialization | |||||
| // HalfPixelCenters false | |||||
| int input_pos; | |||||
| int pos_array[RESIZENEARESTNEIGHBOR_DIMENSION]; | |||||
| int in_height = s3; | |||||
| int in_width = s4; | |||||
| // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + | |||||
| // pos_array[1] * output_shape[2] * output_shape[3] + | |||||
| // pos_array[2] * output_shape[3] + | |||||
| // pos_array[3] | |||||
| T h_scale_ = static_cast<T>(h_scale); | |||||
| T w_scale_ = static_cast<T>(w_scale); | |||||
| T out_h_; | |||||
| T out_w_; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||||
| pos_array[0] = pos / (d2 * d3 * d4) % d1; | |||||
| pos_array[1] = pos / (d3 * d4) % d2; | |||||
| pos_array[2] = pos / (d4) % d3; | |||||
| pos_array[3] = pos % d4; | |||||
| out_h_ = static_cast<T>(pos_array[2]); | |||||
| out_w_ = static_cast<T>(pos_array[3]); | |||||
| const int in_y = | |||||
| min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)), | |||||
| in_height - 1); | |||||
| const int in_x = | |||||
| min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)), | |||||
| in_width - 1); | |||||
| // pos_array[0] N, pos_array[1] C, in_y H, in_x W | |||||
| input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; | |||||
| output[pos] = input[input_pos]; | |||||
| } | |||||
| return; | |||||
| } | |||||
| template <typename T> | |||||
| void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4, | |||||
| T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, | |||||
| float h_scale, float w_scale, cudaStream_t cuda_stream) { | |||||
| ResizeNearestNeighbor<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, s1, s2, s3, s4, output, d1, d2, | |||||
| d3, d4, align_corners, h_scale, w_scale); | |||||
| return; | |||||
| } | |||||
| template void CalResizeNearestNeighbor<float>(const int size, const float *input, const int s1, const int s2, | |||||
| const int s3, const int s4, float *output, const int d1, const int d2, | |||||
| const int d3, const int d4, bool align_corners, float h_scale, | |||||
| float w_scale, cudaStream_t cuda_stream); | |||||
| template void CalResizeNearestNeighbor<half>(const int size, const half *input, const int s1, const int s2, | |||||
| const int s3, const int s4, half *output, const int d1, const int d2, | |||||
| const int d3, const int d4, bool align_corners, float h_scale, | |||||
| float w_scale, cudaStream_t cuda_stream); | |||||
| template void CalResizeNearestNeighbor<int>(const int size, const int *input, const int s1, const int s2, const int s3, | |||||
| const int s4, int *output, const int d1, const int d2, const int d3, | |||||
| const int d4, bool align_corners, float h_scale, float w_scale, | |||||
| cudaStream_t cuda_stream); | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ | |||||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ | |||||
| #include <cuda_runtime.h> | |||||
| #include "runtime/device/gpu/cuda_common.h" | |||||
| #define RESIZENEARESTNEIGHBOR_DIMENSION 4 | |||||
| template <typename T> | |||||
| void CalResizeNearestNeighbor(const int size, const T *input, const int s1, const int s2, const int s3, const int s4, | |||||
| T *output, const int d1, const int d2, const int d3, const int d4, bool align_corners, | |||||
| float h_scale, float w_scale, cudaStream_t cuda_stream); | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_RESIZE_NEAREST_NEIGHBOR_IMPL_CUH_ | |||||
| @@ -2338,10 +2338,10 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||||
| and output tensors are aligned. Default: False. | and output tensors are aligned. Default: False. | ||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. | |||||
| - **input_x** (Tensor) - The input tensor. The shape of the tensor is :math:`(N, C, H, W)`. | |||||
| Outputs: | Outputs: | ||||
| Tensor, the shape of the output tensor is :math:`(N, NEW\_C, NEW\_H, W)`. | |||||
| Tensor, the shape of the output tensor is :math:`(N, C, NEW\_H, NEW\_W)`. | |||||
| Examples: | Examples: | ||||
| >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) | >>> input_tensor = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mindspore.float32) | ||||
| @@ -2360,7 +2360,7 @@ class ResizeNearestNeighbor(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | self.init_prim_io_names(inputs=['image_in'], outputs=['image_out']) | ||||
| def infer_shape(self, x): | def infer_shape(self, x): | ||||
| validator.check('the dimension of input_x', len(x), '', 2, Rel.GE, self.name) | |||||
| validator.check('the dimension of input_x', len(x), '', 4, Rel.EQ, self.name) | |||||
| return tuple(x)[:-2] + tuple(self.size) | return tuple(x)[:-2] + tuple(self.size) | ||||
| def infer_dtype(self, x): | def infer_dtype(self, x): | ||||
| @@ -0,0 +1,75 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops.operations import _grad_ops as G | |||||
| class ResizeNearestNeighborGradAlignCornerT(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ResizeNearestNeighborGradAlignCornerT, self).__init__() | |||||
| self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad(align_corners=True) | |||||
| def construct(self, dy, size): | |||||
| return self.ResizeNearestNeighborGradAlignCornerT(dy, size) | |||||
| class ResizeNearestNeighborGradAlignCornerF(nn.Cell): | |||||
| def __init__(self): | |||||
| super(ResizeNearestNeighborGradAlignCornerF, self).__init__() | |||||
| self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad(align_corners=False) | |||||
| def construct(self, dy, size): | |||||
| return self.ResizeNearestNeighborGradAlignCornerF(dy, size) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ResizeNearestNeighborGradAlignCornerT(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerT() | |||||
| output = rnn(Tensor(dy), size) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerT() | |||||
| output = rnn(Tensor(dy), size) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ResizeNearestNeighborGradAlignCornerF(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerF() | |||||
| output = rnn(Tensor(dy), size) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerF() | |||||
| output = rnn(Tensor(dy), size) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright 2019 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| import mindspore.nn as nn | |||||
| from mindspore import Tensor | |||||
| from mindspore.ops import operations as P | |||||
| class ResizeNearestNeighborAlignCornerT(nn.Cell): | |||||
| def __init__(self, size): | |||||
| super(ResizeNearestNeighborAlignCornerT, self).__init__() | |||||
| self.ResizeNearestNeighborAlignCornerT = P.ResizeNearestNeighbor(size, align_corners=True) | |||||
| def construct(self, x): | |||||
| return self.ResizeNearestNeighborAlignCornerT(x) | |||||
| class ResizeNearestNeighborAlignCornerF(nn.Cell): | |||||
| def __init__(self, size): | |||||
| super(ResizeNearestNeighborAlignCornerF, self).__init__() | |||||
| self.ResizeNearestNeighborAlignCornerF = P.ResizeNearestNeighbor(size, align_corners=False) | |||||
| def construct(self, x): | |||||
| return self.ResizeNearestNeighborAlignCornerF(x) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ResizeNearestNeighborAlignCornerT(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)) | |||||
| expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | |||||
| rnn = ResizeNearestNeighborAlignCornerT((4, 4)) | |||||
| output = rnn(input_tensor) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)) | |||||
| expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | |||||
| rnn = ResizeNearestNeighborAlignCornerT((4, 4)) | |||||
| output = rnn(input_tensor) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_gpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_ResizeNearestNeighborAlignCornerF(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||||
| input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float32)) | |||||
| expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | |||||
| rnn = ResizeNearestNeighborAlignCornerF((4, 4)) | |||||
| output = rnn(input_tensor) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.float16)) | |||||
| expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | |||||
| rnn = ResizeNearestNeighborAlignCornerF((4, 4)) | |||||
| output = rnn(input_tensor) | |||||
| assert np.all(output.asnumpy() == expect) | |||||