Merge pull request !3760 from 34bunny/GPU-ResizeNearestNeighbor-fixtags/v0.7.0-beta
| @@ -55,15 +55,15 @@ class ResizeNearestNeighborGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||
| if (output_num != 1) { | if (output_num != 1) { | ||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| shape_size_ = input_shape.size(); | shape_size_ = input_shape.size(); | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) { | if (shape_size_ != RESIZENEARESTNEIGHBOR_DIMENSION) { | ||||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " | |||||
| << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; | |||||
| MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " | |||||
| << RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| input_size_ = 1; | input_size_ = 1; | ||||
| @@ -38,10 +38,10 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel { | |||||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | ||||
| T *input = GetDeviceAddress<T>(inputs, 0); | T *input = GetDeviceAddress<T>(inputs, 0); | ||||
| T *output = GetDeviceAddress<T>(outputs, 0); | T *output = GetDeviceAddress<T>(outputs, 0); | ||||
| int size = SizeToInt(output_size_ / sizeof(T)); | |||||
| float h_scale = Scaling(input_shape_[2], output_shape_[2], align_corners_); | |||||
| float w_scale = Scaling(input_shape_[3], output_shape_[3], align_corners_); | |||||
| CalResizeNearestNeighborGrad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], | |||||
| int input_size = SizeToInt(input_size_ / sizeof(T)); | |||||
| float h_scale = Scaling(output_shape_[2], input_shape_[2], align_corners_); | |||||
| float w_scale = Scaling(output_shape_[3], input_shape_[3], align_corners_); | |||||
| CalResizeNearestNeighborGrad(input_size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], | |||||
| output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], | output, output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], | ||||
| align_corners_, h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr)); | align_corners_, h_scale, w_scale, reinterpret_cast<cudaStream_t>(stream_ptr)); | ||||
| return true; | return true; | ||||
| @@ -55,15 +55,15 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | ||||
| if (output_num != 1) { | if (output_num != 1) { | ||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor needs 1 output."; | |||||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but ResizeNearestNeighbor has 1 output."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| shape_size_ = input_shape.size(); | shape_size_ = input_shape.size(); | ||||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | ||||
| if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) { | if (shape_size_ != RESIZENEARESTNEIGHBORGRAD_DIMENSION) { | ||||
| MS_LOG(EXCEPTION) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " | |||||
| << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; | |||||
| MS_LOG(ERROR) << "Input is " << shape_size_ << "-D, but ResizeNearestNeighbor supports only " | |||||
| << RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs."; | |||||
| return false; | return false; | ||||
| } | } | ||||
| input_size_ = 1; | input_size_ = 1; | ||||
| @@ -18,64 +18,73 @@ | |||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include <math.h> | #include <math.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" | |||||
| #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" | #include "backend/kernel_compiler/gpu/cuda_impl/resize_nearest_neighbor_grad_impl.cuh" | ||||
| template <typename T> | template <typename T> | ||||
| __global__ void ResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, | |||||
| const int d4, bool align_corners, float h_scale, float w_scale) { | |||||
| __global__ void InitZero(T *output, const int output_size) { | |||||
| for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (output_size); pos += gridDim.x * blockDim.x) { | |||||
| output[pos] = static_cast<T>(0); | |||||
| } | |||||
| } | |||||
| template <typename T> | |||||
| __global__ void ResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, | |||||
| const int s3, const int s4, T *output, const int d1, const int d2, | |||||
| const int d3, const int d4, bool align_corners, float h_scale, | |||||
| float w_scale) { | |||||
| // initialization | // initialization | ||||
| // HalfPixelCenters false | // HalfPixelCenters false | ||||
| int input_pos; | |||||
| int output_pos; | |||||
| int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION]; | int pos_array[RESIZENEARESTNEIGHBORGRAD_DIMENSION]; | ||||
| int in_height = s3; | |||||
| int in_width = s4; | |||||
| int out_height = d3; | |||||
| int out_width = d4; | |||||
| // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + | // for example 4-D: pos = pos_array[0] * output_shape[1] * output_shape[2] * output_shape[3] + | ||||
| // pos_array[1] * output_shape[2] * output_shape[3] + | // pos_array[1] * output_shape[2] * output_shape[3] + | ||||
| // pos_array[2] * output_shape[3] + | // pos_array[2] * output_shape[3] + | ||||
| // pos_array[3] | // pos_array[3] | ||||
| T h_scale_ = static_cast<T>(h_scale); | |||||
| T w_scale_ = static_cast<T>(w_scale); | |||||
| T out_h_; | |||||
| T out_w_; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | |||||
| pos_array[0] = pos / (d2 * d3 * d4) % d1; | |||||
| pos_array[1] = pos / (d3 * d4) % d2; | |||||
| pos_array[2] = pos / (d4) % d3; | |||||
| pos_array[3] = pos % d4; | |||||
| out_h_ = static_cast<T>(pos_array[2]); | |||||
| out_w_ = static_cast<T>(pos_array[3]); | |||||
| const int in_y = | |||||
| min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)), | |||||
| in_height - 1); | |||||
| const int in_x = | |||||
| min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)), | |||||
| in_width - 1); | |||||
| // pos_array[0] N, pos_array[1] C, in_y H, in_x W | |||||
| input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; | |||||
| output[pos] = input[input_pos]; | |||||
| int in_h; | |||||
| int in_w; | |||||
| for (size_t pos = threadIdx.x + blockIdx.x * blockDim.x; pos < (input_size); pos += gridDim.x * blockDim.x) { | |||||
| pos_array[0] = pos / (s2 * s3 * s4) % s1; | |||||
| pos_array[1] = pos / (s3 * s4) % s2; | |||||
| pos_array[2] = pos / (s4) % s3; | |||||
| pos_array[3] = pos % s4; | |||||
| in_h = pos_array[2]; | |||||
| in_w = pos_array[3]; | |||||
| const int out_y = | |||||
| min((align_corners) ? static_cast<int>(roundf(in_h * h_scale)) : static_cast<int>(floorf(in_h * h_scale)), | |||||
| out_height - 1); | |||||
| const int out_x = | |||||
| min((align_corners) ? static_cast<int>(roundf(in_w * w_scale)) : static_cast<int>(floorf(in_w * w_scale)), | |||||
| out_width - 1); | |||||
| // pos_array[0] N, pos_array[1] C, out_y H, out_x W | |||||
| output_pos = pos_array[0] * d2 * d3 * d4 + pos_array[1] * d3 * d4 + out_y * d4 + out_x; | |||||
| ms_atomic_add(&output[output_pos], input[pos]); | |||||
| } | } | ||||
| return; | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | ||||
| bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) { | bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream) { | ||||
| ResizeNearestNeighborGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>( | |||||
| size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); | |||||
| int output_size = d1 * d2 * d3 * d4; | |||||
| InitZero<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(output, output_size); | |||||
| ResizeNearestNeighborGrad<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>( | |||||
| input_size, input, s1, s2, s3, s4, output, d1, d2, d3, d4, align_corners, h_scale, w_scale); | |||||
| return; | return; | ||||
| } | } | ||||
| template void CalResizeNearestNeighborGrad<float>(const int size, const float *input, const int s1, const int s2, | |||||
| template void CalResizeNearestNeighborGrad<float>(const int input_size, const float *input, const int s1, const int s2, | |||||
| const int s3, const int s4, float *output, const int d1, const int d2, | const int s3, const int s4, float *output, const int d1, const int d2, | ||||
| const int d3, const int d4, bool align_corners, float h_scale, | const int d3, const int d4, bool align_corners, float h_scale, | ||||
| float w_scale, cudaStream_t cuda_stream); | float w_scale, cudaStream_t cuda_stream); | ||||
| template void CalResizeNearestNeighborGrad<half>(const int size, const half *input, const int s1, const int s2, | |||||
| template void CalResizeNearestNeighborGrad<half>(const int input_size, const half *input, const int s1, const int s2, | |||||
| const int s3, const int s4, half *output, const int d1, const int d2, | const int s3, const int s4, half *output, const int d1, const int d2, | ||||
| const int d3, const int d4, bool align_corners, float h_scale, | const int d3, const int d4, bool align_corners, float h_scale, | ||||
| float w_scale, cudaStream_t cuda_stream); | float w_scale, cudaStream_t cuda_stream); | ||||
| template void CalResizeNearestNeighborGrad<int>(const int size, const int *input, const int s1, const int s2, | |||||
| template void CalResizeNearestNeighborGrad<int>(const int input_size, const int *input, const int s1, const int s2, | |||||
| const int s3, const int s4, int *output, const int d1, const int d2, | const int s3, const int s4, int *output, const int d1, const int d2, | ||||
| const int d3, const int d4, bool align_corners, float h_scale, | const int d3, const int d4, bool align_corners, float h_scale, | ||||
| float w_scale, cudaStream_t cuda_stream); | float w_scale, cudaStream_t cuda_stream); | ||||
| @@ -21,7 +21,7 @@ | |||||
| #define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4 | #define RESIZENEARESTNEIGHBORGRAD_DIMENSION 4 | ||||
| template <typename T> | template <typename T> | ||||
| void CalResizeNearestNeighborGrad(const int size, const T *input, const int s1, const int s2, const int s3, | |||||
| void CalResizeNearestNeighborGrad(const int input_size, const T *input, const int s1, const int s2, const int s3, | |||||
| const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | const int s4, T *output, const int d1, const int d2, const int d3, const int d4, | ||||
| bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); | bool align_corners, float h_scale, float w_scale, cudaStream_t cuda_stream); | ||||
| @@ -34,22 +34,20 @@ __global__ void ResizeNearestNeighbor(const int size, const T *input, const int | |||||
| // pos_array[1] * output_shape[2] * output_shape[3] + | // pos_array[1] * output_shape[2] * output_shape[3] + | ||||
| // pos_array[2] * output_shape[3] + | // pos_array[2] * output_shape[3] + | ||||
| // pos_array[3] | // pos_array[3] | ||||
| T h_scale_ = static_cast<T>(h_scale); | |||||
| T w_scale_ = static_cast<T>(w_scale); | |||||
| T out_h_; | |||||
| T out_w_; | |||||
| int out_h; | |||||
| int out_w; | |||||
| for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { | ||||
| pos_array[0] = pos / (d2 * d3 * d4) % d1; | pos_array[0] = pos / (d2 * d3 * d4) % d1; | ||||
| pos_array[1] = pos / (d3 * d4) % d2; | pos_array[1] = pos / (d3 * d4) % d2; | ||||
| pos_array[2] = pos / (d4) % d3; | pos_array[2] = pos / (d4) % d3; | ||||
| pos_array[3] = pos % d4; | pos_array[3] = pos % d4; | ||||
| out_h_ = static_cast<T>(pos_array[2]); | |||||
| out_w_ = static_cast<T>(pos_array[3]); | |||||
| out_h = pos_array[2]; | |||||
| out_w = pos_array[3]; | |||||
| const int in_y = | const int in_y = | ||||
| min((align_corners) ? static_cast<int>(roundf(out_h_ * h_scale_)) : static_cast<int>(floorf(out_h_ * h_scale_)), | |||||
| min((align_corners) ? static_cast<int>(roundf(out_h * h_scale)) : static_cast<int>(floorf(out_h * h_scale)), | |||||
| in_height - 1); | in_height - 1); | ||||
| const int in_x = | const int in_x = | ||||
| min((align_corners) ? static_cast<int>(roundf(out_w_ * w_scale_)) : static_cast<int>(floorf(out_w_ * w_scale_)), | |||||
| min((align_corners) ? static_cast<int>(roundf(out_w * w_scale)) : static_cast<int>(floorf(out_w * w_scale)), | |||||
| in_width - 1); | in_width - 1); | ||||
| // pos_array[0] N, pos_array[1] C, in_y H, in_x W | // pos_array[0] N, pos_array[1] C, in_y H, in_x W | ||||
| input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; | input_pos = pos_array[0] * s2 * s3 * s4 + pos_array[1] * s3 * s4 + in_y * s4 + in_x; | ||||
| @@ -0,0 +1,40 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cuda_fp16.h> | |||||
| inline __device__ float ms_atomic_add(float *address, float val) { return atomicAdd(address, val); } | |||||
| inline __device__ int ms_atomic_add(int *address, int val) { return atomicAdd(address, val); } | |||||
| inline __device__ half ms_atomic_add(half *address, half val) { | |||||
| unsigned int *aligned = | |||||
| reinterpret_cast<unsigned int *>(reinterpret_cast<size_t>(address) - (reinterpret_cast<size_t>(address) & 2)); | |||||
| unsigned int old = *aligned; | |||||
| unsigned int assumed; | |||||
| unsigned short old_as_us; //NOLINT | |||||
| do { | |||||
| assumed = old; | |||||
| old_as_us = static_cast<unsigned short>(reinterpret_cast<size_t>(address) & 2 ? old >> 16 : old & 0xffff); //NOLINT | |||||
| half sum = __float2half_rn(__half2float(__ushort_as_half(old_as_us)) + static_cast<float>(val)); | |||||
| unsigned short sum_as_us = __half_as_ushort(sum); //NOLINT | |||||
| unsigned int sum_as_ui = | |||||
| reinterpret_cast<size_t>(address) & 2 ? (sum_as_us << 16) | (old & 0xffff) : (old & 0xffff0000) | sum_as_us; | |||||
| old = atomicCAS(aligned, assumed, sum_as_ui); | |||||
| } while (assumed != old); | |||||
| __half_raw raw = {old_as_us}; | |||||
| return half(raw); | |||||
| } | |||||
| @@ -43,15 +43,21 @@ class ResizeNearestNeighborGradAlignCornerF(nn.Cell): | |||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| def test_ResizeNearestNeighborGradAlignCornerT(): | def test_ResizeNearestNeighborGradAlignCornerT(): | ||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | ||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) | |||||
| dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float32) | |||||
| size = (4, 4) | |||||
| expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float32) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerT() | rnn = ResizeNearestNeighborGradAlignCornerT() | ||||
| output = rnn(Tensor(dy), size) | output = rnn(Tensor(dy), size) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) | |||||
| dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.float16) | |||||
| size = (4, 4) | |||||
| expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.float16) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerT() | |||||
| output = rnn(Tensor(dy), size) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| dy = np.array([[[[1, 2], [3, 4]]]]).astype(np.int32) | |||||
| size = (4, 4) | |||||
| expect = np.array([[[[1, 0, 0, 2], [0, 0, 0, 0], [0, 0, 0, 0], [3, 0, 0, 4]]]]).astype(np.int32) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerT() | rnn = ResizeNearestNeighborGradAlignCornerT() | ||||
| output = rnn(Tensor(dy), size) | output = rnn(Tensor(dy), size) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| @@ -63,13 +69,19 @@ def test_ResizeNearestNeighborGradAlignCornerF(): | |||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | ||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float32) | ||||
| size = (2, 2) | size = (2, 2) | ||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float32) | |||||
| expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float32) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerF() | rnn = ResizeNearestNeighborGradAlignCornerF() | ||||
| output = rnn(Tensor(dy), size) | output = rnn(Tensor(dy), size) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.float16) | ||||
| size = (2, 2) | size = (2, 2) | ||||
| expect = np.array([[[[1, 0], [0, 1]]]]).astype(np.float16) | |||||
| expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.float16) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerF() | |||||
| output = rnn(Tensor(dy), size) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| dy = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) | |||||
| size = (2, 2) | |||||
| expect = np.array([[[[4, 0], [0, 4]]]]).astype(np.int32) | |||||
| rnn = ResizeNearestNeighborGradAlignCornerF() | rnn = ResizeNearestNeighborGradAlignCornerF() | ||||
| output = rnn(Tensor(dy), size) | output = rnn(Tensor(dy), size) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| @@ -53,6 +53,11 @@ def test_ResizeNearestNeighborAlignCornerT(): | |||||
| rnn = ResizeNearestNeighborAlignCornerT((4, 4)) | rnn = ResizeNearestNeighborAlignCornerT((4, 4)) | ||||
| output = rnn(input_tensor) | output = rnn(input_tensor) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32)) | |||||
| expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) | |||||
| rnn = ResizeNearestNeighborAlignCornerT((4, 4)) | |||||
| output = rnn(input_tensor) | |||||
| assert np.all(output.asnumpy() == expect) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_gpu_training | @pytest.mark.platform_x86_gpu_training | ||||
| @@ -69,3 +74,8 @@ def test_ResizeNearestNeighborAlignCornerF(): | |||||
| rnn = ResizeNearestNeighborAlignCornerF((4, 4)) | rnn = ResizeNearestNeighborAlignCornerF((4, 4)) | ||||
| output = rnn(input_tensor) | output = rnn(input_tensor) | ||||
| assert np.all(output.asnumpy() == expect) | assert np.all(output.asnumpy() == expect) | ||||
| input_tensor = Tensor(np.array([[[[1, 0], [0, 1]]]]).astype(np.int32)) | |||||
| expect = np.array([[[[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]]]).astype(np.int32) | |||||
| rnn = ResizeNearestNeighborAlignCornerF((4, 4)) | |||||
| output = rnn(input_tensor) | |||||
| assert np.all(output.asnumpy() == expect) | |||||