From 2283eac723bc9147bc35e4b6423c83201e65fc62 Mon Sep 17 00:00:00 2001 From: wanyiming Date: Fri, 18 Dec 2020 20:13:58 +0800 Subject: [PATCH] add_resize_ops --- .../backend/kernel_compiler/common_utils.cc | 21 + .../backend/kernel_compiler/common_utils.h | 18 + .../cpu/resize_bilinear_cpu_kernel.cc | 113 ++++ .../cpu/resize_bilinear_cpu_kernel.h | 58 ++ .../cpu/resize_bilinear_grad_cpu_kernel.cc | 106 ++++ .../cpu/resize_bilinear_grad_cpu_kernel.h | 62 ++ .../cpu/resize_nearest_neighbor_cpu_kernel.cc | 94 +++ .../cpu/resize_nearest_neighbor_cpu_kernel.h | 68 ++ ...resize_nearest_neighbor_grad_cpu_kernel.cc | 91 +++ .../resize_nearest_neighbor_grad_cpu_kernel.h | 68 ++ .../ops/cpu/test_resize_bilinear_grad_op.py | 83 +++ tests/st/ops/cpu/test_resize_bilinear_op.py | 571 +++++++++++++++++ .../test_resize_nearest_neighbor_grad_op.py | 93 +++ .../cpu/test_resize_nearest_neighbor_op.py | 589 ++++++++++++++++++ 14 files changed, 2035 insertions(+) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.h create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.h create mode 100644 tests/st/ops/cpu/test_resize_bilinear_grad_op.py create mode 100644 tests/st/ops/cpu/test_resize_bilinear_op.py create mode 100644 tests/st/ops/cpu/test_resize_nearest_neighbor_grad_op.py create mode 100755 tests/st/ops/cpu/test_resize_nearest_neighbor_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index ec85eb367a..77e1b295b5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -826,5 +826,26 @@ std::string GetProcessorStr(const AnfNodePtr &anf_node) { return processor; } + +float Scaling(size_t in_size, size_t out_size, bool align_corners) { + return (align_corners && out_size > 1) ? (in_size - 1) / static_cast(out_size - 1) + : in_size / static_cast(out_size); +} + +float ScaleGrid(const int x, const float scale) { return static_cast(x) * scale; } + +void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale, + CachedInterpolation *interpolation) { + interpolation[out_size].lower = 0; + interpolation[out_size].upper = 0; + for (size_t i = 0; i <= out_size - 1; ++i) { + const float in = ScaleGrid(i, scale); + const float in_f = std::floor(in); + interpolation[i].lower = std::max(static_cast(in_f), static_cast(0)); + interpolation[i].upper = std::min(static_cast(std::ceil(in)), in_size - 1); + interpolation[i].lerp = in - in_f; + } +} + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h index b23c7334b5..24092a55b7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.h +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.h @@ -102,6 +102,16 @@ void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector GetReduceAttrAxis(const CNodePtr &cnode); std::string GetProcessorStr(const AnfNodePtr &anf_node); +float Scaling(size_t in_size, size_t out_size, bool align_corners); +float ScaleGrid(const int x, const float scale); +struct CachedInterpolation { + size_t lower; + size_t upper; + float lerp; +}; + +void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale, + CachedInterpolation *interpolation); template inline std::string Vector2Str(const std::vector &inputs) { @@ -113,6 +123,14 @@ inline std::string Vector2Str(const std::vector &inputs) { } return ""; } + +template +inline T ComputeLerp(T top_left, T top_right, T bottom_left, T bottom_right, T x_lerp, T y_lerp) { + T top = top_left + (top_right - top_left) * x_lerp; + T bottom = bottom_left + (bottom_right - bottom_left) * x_lerp; + return top + (bottom - top) * y_lerp; +} + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.cc new file mode 100644 index 0000000000..0e70c8d70a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { + +void ResizeBilinearCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + size_ = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + align_corners_ = AnfAlgo::GetNodeAttr(kernel_node, "align_corners"); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + + size_t in_height = shape_[2]; + size_t in_width = shape_[3]; + size_t out_height = size_[0]; + size_t out_width = size_[1]; + height_scale = Scaling(in_height, out_height, align_corners_); + width_scale = Scaling(in_width, out_width, align_corners_); +} + +bool ResizeBilinearCPUKernel::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void ResizeBilinearCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + size_t batch_size = shape_[0]; + size_t channel = shape_[1]; + size_t in_height = shape_[2]; + size_t in_width = shape_[3]; + size_t out_height = size_[0]; + size_t out_width = size_[1]; + size_t out_hw_size = out_height * out_width; + size_t in_hw_size = in_height * in_width; + size_t bhwc_size = in_hw_size * channel * batch_size; + + if (out_height == in_height && out_width == in_width) { + for (size_t i = 0; i < bhwc_size; ++i) { + output_addr[i] = static_cast(input_addr[i]); + } + } + + std::vector ys(out_height + 1); + std::vector xs(out_width + 1); + + ComputeInterpolationWeights(out_height, in_height, height_scale, ys.data()); + ComputeInterpolationWeights(out_width, in_width, width_scale, xs.data()); + + for (size_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < out_height; ++h) { + const T1 *ys_input_lower_ptr = input_addr + ys[h].lower * in_width; + const T1 *ys_input_upper_ptr = input_addr + ys[h].upper * in_width; + const T2 ys_lerp = T2(ys[h].lerp); + for (size_t w = 0; w < out_width; ++w) { + const size_t xs_lower = xs[w].lower; + const size_t xs_upper = xs[w].upper; + const T2 xs_lerp = T2(xs[w].lerp); + const T2 top_left(ys_input_lower_ptr[xs_lower]); + const T2 top_right(ys_input_lower_ptr[xs_upper]); + const T2 bottom_left(ys_input_upper_ptr[xs_lower]); + const T2 bottom_right(ys_input_upper_ptr[xs_upper]); + output_addr[h * out_width + w] = + ComputeLerp(top_left, top_right, bottom_left, bottom_right, xs_lerp, ys_lerp); + } + } + output_addr += out_hw_size; + input_addr += in_hw_size; + } + } +} + +void ResizeBilinearCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinear needs 1 inputs, but gets " << input_num; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinear expects 1 output, but gets" << output_num; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.h new file mode 100644 index 0000000000..58dc881f3f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_cpu_kernel.h @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ResizeBilinearCPUKernel : public CPUKernel { + public: + ResizeBilinearCPUKernel() = default; + ~ResizeBilinearCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + TypeId dtype_{kTypeUnknown}; + bool align_corners_ = false; + float height_scale; + float width_scale; + std::vector size_; + std::vector shape_; +}; + +MS_REG_CPU_KERNEL(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), + ResizeBilinearCPUKernel); + +MS_REG_CPU_KERNEL(ResizeBilinear, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeBilinearCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.cc new file mode 100644 index 0000000000..bf28dc231f --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { + +void ResizeBilinearGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + size_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + align_corners_ = AnfAlgo::GetNodeAttr(kernel_node, "align_corners"); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + + size_t in_height = shape_[2]; + size_t in_width = shape_[3]; + size_t out_height = size_[2]; + size_t out_width = size_[3]; + + height_scale = Scaling(out_height, in_height, align_corners_); + width_scale = Scaling(out_width, in_width, align_corners_); +} + +bool ResizeBilinearGradCPUKernel::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void ResizeBilinearGradCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto dloss_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + size_t batch_size = shape_[0]; + size_t channel = shape_[1]; + size_t in_height = shape_[2]; + size_t in_width = shape_[3]; + size_t out_height = size_[2]; + size_t out_width = size_[3]; + size_t out_hw_size = out_height * out_width; + size_t in_hw_size = in_height * in_width; + + for (size_t b = 0; b < batch_size; ++b) { + for (size_t c = 0; c < channel; ++c) { + for (size_t h = 0; h < in_height; ++h) { + const float in_y = static_cast(h) * height_scale; + const size_t top_y_index = std::max(static_cast(floorf(in_y)), static_cast(0)); + const size_t bottom_y_index = std::min(static_cast(ceilf(in_y)), out_height - 1); + const float y_lerp = in_y - floorf(in_y); + const float inverse_y_lerp = 1.0 - y_lerp; + for (size_t w = 0; w < in_width; ++w) { + const float in_x = static_cast(w) * width_scale; + const size_t left_x_index = std::max(static_cast(floorf(in_x)), static_cast(0)); + const size_t right_x_index = std::min(static_cast(ceilf(in_x)), out_width - 1); + const float x_lerp = in_x - floorf(in_x); + const float inverse_x_lerp = 1.0 - x_lerp; + output_addr[top_y_index * out_width + left_x_index] += + dloss_addr[h * in_width + w] * T(inverse_y_lerp * inverse_x_lerp); + output_addr[top_y_index * out_width + right_x_index] += + dloss_addr[h * in_width + w] * T(inverse_y_lerp * x_lerp); + output_addr[bottom_y_index * out_width + left_x_index] += + dloss_addr[h * in_width + w] * T(y_lerp * inverse_x_lerp); + output_addr[bottom_y_index * out_width + right_x_index] += dloss_addr[h * in_width + w] * T(y_lerp * x_lerp); + } + } + output_addr += out_hw_size; + dloss_addr += in_hw_size; + } + } +} + +void ResizeBilinearGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 2) { + MS_LOG(EXCEPTION) << "ResizeBilinearGrad needs 2 inputs, but gets " << input_num; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinear Gradexpects 1 output, but gets" << output_num; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.h new file mode 100644 index 0000000000..be87ceb50c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_bilinear_grad_cpu_kernel.h @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_GRAD_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ResizeBilinearGradCPUKernel : public CPUKernel { + public: + ResizeBilinearGradCPUKernel() = default; + ~ResizeBilinearGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + TypeId dtype_{kTypeUnknown}; + bool align_corners_ = false; + float height_scale; + float width_scale; + std::vector size_; + std::vector shape_; +}; + +MS_REG_CPU_KERNEL( + ResizeBilinearGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeBilinearGradCPUKernel); + +MS_REG_CPU_KERNEL( + ResizeBilinearGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeBilinearGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_BILINEAR_GRAD_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.cc new file mode 100644 index 0000000000..61661ce684 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.cc @@ -0,0 +1,94 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { + +void ResizeNearestNeighborCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector output_size = AnfAlgo::GetNodeAttr>(kernel_node, SIZE); + align_corners_ = AnfAlgo::GetNodeAttr(kernel_node, "align_corners"); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + in_height_ = input_shape[2]; + in_width_ = input_shape[3]; + out_height_ = output_size[0]; + out_width_ = output_size[1]; + height_scale_ = Scaling(in_height_, out_height_, align_corners_); + width_scale_ = Scaling(in_width_, out_width_, align_corners_); + output_size_ = batch_size_ * channel_ * out_height_ * out_width_; +} + +bool ResizeNearestNeighborCPUKernel::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void ResizeNearestNeighborCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto input_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + + if (out_height_ == in_height_ && out_width_ == in_width_) { + for (size_t i = 0; i < output_size_; ++i) { + output_addr[i] = input_addr[i]; + } + } + + for (size_t i = 0; i < output_size_; ++i) { + size_t pos0 = i / (channel_ * out_height_ * out_width_) % batch_size_; + size_t pos1 = i / (out_height_ * out_width_) % channel_; + size_t pos2 = i / (out_width_) % out_height_; + size_t pos3 = i % out_width_; + const size_t in_y = std::min((align_corners_) ? static_cast(roundf(pos2 * height_scale_)) + : static_cast(floorf(pos2 * height_scale_)), + in_height_ - 1); + const size_t in_x = std::min((align_corners_) ? static_cast(roundf(pos3 * width_scale_)) + : static_cast(floorf(pos3 * width_scale_)), + in_width_ - 1); + size_t input_pos = + pos0 * channel_ * in_height_ * in_width_ + pos1 * in_height_ * in_width_ + in_y * in_width_ + in_x; + output_addr[i] = input_addr[input_pos]; + } +} + +void ResizeNearestNeighborCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinear needs 1 inputs, but gets " << input_num; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinear expects 1 output, but gets" << output_num; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.h new file mode 100644 index 0000000000..4f83b002da --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_cpu_kernel.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ResizeNearestNeighborCPUKernel : public CPUKernel { + public: + ResizeNearestNeighborCPUKernel() = default; + ~ResizeNearestNeighborCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + TypeId dtype_{kTypeUnknown}; + bool align_corners_{false}; + size_t batch_size_{0}; + size_t channel_{0}; + size_t in_height_{0}; + size_t in_width_{0}; + size_t out_height_{0}; + size_t out_width_{0}; + size_t output_size_{0}; + float height_scale_{1.0}; + float width_scale_{1.0}; +}; + +MS_REG_CPU_KERNEL(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeNearestNeighborCPUKernel); + +MS_REG_CPU_KERNEL(ResizeNearestNeighbor, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeNearestNeighborCPUKernel); + +MS_REG_CPU_KERNEL(ResizeNearestNeighbor, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ResizeNearestNeighborCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_CPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.cc new file mode 100644 index 0000000000..50f12a216a --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.cc @@ -0,0 +1,91 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.h" +#include "runtime/device/cpu/cpu_device_address.h" +#include "backend/kernel_compiler/common_utils.h" + +namespace mindspore { +namespace kernel { + +void ResizeNearestNeighborGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { + CheckParam(kernel_node); + std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + std::vector output_size = AnfAlgo::GetOutputInferShape(kernel_node, 0); + align_corners_ = AnfAlgo::GetNodeAttr(kernel_node, "align_corners"); + dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); + batch_size_ = input_shape[0]; + channel_ = input_shape[1]; + in_height_ = input_shape[2]; + in_width_ = input_shape[3]; + out_height_ = output_size[2]; + out_width_ = output_size[3]; + height_scale_ = Scaling(out_height_, in_height_, align_corners_); + width_scale_ = Scaling(out_width_, in_width_, align_corners_); +} + +bool ResizeNearestNeighborGradCPUKernel::Launch(const std::vector &inputs, + const std::vector &, + const std::vector &outputs) { + if (dtype_ == kNumberTypeFloat16) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeFloat32) { + LaunchKernel(inputs, outputs); + } else if (dtype_ == kNumberTypeInt32) { + LaunchKernel(inputs, outputs); + } + return true; +} + +template +void ResizeNearestNeighborGradCPUKernel::LaunchKernel(const std::vector &inputs, + const std::vector &outputs) { + auto dloss_addr = reinterpret_cast(inputs[0]->addr); + auto output_addr = reinterpret_cast(outputs[0]->addr); + size_t in_hw_size = in_width_ * in_height_; + size_t out_hw_size = out_width_ * out_height_; + + for (size_t b = 0; b < batch_size_; ++b) { + for (size_t c = 0; c < channel_; ++c) { + for (size_t h = 0; h < in_height_; ++h) { + const size_t out_y = std::min((align_corners_) ? static_cast(roundf(h * height_scale_)) + : static_cast(floorf(h * height_scale_)), + out_height_ - 1); + for (size_t w = 0; w < in_width_; ++w) { + const size_t out_x = std::min((align_corners_) ? static_cast(roundf(w * width_scale_)) + : static_cast(floorf(w * width_scale_)), + out_width_ - 1); + output_addr[out_y * out_width_ + out_x] += dloss_addr[h * in_width_ + w]; + } + } + output_addr += out_hw_size; + dloss_addr += in_hw_size; + } + } +} + +void ResizeNearestNeighborGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinearGrad needs 1 inputs, but gets " << input_num; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(EXCEPTION) << "ResizeBilinear Gradexpects 1 output, but gets" << output_num; + } +} +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.h new file mode 100644 index 0000000000..f2a2e89ebd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/resize_nearest_neighbor_grad_cpu_kernel.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_GRAD_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_GRAD_CPU_KERNEL_H_ +#include +#include +#include +#include +#include "backend/kernel_compiler/cpu/cpu_kernel.h" +#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" + +namespace mindspore { +namespace kernel { +class ResizeNearestNeighborGradCPUKernel : public CPUKernel { + public: + ResizeNearestNeighborGradCPUKernel() = default; + ~ResizeNearestNeighborGradCPUKernel() override = default; + + void InitKernel(const CNodePtr &kernel_node) override; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override; + + template + void LaunchKernel(const std::vector &inputs, const std::vector &outputs); + + private: + void CheckParam(const CNodePtr &kernel_node); + TypeId dtype_{kTypeUnknown}; + bool align_corners_{false}; + size_t batch_size_{0}; + size_t channel_{0}; + size_t in_height_{0}; + size_t in_width_{0}; + size_t out_height_{0}; + size_t out_width_{0}; + float height_scale_{1.0}; + float width_scale_{1.0}; +}; + +MS_REG_CPU_KERNEL(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ResizeNearestNeighborGradCPUKernel); + +MS_REG_CPU_KERNEL(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ResizeNearestNeighborGradCPUKernel); + +MS_REG_CPU_KERNEL(ResizeNearestNeighborGrad, + KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ResizeNearestNeighborGradCPUKernel); +} // namespace kernel +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_RESIZE_NEAREST_NEIGHBOR_GRAD_CPU_KERNEL_H_ diff --git a/tests/st/ops/cpu/test_resize_bilinear_grad_op.py b/tests/st/ops/cpu/test_resize_bilinear_grad_op.py new file mode 100644 index 0000000000..7d2b27dfb1 --- /dev/null +++ b/tests/st/ops/cpu/test_resize_bilinear_grad_op.py @@ -0,0 +1,83 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class ResizeBilinearGradAlignCornerT(nn.Cell): + def __init__(self): + super(ResizeBilinearGradAlignCornerT, self).__init__() + self.ResizeBilinearGradAlignCornerT = G.ResizeBilinearGrad( + align_corners=True) + + def construct(self, dy, size): + return self.ResizeBilinearGradAlignCornerT(dy, size) + + +class ResizeBilinearGradAlignCornerF(nn.Cell): + def __init__(self): + super(ResizeBilinearGradAlignCornerF, self).__init__() + self.ResizeBilinearGradAlignCornerF = G.ResizeBilinearGrad(align_corners=False) + + def construct(self, dy, size): + return self.ResizeBilinearGradAlignCornerF(dy, size) + + +def test_ResizeBilinearGradAlignCornerT(): + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) + + orign_image = np.array( + [[[[1.1, 2.2, 3.2, 2.5], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float16) + expect = np.array([[[[1., 0., 0., 2.], + [0., 0., 0., 0.], + [0., 0., 0., 0.], + [3., 0., 0., 4.]]]]).astype(np.float16) + rnn = ResizeBilinearGradAlignCornerT() + output = rnn(Tensor(dy), Tensor(orign_image)) + assert np.all(output.asnumpy() == expect) + + orign_image = np.array( + [[[[1.1, 2.2, 3.2, 2.5], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1], [3.3, 4.4, 5.7, 8.1]]]]).astype(np.float32) + expect = np.array([[[[1., 0., 0., 2.], + [0., 0., 0., 0.], + [0., 0., 0., 0.], + [3., 0., 0., 4.]]]]).astype(np.float32) + rnn = ResizeBilinearGradAlignCornerT() + output = rnn(Tensor(dy), Tensor(orign_image)) + assert np.all(output.asnumpy() == expect) + + +def test_ResizeBilinearGradAlignCornerF(): + dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + + orign_image = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float16) + expect = np.array([[[[2.25, 0.75], + [0.75, 4.25]]]]).astype(np.float16) + rnn = ResizeBilinearGradAlignCornerF() + output = rnn(Tensor(dy), Tensor(orign_image)) + assert np.all(output.asnumpy() == expect) + + orign_image = np.array([[[[1.1, 2.2], [3.3, 4.4]]]]).astype(np.float32) + expect = np.array([[[[2.25, 0.75], + [0.75, 4.25]]]]).astype(np.float32) + rnn = ResizeBilinearGradAlignCornerF() + output = rnn(Tensor(dy), Tensor(orign_image)) + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/cpu/test_resize_bilinear_op.py b/tests/st/ops/cpu/test_resize_bilinear_op.py new file mode 100644 index 0000000000..dab90236b5 --- /dev/null +++ b/tests/st/ops/cpu/test_resize_bilinear_op.py @@ -0,0 +1,571 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +from mindspore import context, Tensor +from mindspore.ops import operations as P +from mindspore import nn + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetResizeBilinear(nn.Cell): + def __init__(self, size=None, align_corner=False): + super(NetResizeBilinear, self).__init__() + self.op = P.ResizeBilinear(size=size, align_corners=align_corner) + + def construct(self, inputs): + return self.op(inputs) + + +def test_resize_nn_grayscale_integer_ratio_half(datatype=np.float16): + input_tensor = Tensor(np.array( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype)) + + # larger h and w + resize_nn = NetResizeBilinear((9, 9)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.13330078, 0.16662598, 0.19995117, 0.23331706, + 0.26668295, 0.30004883, 0.30004883, 0.30004883], + [0.19995117, 0.23328993, 0.26662868, 0.29996747, 0.33333334, + 0.36669925, 0.40006512, 0.40006512, 0.40006512], + [0.29992676, 0.33327907, 0.36663142, 0.39998373, 0.4333496, + 0.4667155, 0.5000814, 0.5000814, 0.5000814], + [0.39990234, 0.43326822, 0.46663412, 0.5, 0.5333659, + 0.5667318, 0.60009766, 0.60009766, 0.60009766], + [0.5, 0.5333116, 0.5666233, 0.59993494, 0.6333008, + 0.66666675, 0.7000326, 0.7000326, 0.7000326], + [0.60009766, 0.633355, 0.66661245, 0.6998698, 0.7332357, + 0.7666016, 0.79996747, 0.79996747, 0.79996747], + [0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706, + 0.8665365, 0.89990234, 0.89990234, 0.89990234], + [0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706, + 0.8665365, 0.89990234, 0.89990234, 0.89990234], + [0.7001953, 0.73339844, 0.76660156, 0.7998047, 0.8331706, + 0.8665365, 0.89990234, 0.89990234, 0.89990234]]]]).astype(np.float32)) + error = np.ones(shape=[9, 9]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h and w + resize_nn = NetResizeBilinear((1, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559]]]]).astype(np.float32)) + error = np.ones(shape=[1, 1]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, larger w + resize_nn = NetResizeBilinear((1, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, 0.30004883]]]]).astype(np.float32)) + error = np.ones(shape=[1, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, smaller w + resize_nn = NetResizeBilinear((6, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.09997559], [0.24993896], [0.39990234], [0.5500488], [0.7001953], [0.7001953]]]]).astype( + np.float32)) + error = np.ones(shape=[6, 1]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, same w + resize_nn = NetResizeBilinear((1, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.09997559, 0.19995117, 0.30004883]]]]).astype(np.float32)) + error = np.ones(shape=[1, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, same w + resize_nn = NetResizeBilinear((6, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883], + [0.24993896, 0.3499756, 0.45007324], + [0.39990234, 0.5, 0.60009766], + [0.5500488, 0.64990234, 0.75], + [0.7001953, 0.7998047, 0.89990234], + [0.7001953, 0.7998047, 0.89990234]]]]).astype(np.float32)) + error = np.ones(shape=[6, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, smaller w + resize_nn = NetResizeBilinear((3, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.09997559], [0.39990234], [0.7001953]]]]).astype(np.float32)) + error = np.ones(shape=[3, 1]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, larger w + resize_nn = NetResizeBilinear((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, + 0.30004883], + [0.39990234, 0.44995117, 0.5, 0.5500488, 0.60009766, + 0.60009766], + [0.7001953, 0.75, 0.7998047, 0.8498535, 0.89990234, + 0.89990234]]]]).astype(np.float32)) + error = np.ones(shape=[3, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same w, same h (identity) + resize_nn = NetResizeBilinear((3, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array( + [[[[0.09997559, 0.19995117, 0.30004883], + [0.39990234, 0.5, 0.60009766], + [0.7001953, 0.7998047, 0.89990234]]]]).astype(np.float32)) + error = np.ones(shape=[3, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + +def test_resize_nn_grayscale_integer_ratio_float(datatype=np.float32): + input_tensor = Tensor(np.array( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype)) + + # larger h and w + resize_nn = NetResizeBilinear((9, 9)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.13333334, 0.16666667, 0.2, 0.23333335, 0.26666668, 0.3, 0.3, 0.3], + [0.20000002, 0.23333335, 0.26666668, 0.3, 0.33333337, 0.3666667, 0.40000004, + 0.40000004, 0.40000004], + [0.3, 0.33333337, 0.36666667, 0.40000004, 0.43333337, 0.4666667, 0.5, 0.5, + 0.5], + [0.4, 0.43333334, 0.46666667, 0.5, 0.53333336, 0.5666667, 0.6, 0.6, 0.6], + [0.5, 0.53333336, 0.56666666, 0.6, 0.6333333, 0.66666675, 0.70000005, + 0.70000005, 0.70000005], + [0.6, 0.6333334, 0.6666667, 0.70000005, 0.73333335, 0.7666667, 0.8, 0.8, 0.8], + [0.7, 0.73333335, 0.76666665, 0.8, 0.8333333, 0.8666667, 0.9, 0.9, 0.9], + [0.7, 0.73333335, 0.76666665, 0.8, 0.8333333, 0.8666667, 0.9, 0.9, 0.9], + [0.7, 0.73333335, 0.76666665, 0.8, 0.8333333, 0.8666667, 0.9, 0.9, + 0.9]]]]).astype(np.float32)) + error = np.ones(shape=[9, 9]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h and w + resize_nn = NetResizeBilinear((1, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1]]]]).astype(np.float32)) + error = np.ones(shape=[1, 1]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, larger w + resize_nn = NetResizeBilinear((1, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1, 0.15, 0.2, 0.25, 0.3, 0.3]]]]).astype(np.float32)) + error = np.ones(shape=[1, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, smaller w + resize_nn = NetResizeBilinear((6, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1], [0.25], [0.4], [0.55], [0.7], [0.7]]]]).astype(np.float32)) + error = np.ones(shape=[6, 1]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, same w + resize_nn = NetResizeBilinear((1, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1, 0.2, 0.3]]]]).astype(np.float32)) + error = np.ones(shape=[1, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, same w + resize_nn = NetResizeBilinear((6, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3], + [0.25, 0.35000002, 0.45000002], + [0.4, 0.5, 0.6], + [0.55, 0.65, 0.75], + [0.7, 0.8, 0.9], + [0.7, 0.8, 0.9]]]]).astype(np.float32)) + error = np.ones(shape=[6, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, smaller w + resize_nn = NetResizeBilinear((3, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1], [0.4], [0.7]]]]).astype(np.float32)) + error = np.ones(shape=[3, 1]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, larger w + resize_nn = NetResizeBilinear((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.15, 0.2, 0.25, 0.3, 0.3], + [0.4, 0.45, 0.5, 0.55, 0.6, 0.6], + [0.7, 0.75, 0.8, 0.85, 0.9, 0.9]]]]).astype(np.float32)) + error = np.ones(shape=[3, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same w, same h (identity) + resize_nn = NetResizeBilinear((3, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(np.float32)) + error = np.ones(shape=[3, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + +def test_resize_nn_grayscale_not_integer_ratio_half(datatype=np.float16): + input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.0, 0.1, 0.2]]]]).astype(datatype)) + + # larger h and w + resize_nn = NetResizeBilinear((7, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784, + 0.38563755, 0.39990234], + [0.27141464, 0.3285734, 0.3857422, 0.44294086, 0.5000399, + 0.55703926, 0.57128906], + [0.44285366, 0.5000423, 0.5572336, 0.6144322, 0.67150134, + 0.7284409, 0.7426758], + [0.6142578, 0.50819117, 0.44293588, 0.5001146, 0.5571937, + 0.6141731, 0.62841797], + [0.78564453, 0.4346799, 0.18574369, 0.2428925, 0.3000015, + 0.3570706, 0.3713379], + [0.89990234, 0.3856724, 0.01428223, 0.07141115, 0.12854005, + 0.18566895, 0.19995117], + [0.89990234, 0.3856724, 0.01428223, 0.07141115, 0.12854005, + 0.18566895, 0.19995117]]]]).astype(np.float32)) + error = np.ones(shape=[7, 7]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h and w + resize_nn = NetResizeBilinear((2, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.09997559, 0.23331706, 0.36661786], + [0.6999512, 0.33339438, 0.46661377]]]]).astype(np.float32)) + error = np.ones(shape=[2, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, larger w + resize_nn = NetResizeBilinear((2, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784, + 0.38563755, 0.39990234], + [0.6999512, 0.47143552, 0.3143398, 0.37150356, 0.4285976, + 0.48562187, 0.49987793]]]]).astype(np.float32)) + error = np.ones(shape=[2, 7]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, smaller w + resize_nn = NetResizeBilinear((5, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.23331706, 0.36661786], + [0.33999026, 0.47340494, 0.6066081], + [0.5799805, 0.51343584, 0.64660645], + [0.8199219, 0.15335283, 0.28662106], + [0.89990234, 0.0333252, 0.16662598]]]]).astype(np.float32)) + error = np.ones(shape=[5, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, same w + resize_nn = NetResizeBilinear((2, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234], + [0.6999512, 0.30004883, 0.40008545, 0.49987793]]]]).astype(np.float32)) + error = np.ones(shape=[2, 4]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, same w + resize_nn = NetResizeBilinear((8, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234], + [0.24998474, 0.3500061, 0.45010376, 0.5498657], + [0.3999939, 0.50006104, 0.6001587, 0.6998291], + [0.5499878, 0.52508545, 0.62516785, 0.724823], + [0.6999512, 0.30004883, 0.40008545, 0.49987793], + [0.84991455, 0.07501221, 0.17500305, 0.27493286], + [0.89990234, 0., 0.09997559, 0.19995117], + [0.89990234, 0., 0.09997559, 0.19995117]]]]).astype(np.float32)) + error = np.ones(shape=[8, 4]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, smaller w + resize_nn = NetResizeBilinear((3, 2)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.30004883], + [0.5, 0.7001953], + [0.89990234, 0.09997559]]]]).astype(np.float32)) + error = np.ones(shape=[3, 2]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, larger w + resize_nn = NetResizeBilinear((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.16662598, 0.23331706, 0.30004883, 0.36661786, + 0.39990234], + [0.5, 0.56673175, 0.63346356, 0.7001953, 0.76660156, + 0.7998047], + [0.89990234, 0.2999674, 0.0333252, 0.09997559, 0.16662598, + 0.19995117]]]]).astype(np.float32)) + error = np.ones(shape=[3, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same w, same h (identity) + resize_nn = NetResizeBilinear((3, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.19995117, 0.30004883, 0.39990234], + [0.5, 0.60009766, 0.7001953, 0.7998047], + [0.89990234, 0., 0.09997559, 0.19995117]]]]).astype(np.float32)) + error = np.ones(shape=[3, 4]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + +def test_resize_nn_grayscale_not_integer_ratio_float(datatype=np.float32): + input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.0, 0.1, 0.2]]]]).astype(datatype)) + + # larger h and w + resize_nn = NetResizeBilinear((7, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144, 0.3857143, 0.4], + [0.27142859, 0.32857144, 0.38571432, 0.44285715, 0.5, 0.55714285, 0.5714286], + [0.44285715, 0.5, 0.5571429, 0.6142857, 0.67142856, 0.7285714, 0.74285716], + [0.6142857, 0.5081633, 0.4428572, 0.5, 0.55714285, 0.6142857, 0.62857145], + [0.78571427, 0.43469384, 0.1857143, 0.24285716, 0.3, 0.35714287, 0.37142855], + [0.9, 0.38571423, 0.01428572, 0.07142859, 0.12857144, 0.1857143, 0.2], + [0.9, 0.38571423, 0.01428572, 0.07142859, 0.12857144, 0.1857143, + 0.2]]]]).astype(np.float32)) + error = np.ones(shape=[7, 7]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h and w + resize_nn = NetResizeBilinear((2, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1, 0.23333335, 0.36666667], + [0.7, 0.33333334, 0.46666667]]]]).astype(np.float32)) + error = np.ones(shape=[2, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, larger w + resize_nn = NetResizeBilinear((2, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144, + 0.3857143, 0.4], + [0.7, 0.47142854, 0.31428576, 0.37142858, 0.42857143, + 0.4857143, 0.5]]]]).astype(np.float32)) + error = np.ones(shape=[2, 7]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, smaller w + resize_nn = NetResizeBilinear((5, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.23333335, 0.36666667], + [0.34, 0.47333336, 0.6066667], + [0.58000004, 0.5133333, 0.64666665], + [0.82000005, 0.1533333, 0.28666663], + [0.9, 0.03333334, 0.16666669]]]]).astype(np.float32)) + error = np.ones(shape=[5, 3]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # smaller h, same w + resize_nn = NetResizeBilinear((2, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.7, 0.3, 0.4, 0.5]]]]).astype(np.float32)) + error = np.ones(shape=[2, 4]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # larger h, same w + resize_nn = NetResizeBilinear((8, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.25, 0.35000002, 0.45, 0.55], + [0.4, 0.5, 0.6, 0.70000005], + [0.55, 0.52500004, 0.625, 0.725], + [0.7, 0.3, 0.4, 0.5], + [0.84999996, 0.07499999, + 0.17500001, 0.27499998], + [0.9, 0., 0.1, 0.2], + [0.9, 0., 0.1, 0.2]]]]).astype(np.float32)) + error = np.ones(shape=[8, 4]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, smaller w + resize_nn = NetResizeBilinear((3, 2)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.3], + [0.5, 0.7], + [0.9, 0.1]]]]).astype(np.float32)) + error = np.ones(shape=[3, 2]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same h, larger w + resize_nn = NetResizeBilinear((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.16666667, 0.23333335, 0.3, 0.36666667, 0.4], + [0.5, 0.56666666, 0.6333333, 0.7, 0.76666665, 0.8], + [0.9, 0.29999995, 0.03333334, 0.1, 0.16666669, 0.2]]]]).astype(np.float32)) + error = np.ones(shape=[3, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + # same w, same h (identity) + resize_nn = NetResizeBilinear((3, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0., 0.1, 0.2]]]]).astype(np.float32)) + error = np.ones(shape=[3, 4]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + +def test_resize_nn_grayscale_multiple_images_half(datatype=np.float16): + input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], + [[[0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [0.1, 0.2, 0.3]]], + [[[0.7, 0.8, 0.9], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]]).astype(datatype)) + + resize_nn = NetResizeBilinear((2, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, 0.30004883], + [0.5500488, 0.5999756, 0.64990234, 0.6999512, 0.75, 0.75]]], + [[[0.39990234, 0.44995117, 0.5, 0.5500488, 0.60009766, 0.60009766], + [0.40008545, 0.4499817, 0.49987793, 0.54992676, 0.5999756, 0.5999756]]], + [[[0.7001953, 0.75, 0.7998047, 0.8498535, 0.89990234, 0.89990234], + [0.24993896, 0.29995728, 0.3499756, 0.4000244, 0.45007324, + 0.45007324]]]]).astype(np.float32)) + + error = np.ones(shape=[3, 3, 2, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + +def test_resize_nn_grayscale_multiple_images_float(datatype=np.float32): + input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], + [[[0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [0.1, 0.2, 0.3]]], + [[[0.7, 0.8, 0.9], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]]).astype(datatype)) + + resize_nn = NetResizeBilinear((2, 6)) + output = resize_nn(input_tensor) + + expected_output = Tensor(np.array([[[[0.1, 0.15, 0.2, 0.25, 0.3, 0.3], + [0.55, 0.6, 0.65, 0.70000005, 0.75, 0.75]]], + [[[0.4, 0.45, 0.5, 0.55, 0.6, 0.6], + [0.4, 0.45, 0.5, 0.55, 0.6, 0.6]]], + [[[0.7, 0.75, 0.8, 0.85, 0.9, 0.9], + [0.25, 0.3, 0.35000002, 0.4, 0.45000002, 0.45000002]]]]).astype(np.float32)) + + error = np.ones(shape=[3, 3, 2, 6]) * 1.0e-6 + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + + +def test_resize_nn_grayscale_align_corners_half(datatype=np.float16): + input_tensor = Tensor( + np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype)) + + resize_nn_corners_aligned = NetResizeBilinear( + size=(3, 7), align_corner=True) + output_corners_aligned = resize_nn_corners_aligned(input_tensor) + + resize_nn = NetResizeBilinear((3, 7)) + output = resize_nn(input_tensor) + + expected_output_align = Tensor(np.array([[[[0.09997559, 0.14996338, 0.19995117, 0.25, 0.30004883, + 0.3499756, 0.39990234], + [0.2999878, 0.3500061, 0.4000244, 0.45007324, 0.5001221, + 0.5499878, 0.5998535], + [0.5, 0.5500488, 0.60009766, 0.6501465, 0.7001953, + 0.75, 0.7998047]]]]).astype(np.float32)) + expected_output = Tensor(np.array([[[[0.09997559, 0.15710449, 0.21425085, 0.2714495, 0.3285784, + 0.38563755, 0.39990234], + [0.36665854, 0.42383394, 0.4810152, 0.53821385, 0.59529626, + 0.6522624, 0.6665039], + [0.5, 0.55719864, 0.61439735, 0.671596, 0.72865516, + 0.7855748, 0.7998047]]]]).astype(np.float32)) + + error = np.ones(shape=[3, 7]) * 1.0e-6 + diff_align = output_corners_aligned.asnumpy() - expected_output_align.asnumpy() + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + assert np.all(abs(diff_align) < error) + + +def test_resize_nn_grayscale_align_corners_float(datatype=np.float32): + input_tensor = Tensor( + np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype)) + + resize_nn_corners_aligned = NetResizeBilinear( + size=(3, 7), align_corner=True) + output_corners_aligned = resize_nn_corners_aligned(input_tensor) + + resize_nn = NetResizeBilinear((3, 7)) + output = resize_nn(input_tensor) + + expected_output_align = Tensor(np.array([[[[0.1, 0.15, 0.2, 0.25, 0.3, + 0.35000002, 0.4], + [0.3, 0.35000002, 0.40000004, 0.45, 0.5, + 0.55, 0.6], + [0.5, 0.55, 0.6, 0.65, 0.7, + 0.75, 0.8]]]]).astype(datatype)) + expected_output = Tensor(np.array([[[[0.1, 0.15714286, 0.21428573, 0.27142859, 0.32857144, + 0.3857143, 0.4], + [0.36666667, 0.42380953, 0.48095244, 0.53809524, 0.5952381, + 0.65238094, 0.6666667], + [0.5, 0.55714285, 0.61428577, 0.67142856, 0.7285714, + 0.78571427, 0.8]]]]).astype(datatype)) + + error = np.ones(shape=[3, 7]) * 1.0e-6 + diff_align = output_corners_aligned.asnumpy() - expected_output_align.asnumpy() + diff = output.asnumpy() - expected_output.asnumpy() + assert np.all(abs(diff) < error) + assert np.all(abs(diff_align) < error) diff --git a/tests/st/ops/cpu/test_resize_nearest_neighbor_grad_op.py b/tests/st/ops/cpu/test_resize_nearest_neighbor_grad_op.py new file mode 100644 index 0000000000..2e478defae --- /dev/null +++ b/tests/st/ops/cpu/test_resize_nearest_neighbor_grad_op.py @@ -0,0 +1,93 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class ResizeNearestNeighborGradAlignCornerT(nn.Cell): + def __init__(self, size=None): + super(ResizeNearestNeighborGradAlignCornerT, self).__init__() + self.ResizeNearestNeighborGradAlignCornerT = G.ResizeNearestNeighborGrad( + align_corners=True) + self.size = size + + def construct(self, dy): + return self.ResizeNearestNeighborGradAlignCornerT(dy, self.size) + + +class ResizeNearestNeighborGradAlignCornerF(nn.Cell): + def __init__(self, size=None): + super(ResizeNearestNeighborGradAlignCornerF, self).__init__() + self.ResizeNearestNeighborGradAlignCornerF = G.ResizeNearestNeighborGrad( + align_corners=False) + self.size = size + + def construct(self, dy): + return self.ResizeNearestNeighborGradAlignCornerF(dy, self.size) + + +def test_ResizeNearestNeighborGradAlignCornerT(): + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) + size = (4, 4) + expect = np.array( + [[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float32) + rnn = ResizeNearestNeighborGradAlignCornerT(size=size) + output = rnn(Tensor(dy)) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16) + size = (4, 4) + expect = np.array( + [[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerT(size=size) + output = rnn(Tensor(dy)) + assert np.all(output.asnumpy() == expect) + dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.int32) + size = (4, 4) + expect = np.array( + [[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.int32) + rnn = ResizeNearestNeighborGradAlignCornerT(size=size) + output = rnn(Tensor(dy)) + assert np.all(output.asnumpy() == expect) + + +def test_ResizeNearestNeighborGradAlignCornerF(): + dy = np.array( + [[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float32) + rnn = ResizeNearestNeighborGradAlignCornerF(size=size) + output = rnn(Tensor(dy)) + assert np.all(output.asnumpy() == expect) + dy = np.array( + [[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float16) + rnn = ResizeNearestNeighborGradAlignCornerF(size=size) + output = rnn(Tensor(dy)) + assert np.all(output.asnumpy() == expect) + dy = np.array( + [[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) + size = (2, 2) + expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.int32) + rnn = ResizeNearestNeighborGradAlignCornerF(size=size) + output = rnn(Tensor(dy)) + assert np.all(output.asnumpy() == expect) diff --git a/tests/st/ops/cpu/test_resize_nearest_neighbor_op.py b/tests/st/ops/cpu/test_resize_nearest_neighbor_op.py new file mode 100755 index 0000000000..0e3c45d31a --- /dev/null +++ b/tests/st/ops/cpu/test_resize_nearest_neighbor_op.py @@ -0,0 +1,589 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +from mindspore import context, Tensor +from mindspore.ops import operations as P +from mindspore import nn + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetResizeNearestNeighbor(nn.Cell): + def __init__(self, size=None, align_corners=False): + super(NetResizeNearestNeighbor, self).__init__() + self.op = P.ResizeNearestNeighbor(size=size, align_corners=align_corners) + + def construct(self, inputs): + return self.op(inputs) + + +def resize_nn_grayscale_integer_ratio(datatype): + input_tensor = Tensor(np.array( + [[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]]]).astype(datatype)) + + # larger h and w + resize_nn = NetResizeNearestNeighbor((9, 9)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3], + [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3], + [0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.3, 0.3, 0.3], + [0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6], + [0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6], + [0.4, 0.4, 0.4, 0.5, 0.5, 0.5, 0.6, 0.6, 0.6], + [0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9], + [0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9], + [0.7, 0.7, 0.7, 0.8, 0.8, 0.8, 0.9, 0.9, 0.9]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h and w + resize_nn = NetResizeNearestNeighbor((1, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, larger w + resize_nn = NetResizeNearestNeighbor((1, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, smaller w + resize_nn = NetResizeNearestNeighbor((6, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1], [0.1], [0.4], [0.4], [0.7], [0.7]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, same w + resize_nn = NetResizeNearestNeighbor((1, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, same w + resize_nn = NetResizeNearestNeighbor((6, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3], + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + [0.4, 0.5, 0.6], + [0.7, 0.8, 0.9], + [0.7, 0.8, 0.9]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, smaller w + resize_nn = NetResizeNearestNeighbor((3, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1], [0.4], [0.7]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, larger w + resize_nn = NetResizeNearestNeighbor((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3], + [0.4, 0.4, 0.5, 0.5, 0.6, 0.6], + [0.7, 0.7, 0.8, 0.8, 0.9, 0.9]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same w, same h (identity) + resize_nn = NetResizeNearestNeighbor((3, 3)) + output = resize_nn(input_tensor) + np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy()) + + +def resize_nn_grayscale_not_integer_ratio(datatype): + input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.0, 0.1, 0.2]]]]).astype(datatype)) + + # larger h and w + resize_nn = NetResizeNearestNeighbor((7, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4], + [0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4], + [0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4], + [0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8], + [0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8], + [0.9, 0.9, 0.0, 0.0, 0.1, 0.1, 0.2], + [0.9, 0.9, 0.0, 0.0, 0.1, 0.1, 0.2]]]]).astype(datatype)) + + # smaller h and w + resize_nn = NetResizeNearestNeighbor((2, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[0.1, 0.2, 0.3], [0.5, 0.6, 0.7]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, larger w + resize_nn = NetResizeNearestNeighbor((2, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4], + [0.5, 0.5, 0.6, 0.6, 0.7, 0.7, 0.8]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, smaller w + resize_nn = NetResizeNearestNeighbor((5, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3], + [0.1, 0.2, 0.3], + [0.5, 0.6, 0.7], + [0.5, 0.6, 0.7], + [0.9, 0.0, 0.1]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, same w + resize_nn = NetResizeNearestNeighbor((2, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, same w + resize_nn = NetResizeNearestNeighbor((8, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8], + [0.5, 0.6, 0.7, 0.8], + [0.9, 0.0, 0.1, 0.2], + [0.9, 0.0, 0.1, 0.2]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, smaller w + resize_nn = NetResizeNearestNeighbor((3, 2)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.3], + [0.5, 0.7], + [0.9, 0.1]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, larger w + resize_nn = NetResizeNearestNeighbor((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.3, 0.3, 0.4], + [0.5, 0.5, 0.6, 0.7, 0.7, 0.8], + [0.9, 0.9, 0.0, 0.1, 0.1, 0.2]]]]).astype(datatype)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same w, same h (identity) + resize_nn = NetResizeNearestNeighbor((3, 4)) + output = resize_nn(input_tensor) + np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy()) + + +def test_resize_nn_rgb_integer_ratio(): + input_tensor = Tensor(np.array( + [[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], + [[11, 12, 13], [14, 15, 16], [17, 18, 19]], + [[111, 112, 113], [114, 115, 116], [117, 118, 119]]]]).astype(np.int32)) + + # larger h and w + resize_nn = NetResizeNearestNeighbor((9, 9)) + output = resize_nn(input_tensor) + expected_output_array = np.array([[[[1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 1, 1, 2, 2, 2, 3, 3, 3], + [4, 4, 4, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 6, 6, 6], + [4, 4, 4, 5, 5, 5, 6, 6, 6], + [7, 7, 7, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 9, 9, 9], + [7, 7, 7, 8, 8, 8, 9, 9, 9]], + [[11, 11, 11, 12, 12, 12, 13, 13, 13], + [11, 11, 11, 12, 12, 12, 13, 13, 13], + [11, 11, 11, 12, 12, 12, 13, 13, 13], + [14, 14, 14, 15, 15, 15, 16, 16, 16], + [14, 14, 14, 15, 15, 15, 16, 16, 16], + [14, 14, 14, 15, 15, 15, 16, 16, 16], + [17, 17, 17, 18, 18, 18, 19, 19, 19], + [17, 17, 17, 18, 18, 18, 19, 19, 19], + [17, 17, 17, 18, 18, 18, 19, 19, 19]], + [[111, 111, 111, 112, 112, 112, 113, 113, 113], + [111, 111, 111, 112, 112, 112, 113, 113, 113], + [111, 111, 111, 112, 112, 112, 113, 113, 113], + [114, 114, 114, 115, 115, 115, 116, 116, 116], + [114, 114, 114, 115, 115, 115, 116, 116, 116], + [114, 114, 114, 115, 115, 115, 116, 116, 116], + [117, 117, 117, 118, 118, 118, 119, 119, 119], + [117, 117, 117, 118, 118, 118, 119, 119, 119], + [117, 117, 117, 118, 118, 118, 119, 119, 119]]]]) + expected_output = Tensor(np.array(expected_output_array).astype(np.int32)) + + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h and w + resize_nn = NetResizeNearestNeighbor((1, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor( + np.array([[[[1]], [[11]], [[111]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, larger w + resize_nn = NetResizeNearestNeighbor((1, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3]], + [[11, 11, 12, 12, 13, 13]], + [[111, 111, 112, 112, 113, 113]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, smaller w + resize_nn = NetResizeNearestNeighbor((6, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1], [1], [4], [4], [7], [7]], + [[11], [11], [14], [14], [17], [17]], + [[111], [111], [114], [114], [117], [117]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, same w + resize_nn = NetResizeNearestNeighbor((1, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 2, 3]], + [[11, 12, 13]], + [[111, 112, 113]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, same w + resize_nn = NetResizeNearestNeighbor((6, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 2, 3], + [1, 2, 3], + [4, 5, 6], + [4, 5, 6], + [7, 8, 9], + [7, 8, 9]], + [[11, 12, 13], + [11, 12, 13], + [14, 15, 16], + [14, 15, 16], + [17, 18, 19], + [17, 18, 19]], + [[111, 112, 113], + [111, 112, 113], + [114, 115, 116], + [114, 115, 116], + [117, 118, 119], + [117, 118, 119]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, smaller w + resize_nn = NetResizeNearestNeighbor((3, 1)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1], [4], [7]], + [[11], [14], [17]], + [[111], [114], [117]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, larger w + resize_nn = NetResizeNearestNeighbor((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3], + [4, 4, 5, 5, 6, 6], + [7, 7, 8, 8, 9, 9]], + [[11, 11, 12, 12, 13, 13], + [14, 14, 15, 15, 16, 16], + [17, 17, 18, 18, 19, 19]], + [[111, 111, 112, 112, 113, 113], + [114, 114, 115, 115, 116, 116], + [117, 117, 118, 118, 119, 119]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same w, same h (identity) + resize_nn = NetResizeNearestNeighbor((3, 3)) + output = resize_nn(input_tensor) + np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy()) + + +def test_resize_nn_rgb_not_integer_ratio(): + input_tensor = Tensor(np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 0, 1, 2]], + [[11, 12, 13, 14], + [15, 16, 17, 18], + [19, 10, 11, 12]], + [[111, 112, 113, 114], + [115, 116, 117, 118], + [119, 110, 111, 112]]]]).astype(np.int32)) + + # larger h and w + resize_nn = NetResizeNearestNeighbor((7, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3, 4], + [1, 1, 2, 2, 3, 3, 4], + [1, 1, 2, 2, 3, 3, 4], + [5, 5, 6, 6, 7, 7, 8], + [5, 5, 6, 6, 7, 7, 8], + [9, 9, 0, 0, 1, 1, 2], + [9, 9, 0, 0, 1, 1, 2]], + [[11, 11, 12, 12, 13, 13, 14], + [11, 11, 12, 12, 13, 13, 14], + [11, 11, 12, 12, 13, 13, 14], + [15, 15, 16, 16, 17, 17, 18], + [15, 15, 16, 16, 17, 17, 18], + [19, 19, 10, 10, 11, 11, 12], + [19, 19, 10, 10, 11, 11, 12]], + [[111, 111, 112, 112, 113, 113, 114], + [111, 111, 112, 112, 113, 113, 114], + [111, 111, 112, 112, 113, 113, 114], + [115, 115, 116, 116, 117, 117, 118], + [115, 115, 116, 116, 117, 117, 118], + [119, 119, 110, 110, 111, 111, 112], + [119, 119, 110, 110, 111, 111, 112]]]]).astype(np.int32)) + + # smaller h and w + resize_nn = NetResizeNearestNeighbor((2, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 2, 3], [5, 6, 7]], + [[11, 12, 13], [15, 16, 17]], + [[111, 112, 113], [115, 116, 117]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, larger w + resize_nn = NetResizeNearestNeighbor((2, 7)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 1, 2, 2, 3, 3, 4], + [5, 5, 6, 6, 7, 7, 8]], + [[11, 11, 12, 12, 13, 13, 14], + [15, 15, 16, 16, 17, 17, 18]], + [[111, 111, 112, 112, 113, 113, 114], + [115, 115, 116, 116, 117, 117, 118]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, smaller w + resize_nn = NetResizeNearestNeighbor((5, 3)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 2, 3], + [1, 2, 3], + [5, 6, 7], + [5, 6, 7], + [9, 0, 1]], + [[11, 12, 13], + [11, 12, 13], + [15, 16, 17], + [15, 16, 17], + [19, 10, 11]], + [[111, 112, 113], + [111, 112, 113], + [115, 116, 117], + [115, 116, 117], + [119, 110, 111]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # smaller h, same w + resize_nn = NetResizeNearestNeighbor((2, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[11, 12, 13, 14], + [15, 16, 17, 18]], + [[111, 112, 113, 114], + [115, 116, 117, 118]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # larger h, same w + resize_nn = NetResizeNearestNeighbor((8, 4)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4], + [5, 6, 7, 8], + [5, 6, 7, 8], + [5, 6, 7, 8], + [9, 0, 1, 2], + [9, 0, 1, 2]], + [[11, 12, 13, 14], + [11, 12, 13, 14], + [11, 12, 13, 14], + [15, 16, 17, 18], + [15, 16, 17, 18], + [15, 16, 17, 18], + [19, 10, 11, 12], + [19, 10, 11, 12]], + [[111, 112, 113, 114], + [111, 112, 113, 114], + [111, 112, 113, 114], + [115, 116, 117, 118], + [115, 116, 117, 118], + [115, 116, 117, 118], + [119, 110, 111, 112], + [119, 110, 111, 112]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, smaller w + resize_nn = NetResizeNearestNeighbor((3, 2)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 3], [5, 7], [9, 1]], + [[11, 13], [15, 17], [19, 11]], + [[111, 113], [115, 117], [119, 111]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same h, larger w + resize_nn = NetResizeNearestNeighbor((3, 6)) + output = resize_nn(input_tensor) + expected_output = Tensor(np.array([[[[1, 1, 2, 3, 3, 4], + [5, 5, 6, 7, 7, 8], + [9, 9, 0, 1, 1, 2]], + [[11, 11, 12, 13, 13, 14], + [15, 15, 16, 17, 17, 18], + [19, 19, 10, 11, 11, 12]], + [[111, 111, 112, 113, 113, 114], + [115, 115, 116, 117, 117, 118], + [119, 119, 110, 111, 111, 112]]]]).astype(np.int32)) + np.testing.assert_array_equal(expected_output.asnumpy(), output.asnumpy()) + + # same w, same h (identity) + resize_nn = NetResizeNearestNeighbor((3, 4)) + output = resize_nn(input_tensor) + np.testing.assert_array_equal(output.asnumpy(), input_tensor.asnumpy()) + + +def resize_nn_grayscale_multiple_images(datatype): + input_tensor = Tensor(np.array([[[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]], + [[[0.4, 0.5, 0.6], [0.7, 0.8, 0.9], [0.1, 0.2, 0.3]]], + [[[0.7, 0.8, 0.9], [0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]]]).astype(datatype)) + + resize_nn = NetResizeNearestNeighbor((2, 6)) + output = resize_nn(input_tensor) + + expected_output = Tensor(np.array([[[[0.1, 0.1, 0.2, 0.2, 0.3, 0.3], + [0.4, 0.4, 0.5, 0.5, 0.6, 0.6]]], + [[[0.4, 0.4, 0.5, 0.5, 0.6, 0.6], + [0.7, 0.7, 0.8, 0.8, 0.9, 0.9]]], + [[[0.7, 0.7, 0.8, 0.8, 0.9, 0.9], + [0.1, 0.1, 0.2, 0.2, 0.3, 0.3]]]]).astype(datatype)) + + np.testing.assert_array_equal(output.asnumpy(), expected_output.asnumpy()) + + +def resize_nn_grayscale_align_corners(datatype): + input_tensor = Tensor( + np.array([[[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]]]]).astype(datatype)) + + resize_nn_corners_aligned = NetResizeNearestNeighbor( + (3, 7), align_corners=True) + output_corners_aligned = resize_nn_corners_aligned(input_tensor) + + resize_nn = NetResizeNearestNeighbor((3, 7)) + output = resize_nn(input_tensor) + + expected_output = Tensor(np.array([[[[0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4], + [0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8], + [0.5, 0.6, 0.6, 0.7, 0.7, 0.8, 0.8]]]]).astype(datatype)) + + np.testing.assert_array_equal( + output_corners_aligned.asnumpy(), expected_output.asnumpy()) + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, + output.asnumpy(), expected_output.asnumpy()) + + +def test_resize_nn_rgb_multiple(): + input_tensor = Tensor(np.array([[[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], + [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], + [[111, 112, 113, 114, 115], [116, 117, 118, 119, 120]]], + [[[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], + [[111, 112, 113, 114, 115], [116, 117, 118, 119, 120]], + [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]], + [[[111, 112, 113, 114, 115], [116, 117, 118, 119, 120]], + [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], + [[11, 12, 13, 14, 15], [16, 17, 18, 19, 20]]]]).astype(np.int32)) + + resize_nn = NetResizeNearestNeighbor((5, 2)) + output = resize_nn(input_tensor) + + expected_output = Tensor(np.array([[[[1, 3], [1, 3], [1, 3], [6, 8], [6, 8]], + [[11, 13], [11, 13], [11, 13], [16, 18], [16, 18]], + [[111, 113], [111, 113], [111, 113], [116, 118], [116, 118]]], + [[[11, 13], [11, 13], [11, 13], [16, 18], [16, 18]], + [[111, 113], [111, 113], [111, 113], [116, 118], [116, 118]], + [[1, 3], [1, 3], [1, 3], [6, 8], [6, 8]]], + [[[111, 113], [111, 113], [111, 113], [116, 118], [116, 118]], + [[1, 3], [1, 3], [1, 3], [6, 8], [6, 8]], + [[11, 13], [11, 13], [11, 13], [16, 18], [16, 18]]]]).astype(np.int32)) + + np.testing.assert_array_equal(output.asnumpy(), expected_output.asnumpy()) + + +def test_resize_nn_rgb_align_corners(): + input_tensor = Tensor(np.array([[[[1, 2, 3, 4], [5, 6, 7, 8]], + [[11, 12, 13, 14], [15, 16, 17, 18]], + [[21, 22, 23, 24], [25, 26, 27, 28]]]]).astype(np.int32)) + + resize_nn_corners_aligned = NetResizeNearestNeighbor( + (5, 2), align_corners=True) + output_corners_aligned = resize_nn_corners_aligned(input_tensor) + + resize_nn = NetResizeNearestNeighbor((5, 2)) + output = resize_nn(input_tensor) + + expected_output = Tensor(np.array([[[[1, 4], [1, 4], [5, 8], [5, 8], [5, 8]], + [[11, 14], [11, 14], [15, 18], + [15, 18], [15, 18]], + [[21, 24], [21, 24], [25, 28], [25, 28], [25, 28]]]]).astype(np.int32)) + + np.testing.assert_array_equal( + output_corners_aligned.asnumpy(), expected_output.asnumpy()) + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, + output.asnumpy(), expected_output.asnumpy()) + + +def test_resize_nn_grayscale_integer_ratio_half(): + resize_nn_grayscale_integer_ratio(np.float16) + + +def test_resize_nn_grayscale_integer_ratio_float(): + resize_nn_grayscale_integer_ratio(np.float32) + + +def test_resize_nn_grayscale_not_integer_ratio_half(): + resize_nn_grayscale_not_integer_ratio(np.float16) + + +def test_resize_nn_grayscale_not_integer_ratio_float(): + resize_nn_grayscale_not_integer_ratio(np.float32) + + +def test_resize_nn_grayscale_multiple_half(): + resize_nn_grayscale_multiple_images(np.float16) + + +def test_resize_nn_grayscale_multiple_float(): + resize_nn_grayscale_multiple_images(np.float32) + + +def test_resize_nn_grayscale_align_corners_half(): + resize_nn_grayscale_align_corners(np.float16) + + +def test_resize_nn_grayscale_align_corners_float(): + resize_nn_grayscale_align_corners(np.float32) + + +if __name__ == "__main__": + test_resize_nn_grayscale_integer_ratio_half() + test_resize_nn_grayscale_integer_ratio_float() + test_resize_nn_grayscale_not_integer_ratio_half() + test_resize_nn_grayscale_not_integer_ratio_float() + test_resize_nn_grayscale_multiple_half() + test_resize_nn_grayscale_multiple_float() + test_resize_nn_grayscale_align_corners_half() + test_resize_nn_grayscale_align_corners_float() + test_resize_nn_rgb_integer_ratio() + test_resize_nn_rgb_not_integer_ratio() + test_resize_nn_rgb_multiple() + test_resize_nn_rgb_align_corners()