From e6a41b1938edb9adcf8e7682092f3320cc33e3ad Mon Sep 17 00:00:00 2001 From: ZhengQihao3f3f3f Date: Tue, 27 Apr 2021 16:46:04 +0800 Subject: [PATCH] Add ops test PR 1 --- .../gpu/cuda_impl/adaptive_avg_pool2d_impl.cu | 168 ++++++++++++++++++ .../cuda_impl/adaptive_avg_pool2d_impl.cuh | 25 +++ .../gpu/nn/adaptive_avg_pool2d_gpu_kernel.cc | 31 ++++ .../gpu/nn/adaptive_avg_pool2d_gpu_kernel.h | 120 +++++++++++++ mindspore/ops/operations/__init__.py | 7 +- mindspore/ops/operations/nn_ops.py | 80 ++++++++- .../st/ops/gpu/test_adaptive_avg_pool2d_op.py | 104 +++++++++++ 7 files changed, 526 insertions(+), 9 deletions(-) create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cu create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.cc create mode 100644 mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_adaptive_avg_pool2d_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cu new file mode 100644 index 0000000000..92ac1de446 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cu @@ -0,0 +1,168 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh" + +__device__ inline uint start_index(uint a, uint b, uint c) { + return floorf(__uint2float_rn(a * c) / __uint2float_rn(b)); +} + +__device__ inline uint end_index(uint a, uint b, uint c) { + return ceilf(__uint2float_rn((a + 1) * c) / __uint2float_rn(b)); +} + +template +__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, T *input_data, + T *output_data) { + for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) { + T *input_ptr = input_data + c * input_height * input_width; + T *output_ptr = output_data + c * output_height * output_width; + + for (uint oh = 0; oh < output_height; oh++) { + uint ih0 = start_index(oh, output_height, input_height); + uint ih1 = end_index(oh, output_height, input_height); + uint kh = ih1 - ih0; + + for (uint ow = 0; ow < output_width; ow++) { + uint iw0 = start_index(ow, output_width, input_width); + uint iw1 = end_index(ow, output_width, input_width); + uint kw = iw1 - iw0; + + // compute local average + T sum = 0; + for (uint ih = ih0; ih < ih1; ih++) { + for (uint iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / kh / kw; + } + } + } +} + +template <> +__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, float *input_data, + float *output_data) { + for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) { + float *input_ptr = input_data + c * input_height * input_width; + float *output_ptr = output_data + c * output_height * output_width; + + for (uint oh = 0; oh < output_height; oh++) { + uint ih0 = start_index(oh, output_height, input_height); + uint ih1 = end_index(oh, output_height, input_height); + uint kh = ih1 - ih0; + + for (uint ow = 0; ow < output_width; ow++) { + uint iw0 = start_index(ow, output_width, input_width); + uint iw1 = end_index(ow, output_width, input_width); + uint kw = iw1 - iw0; + + // compute local average + float sum = 0; + for (uint ih = ih0; ih < ih1; ih++) { + for (uint iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / __uint2float_rn(kh * kw); + } + } + } +} + +template <> +__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, half *input_data, + half *output_data) { + for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) { + half *input_ptr = input_data + c * input_height * input_width; + half *output_ptr = output_data + c * output_height * output_width; + + for (uint oh = 0; oh < output_height; oh++) { + uint ih0 = start_index(oh, output_height, input_height); + uint ih1 = end_index(oh, output_height, input_height); + uint kh = ih1 - ih0; + + for (uint ow = 0; ow < output_width; ow++) { + uint iw0 = start_index(ow, output_width, input_width); + uint iw1 = end_index(ow, output_width, input_width); + uint kw = iw1 - iw0; + + // compute local average + half sum = 0; + for (uint ih = ih0; ih < ih1; ih++) { + for (uint iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / __uint2half_rn(kh * kw); + } + } + } +} + +template <> +__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, double *input_data, + double *output_data) { + for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) { + double *input_ptr = input_data + c * input_height * input_width; + double *output_ptr = output_data + c * output_height * output_width; + + for (uint oh = 0; oh < output_height; oh++) { + uint ih0 = start_index(oh, output_height, input_height); + uint ih1 = end_index(oh, output_height, input_height); + uint kh = ih1 - ih0; + + for (uint ow = 0; ow < output_width; ow++) { + uint iw0 = start_index(ow, output_width, input_width); + uint iw1 = end_index(ow, output_width, input_width); + uint kw = iw1 - iw0; + + // compute local average + double sum = 0; + for (uint ih = ih0; ih < ih1; ih++) { + for (uint iw = iw0; iw < iw1; iw++) { + sum += input_ptr[ih * input_width + iw]; + } + } + output_ptr[oh * output_width + ow] = sum / __uint2double_rn(kh * kw); + } + } + } +} + +template +void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, const uint output_height, + const uint output_width, T *input_data, T *output_data, cudaStream_t cuda_stream) { + AdaptiveAvgPool2DKernel<<>>( + size, input_height, input_width, output_height, output_width, input_data, output_data); +} + +template void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, float *input_data, + float *output_data, cudaStream_t cuda_stream); + +template void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, half *input_data, + half *output_data, cudaStream_t cuda_stream); + +template void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, + const uint output_height, const uint output_width, double *input_data, + double *output_data, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh new file mode 100644 index 0000000000..88dda5cf5e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAPTIVEAVGPOOL2D_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAPTIVEAVGPOOL2D_IMPL_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, const uint output_height, + const uint output_width, T *input_data, T *output_data, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAPTIVEAVGPOOL2D_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.cc new file mode 100644 index 0000000000..81c4234319 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.cc @@ -0,0 +1,31 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(AdaptiveAvgPool2D, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + AdaptiveAvgPool2DKernel, half) +MS_REG_GPU_KERNEL_ONE(AdaptiveAvgPool2D, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + AdaptiveAvgPool2DKernel, float) +MS_REG_GPU_KERNEL_ONE(AdaptiveAvgPool2D, + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + AdaptiveAvgPool2DKernel, double) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h new file mode 100644 index 0000000000..299f47efbd --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h @@ -0,0 +1,120 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAPTIVEAVGPOOL2D_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAPTIVEAVGPOOL2D_GPU_KERNEL_H_ + +#include +#include +#include +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class AdaptiveAvgPool2DKernel : public GpuKernel { + public: + AdaptiveAvgPool2DKernel() + : input_size_(0), + output_size_(0), + len(0), + input_height(0), + input_width(0), + output_height(0), + output_width(0), + size(0) {} + ~AdaptiveAvgPool2DKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector & /*workspace*/, + const std::vector &outputs, void *stream_ptr) override { + T *input_addr = GetDeviceAddress(inputs, 0); + T *output_addr = GetDeviceAddress(outputs, 0); + + ApplyAdaptiveAvgPool2D(size, input_height, input_width, output_height, output_width, input_addr, output_addr, + reinterpret_cast(stream_ptr)); + + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + auto shape_addr = AnfAlgo::GetNodeAttr>(kernel_node, "output_size"); + if (shape_addr.size() == 1) { + output_height = shape_addr[0]; + output_width = shape_addr[0]; + } else if (shape_addr.size() == 2) { + output_height = static_cast(shape_addr[1]); + output_width = static_cast(shape_addr[0]); + } else { + MS_LOG(ERROR) << "Input Error."; + } + + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but adaptive_avg_pool2d needs 1 inputs."; + return false; + } + + input_size_ = sizeof(T); + output_size_ = sizeof(T); + + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + len = static_cast(input_shape.size()); + input_height = static_cast(input_shape[len - 1]); + input_width = static_cast(input_shape[len - 2]); + size = static_cast(len == 3 ? input_shape[0] : input_shape[0] * input_shape[1]); + for (uint i = 0; i < len; i++) { + input_size_ *= input_shape[i]; + } + + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + size_t input_size_; + size_t output_size_; + uint len; + uint input_height; + uint input_width; + uint output_height; + uint output_width; + uint size; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAPTIVEAVGPOOL2D_GPU_KERNEL_H_ diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index a877d23758..f4db6cd591 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -41,7 +41,8 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import GeSwitch, Merge -from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, +from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, + MakeRefKey, FusedWeightScaleApplyMomentum, AdamWeightDecay) from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, @@ -83,7 +84,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam FusedSparseFtrl, FusedSparseProximalAdagrad, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent, - ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) + ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D) from . import _quant_ops from ._quant_ops import * from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, @@ -107,7 +108,6 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto GetCenterOfGeometry, MDTemperature, NeighborListUpdate, MDIterationLeapFrogLiujian, CrdToUintCrd, MDIterationSetupRandState, TransferCrd) - __all__ = [ 'Unique', 'ReverseSequence', @@ -469,6 +469,7 @@ __all__ = [ "CrdToUintCrd", "MDIterationSetupRandState", "TransferCrd", + "AdaptiveAvgPool2D" ] diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index c34d5c4bb6..8f1e2cc3b8 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -129,6 +129,74 @@ class Flatten(PrimitiveWithInfer): return input_x +class AdaptiveAvgPool2D(PrimitiveWithInfer): + r""" + AdaptiveAvgPool2D operation. + + This operator applies a 2D adaptive average pooling to an input signal composed of multiple input planes. + That is, for any input size, the size of the specified output is H x W. + The number of output features is equal to the number of input planes. + + Args: + output_size (Union[int, tuple]): The target output size is H x W. + ouput_size can be a tulpe, or a single H for H x H, and H x W can be int or None + which means the output size is the same as the input. + + Inputs: + - **input_x** (Tensor) - The input of AdaptiveAvgPool2D, which is a 3D or 4D tensor, + with float16, float32, float64 data type. + + Outputs: + Tensor, with the same type and same dimensions as the input_x. + + Raises: + ValueError: if `output_size` is not a tuple and if `output_size` length is not 2. + TypeError: If `input_x` is not a tensor. + TypeError: If dtype of `input_x` is not float16, float32, float64. + ValueError: If `input_x` dimension is less than or more than output_size dimension. + + Supported Platforms: + ``GPU`` + + Examples: + >>> input_x = Tensor(np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + >>> [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + >>> [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]), mindspore.float32) + >>> adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D((2, 2)) + >>> output = adaptive_avg_pool_2d(input_x) + >>> print(output) + [[[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]]] + """ + + @prim_attr_register + def __init__(self, output_size): + validator.check_value_type("output_size", output_size, [int, tuple], self.name) + if isinstance(output_size, tuple): + validator.check_int(len(output_size), 2, Rel.EQ, 'output_size', self.name) + self.output_size = (output_size, output_size) if isinstance(self.output_size, int) else output_size + + def infer_shape(self, x_shape): + if len(x_shape) <= len(self.output_size): + raise ValueError("{} dimension should be larger than {} dimension".format(x_shape, self.output_size)) + validator.check_int(len(x_shape), 5, Rel.LT, 'input_x_dimensions', self.name) + for input_x_dimension in x_shape: + validator.check_int(input_x_dimension, 0, Rel.GT, 'input_x dimension', self.name) + zipped = zip(self.output_size, x_shape[-len(self.output_size):]) + out_size = [i if i else j for i, j in zipped] + for item in out_size: + validator.check_value_type("item of output_size", item, [int], self.name) + self.add_prim_attr('output_size', (out_size)) + output_shape = x_shape[:len(x_shape) - len(out_size)] + out_size + return output_shape + + def infer_dtype(self, x_dtype): + validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32, mstype.float64], + self.name) + return x_dtype + + class Softmax(Primitive): r""" Softmax operation. @@ -3298,6 +3366,7 @@ class Gelu(PrimitiveWithInfer): Same as operator GeLU. Gelu will be deprecated in the future. Please use GeLU instead. """ + @deprecated("1.1", "GeLU", True) @prim_attr_register def __init__(self): @@ -3354,13 +3423,12 @@ class GeLU(Primitive): self.init_prim_io_names(inputs=['x'], outputs=['output']) - - class FastGelu(PrimitiveWithInfer): """ Same as operator FastGeLU. FastGelu will be deprecated in the future. Please use FastGeLU instead. """ + @deprecated("1.1", "FastGeLU", True) @prim_attr_register def __init__(self): @@ -8111,11 +8179,11 @@ class Conv3DTranspose(PrimitiveWithInfer): output_padding = (self.output_padding[2], self.output_padding[3], self.output_padding[4]) if self.pad_mode != 'pad' and output_padding != (0, 0, 0): raise ValueError(f"For '{self.name}', when output_padding is not 0, pad_mode should be set as 'pad'.") - validator.check_int_range(self.kernel_size[0]*self.kernel_size[1]*self.kernel_size[2], 1, 343, Rel.INC_BOTH, + validator.check_int_range(self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], 1, 343, Rel.INC_BOTH, 'The product of height, width and depth of kernel_size belonging [1, 343]', self.name) - validator.check_int_range(self.stride[0]*self.stride[1]*self.stride[2], 1, 343, Rel.INC_BOTH, + validator.check_int_range(self.stride[0] * self.stride[1] * self.stride[2], 1, 343, Rel.INC_BOTH, 'The product of height, width and depth of stride belonging [1, 343]', self.name) - validator.check_int_range(self.stride[1]*self.stride[2], 1, 256, Rel.INC_BOTH, + validator.check_int_range(self.stride[1] * self.stride[2], 1, 256, Rel.INC_BOTH, 'The product of height, width and depth of stride belonging [1, 256]', self.name) validator.check_int_range(self.output_padding[2], 0, max(self.dilation[2], self.stride[2]), Rel.INC_LEFT, 'output_padding_d belonging [0, max(stride_d, dilation_d))', self.name) @@ -8180,7 +8248,7 @@ class Conv3DTranspose(PrimitiveWithInfer): self.add_prim_attr('pad_list', self.pad_list) self.add_prim_attr('output_padding', self.output_padding) - output_shape = (x_shape[0], w_shape[1]*self.group, d_out, h_out, w_out) + output_shape = (x_shape[0], w_shape[1] * self.group, d_out, h_out, w_out) self.add_prim_attr('input_size', output_shape) out = { 'value': None, diff --git a/tests/st/ops/gpu/test_adaptive_avg_pool2d_op.py b/tests/st/ops/gpu/test_adaptive_avg_pool2d_op.py new file mode 100644 index 0000000000..df20f0e812 --- /dev/null +++ b/tests/st/ops/gpu/test_adaptive_avg_pool2d_op.py @@ -0,0 +1,104 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor, ops +from mindspore.ops import operations as P +from mindspore.common.api import ms_function + +context.set_context(device_target='GPU') + + +class Net(nn.Cell): + def __init__(self, output_size): + super(Net, self).__init__() + self.adaptive_avg_pool2d = P.AdaptiveAvgPool2D(output_size) + + @ms_function + def construct(self, x): + return self.adaptive_avg_pool2d(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_normal(): + x = np.random.randn(1, 32, 9, 9) + net = Net((3, 5)) + output = net(Tensor(x, mindspore.float32)) + expect_shape = (1, 32, 3, 5) + assert output.asnumpy().shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_single(): + x = np.random.randn(1, 32, 7, 9) + net = Net(5) + output = net(Tensor(x, mindspore.float32)) + expect_shape = (1, 32, 5, 5) + assert output.asnumpy().shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_none(): + x = np.random.randn(1, 32, 7, 9) + net = Net((None, 5)) + output = net(Tensor(x, mindspore.float32)) + expect_shape = (1, 32, 7, 5) + assert output.asnumpy().shape == expect_shape + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_value(): + x = np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]) + net = Net((2, 2)) + output = net(Tensor(x)) + expect_shape = (3, 2, 2) + expect_output = np.array([[[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]]]) + assert output.asnumpy().shape == expect_shape + assert (output.asnumpy() == expect_output).all + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_net_pynative(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + x = np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]) + adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D((2, 2)) + output = adaptive_avg_pool_2d(Tensor(x)) + expect_shape = (3, 2, 2) + expect_output = np.array([[[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]], + [[3.0, 4.0], [6.0, 7.0]]]) + assert output.asnumpy().shape == expect_shape + assert (output.asnumpy() == expect_output).all