Browse Source

support MaxPool3D(CPU/GPU), AvgPool3D(CPU/GPU) and Conv3D(CPU)

tags/v1.3.0
zuochuanyong 5 years ago
parent
commit
7864914515
17 changed files with 1208 additions and 90 deletions
  1. +1
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  2. +47
    -22
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.cc
  3. +16
    -8
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.h
  4. +9
    -7
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc
  5. +24
    -11
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc
  6. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h
  7. +28
    -22
      mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h
  8. +8
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc
  9. +75
    -17
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h
  10. +1
    -0
      mindspore/core/base/core_ops.h
  11. +2
    -1
      mindspore/ops/operations/__init__.py
  12. +98
    -0
      mindspore/ops/operations/nn_ops.py
  13. +179
    -0
      tests/st/ops/cpu/test_avgpool_op.py
  14. +85
    -2
      tests/st/ops/cpu/test_conv_op.py
  15. +179
    -0
      tests/st/ops/cpu/test_maxpool_op.py
  16. +273
    -0
      tests/st/ops/gpu/test_avgpool_gpu_op.py
  17. +179
    -0
      tests/st/ops/gpu/test_maxpool_gpu_op.py

+ 1
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -34,6 +34,7 @@ const char KERNEL_SIZE[] = "kernel_size";
const char STRIDE[] = "stride";
const char STRIDES[] = "strides";
const char DILATION[] = "dilation";
const char DILATIONS[] = "dilations";
const char FORMAT[] = "format";
const char PAD[] = "pad";
const char PAD_LIST[] = "pad_list";


mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.cc → mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h"
#include "backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.h"
#include <string>
#include <algorithm>
#include "utils/ms_utils.h"
@@ -22,19 +22,29 @@

namespace mindspore {
namespace kernel {
void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
constexpr size_t kConvInputTensorNum = 2;
constexpr size_t kShapeSize4D = 4;
constexpr size_t kShapeSize5D = 5;
constexpr size_t kKernelStartAxis = 2;

void ConvCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> src_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
std::vector<size_t> weight_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (src_shape.size() != 4 || weight_shape.size() != 4) {
MS_LOG(EXCEPTION) << "conv2d only support nchw input!";
size_t src_dim = src_shape.size();
size_t weight_dim = weight_shape.size();
if (src_dim < kShapeSize4D || src_dim > kShapeSize5D || src_dim != weight_dim) {
MS_LOG(EXCEPTION) << "conv only supports 4D/5D input!";
}
std::vector<size_t> kernel_size;
for (size_t i = kKernelStartAxis; i < src_dim; ++i) {
kernel_size.emplace_back(weight_shape[i]);
}
std::vector<size_t> kernel_size({weight_shape[2], weight_shape[3]});
size_t group = LongToSize(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, GROUP));
if (group != 1) {
if (src_shape[1] % group != 0) {
MS_LOG(EXCEPTION) << "conv2d channels should be divided by group!";
MS_LOG(EXCEPTION) << "conv channels should be divided by group!";
}
weight_shape.insert(weight_shape.begin(), group);
weight_shape[1] = weight_shape[1] / group;
@@ -44,35 +54,50 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
dnnl::memory::desc dst_desc = GetDefaultMemDesc(dst_shape);
std::vector<int> stride_ori;
std::vector<int> dilation_ori;
auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, STRIDE);
auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, DILATION);
auto stride_attr = src_dim == kShapeSize4D ? STRIDE : STRIDES;
auto dilation_attr = src_dim == kShapeSize4D ? DILATION : DILATIONS;
auto stride_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, stride_attr);
auto dilation_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, dilation_attr);
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_ori),
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(dilation_me.begin(), dilation_me.end(), std::back_inserter(dilation_ori),
[](const int64_t &value) { return static_cast<int>(value); });
if (stride_ori.size() != src_dim) {
MS_LOG(EXCEPTION) << "conv stride size must be " << src_dim << "D!";
}
if (stride_ori[0] != 1 || stride_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d stride only support 1 in N axis and C axis!";
}
if (dilation_ori.size() != 4) {
MS_LOG(EXCEPTION) << "conv2d dilation must be 4d!";
if (dilation_ori.size() != src_dim) {
MS_LOG(EXCEPTION) << "conv dilation size must be " << src_dim << "D!";
}
if (dilation_ori[0] != 1 || dilation_ori[1] != 1) {
MS_LOG(EXCEPTION) << "conv2d dilation only support 1 in N axis and C axis!";
}

std::vector<int> stride{stride_ori[2], stride_ori[3]};
std::vector<int> dilation{dilation_ori[2], dilation_ori[3]};
dnnl::memory::dims strides{stride_ori[2], stride_ori[3]};
dnnl::memory::dims dilates{dilation_ori[2] - 1, dilation_ori[3] - 1};
std::vector<int> stride;
std::vector<int> dilation;
dnnl::memory::dims strides;
dnnl::memory::dims dilates;
for (size_t i = kKernelStartAxis; i < src_dim; ++i) {
stride.emplace_back(stride_ori[i]);
strides.emplace_back(stride_ori[i]);
dilation.emplace_back(dilation_ori[i]);
dilates.emplace_back(dilation_ori[i] - 1);
}
std::vector<int> int_padding_l;
std::vector<int> int_padding_r;
const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r, dilation);
if (int_padding_l.size() != 2 || int_padding_r.size() != 2) {
if (int_padding_l.size() + kKernelStartAxis != src_dim || int_padding_r.size() + kKernelStartAxis != src_dim) {
MS_LOG(EXCEPTION) << "get padding failed";
}
dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]};
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
dnnl::memory::dims padding_l;
dnnl::memory::dims padding_r;
for (size_t i = 0; i < int_padding_l.size(); ++i) {
padding_l.emplace_back(int_padding_l[i]);
padding_r.emplace_back(int_padding_r[i]);
}
dnnl::convolution_forward::desc desc =
dnnl::convolution_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::convolution_auto, src_desc,
weights_desc, dst_desc, strides, dilates, padding_l, padding_r);
@@ -84,10 +109,10 @@ void Conv2dCPUKernel::InitKernel(const CNodePtr &kernel_node) {
AddArgument(DNNL_ARG_DST, dst_desc);
}

bool Conv2dCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < 2 || outputs.empty()) {
bool ConvCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() < kConvInputTensorNum || outputs.empty()) {
MS_LOG(EXCEPTION) << "error input output size!";
}
SetArgumentHandle(DNNL_ARG_SRC, inputs[0]->addr);

mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv2d_cpu_kernel.h → mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/conv_cpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV_CPU_KERNEL_H_

#include <vector>
#include <memory>
@@ -22,10 +22,10 @@

namespace mindspore {
namespace kernel {
class Conv2dCPUKernel : public MKLCPUKernel {
class ConvCPUKernel : public MKLCPUKernel {
public:
Conv2dCPUKernel() = default;
~Conv2dCPUKernel() override = default;
ConvCPUKernel() = default;
~ConvCPUKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override;

@@ -36,8 +36,16 @@ class Conv2dCPUKernel : public MKLCPUKernel {
MS_REG_CPU_KERNEL(
Conv2D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Conv2dCPUKernel);
ConvCPUKernel);
MS_REG_CPU_KERNEL(
Conv3D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConvCPUKernel);
MS_REG_CPU_KERNEL(
Conv3D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ConvCPUKernel);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV2D_CPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CONV_CPU_KERNEL_H_

+ 9
- 7
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.cc View File

@@ -27,12 +27,14 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
const std::vector<int> &stride, std::vector<int> *padding_l, std::vector<int> *padding_r,
const std::vector<int> &dilation) {
MS_EXCEPTION_IF_NULL(kernel_node);
if (src_shape.size() < 2) {
auto dim = src_shape.size();
if (dim < 2) {
MS_LOG(EXCEPTION) << "set pad only support src dim >= 2!";
}
std::vector<int> weight_height;
weight_height.emplace_back(src_shape[src_shape.size() - 2]);
weight_height.emplace_back(src_shape[src_shape.size() - 1]);
for (size_t i = 2; i < dim; ++i) {
weight_height.emplace_back(src_shape[i]);
}

MS_LOG(INFO) << "pad mode: " << pad_mode;
if (pad_mode == PAD_MODE_LOWER_SAME || pad_mode == PAD_MODE_UPPER_SAME) {
@@ -47,10 +49,10 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
}
} else if (pad_mode == PAD_MODE_LOWER_VALID || pad_mode == PAD_MODE_UPPER_VALID) {
MS_LOG(INFO) << "pad valid";
padding_l->emplace_back(0);
padding_l->emplace_back(0);
padding_r->emplace_back(0);
padding_r->emplace_back(0);
for (size_t i = 0; i < dim - 2; ++i) {
padding_l->emplace_back(0);
padding_r->emplace_back(0);
}
} else {
std::vector<int> pad;
std::vector<int64_t> pad_me = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, PAD_LIST);


+ 24
- 11
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.cc View File

@@ -41,28 +41,41 @@ void PoolingCPUKernel::InitKernel(const CNodePtr &kernel_node) {
[](const int64_t &value) { return static_cast<int>(value); });
(void)std::transform(strides_me.begin(), strides_me.end(), std::back_inserter(strides),
[](const int64_t &value) { return static_cast<int>(value); });
if (origin_kernel_sizes.size() != 4 || strides.size() != 4) {
auto dim = origin_kernel_sizes.size();
if (dim < 4 || dim > 5 || dim != strides.size()) {
MS_LOG(EXCEPTION) << "invalid kernel size " << origin_kernel_sizes.size() << " or stride size " << strides.size();
}
std::vector<int> stride{strides[2], strides[3]};
dnnl::memory::dims strides_dims{strides[2], strides[3]};
dnnl::memory::dims kernels_dims{origin_kernel_sizes[2], origin_kernel_sizes[3]};
const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
std::vector<int> stride;
dnnl::memory::dims kernels_dims;
dnnl::memory::dims strides_dims;
std::vector<size_t> kernel_size;
std::vector<int> dummy_dilation;
for (size_t i = 2; i < dim; ++i) {
stride.emplace_back(strides[i]);
kernels_dims.emplace_back(origin_kernel_sizes[i]);
strides_dims.emplace_back(strides[i]);
kernel_size.emplace_back(IntToSize(origin_kernel_sizes[i]));
dummy_dilation.emplace_back(1);
}

std::vector<int> int_padding_l;
std::vector<int> int_padding_r;
std::vector<size_t> kernel_size({IntToSize(origin_kernel_sizes[2]), IntToSize(origin_kernel_sizes[3])});
std::vector<int> dummy_dilation{1, 1};
const std::string pad_mode = AnfAlgo::GetNodeAttr<std::string>(kernel_node, PAD_MODE);
GetPadding(kernel_node, pad_mode, src_shape, kernel_size, stride, &int_padding_l, &int_padding_r, dummy_dilation);
if (int_padding_l.size() != 2 || int_padding_r.size() != 2) {
if (int_padding_l.size() != dim - 2 || int_padding_r.size() != dim - 2) {
MS_LOG(EXCEPTION) << "pooling get padding failed";
}
dnnl::memory::dims padding_l{int_padding_l[0], int_padding_l[1]};
dnnl::memory::dims padding_r{int_padding_r[0], int_padding_r[1]};
dnnl::memory::dims padding_l;
dnnl::memory::dims padding_r;
for (size_t i = 0; i < dim - 2; ++i) {
padding_l.emplace_back(int_padding_l[i]);
padding_r.emplace_back(int_padding_r[i]);
}
dnnl::pooling_forward::desc desc =
dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_max, src_desc, dst_desc,
strides_dims, kernels_dims, padding_l, padding_r);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == prim::kPrimAvgPool->name()) {
if (kernel_name == prim::kPrimAvgPool->name() || kernel_name == prim::kPrimAvgPool3D->name()) {
desc = dnnl::pooling_forward::desc(dnnl::prop_kind::forward_training, dnnl::algorithm::pooling_avg, src_desc,
dst_desc, strides_dims, kernels_dims, padding_l, padding_r);
}


+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/pooling_cpu_kernel.h View File

@@ -41,8 +41,12 @@ class PoolingCPUKernel : public MKLCPUKernel {

MS_REG_CPU_KERNEL(MaxPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingCPUKernel);
MS_REG_CPU_KERNEL(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingCPUKernel);
MS_REG_CPU_KERNEL(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingCPUKernel);
MS_REG_CPU_KERNEL(AvgPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingCPUKernel);
} // namespace kernel
} // namespace mindspore



+ 28
- 22
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel.h View File

@@ -24,6 +24,8 @@
#include <utility>
#include <map>
#include <memory>
#include <numeric>
#include <functional>
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/gpu/kernel_constants.h"
#include "runtime/device/gpu/gpu_device_manager.h"
@@ -140,22 +142,15 @@ class GpuKernel : public KernelMod {
if (shape.size() != len) {
MS_EXCEPTION(ValueError) << "Invalid size of input shape " << shape.size() << "-D with dimA " << len << "-D.";
}
if (format == "NCHW" || format == "DefaultFormat") {
dimA[0] = SizeToInt(shape[0]);
dimA[1] = SizeToInt(shape[1]);
dimA[2] = SizeToInt(shape[2]);
dimA[3] = SizeToInt(shape[3]);
if (format == "NCHW" || format == "DefaultFormat" || format == "NCDHW") {
for (size_t i = 0; i < len; ++i) {
dimA[i] = SizeToInt(shape[i]);
}
} else if (format == "NHWC") {
dimA[0] = SizeToInt(shape[0]);
dimA[1] = SizeToInt(shape[3]);
dimA[2] = SizeToInt(shape[1]);
dimA[3] = SizeToInt(shape[2]);
} else if (format == "NCDHW") {
dimA[0] = SizeToInt(shape[0]);
dimA[1] = SizeToInt(shape[1]);
dimA[2] = SizeToInt(shape[2]);
dimA[3] = SizeToInt(shape[3]);
dimA[4] = SizeToInt(shape[4]);
} else {
MS_LOG(ERROR) << "Unsupported data format " << format;
}
@@ -164,22 +159,15 @@ class GpuKernel : public KernelMod {
if (shape.size() != len) {
MS_EXCEPTION(ValueError) << "Invalid size of input shape " << shape.size() << "-D with strideA " << len << "-D.";
}
if (format == "NCHW" || format == "DefaultFormat") {
strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]);
strideA[1] = SizeToInt(shape[2] * shape[3]);
strideA[2] = SizeToInt(shape[3]);
strideA[3] = 1;
if (format == "NCHW" || format == "DefaultFormat" || format == "NCDHW") {
for (size_t i = 0; i < len; ++i) {
strideA[i] = SizeToInt(accumulate(shape.begin() + i + 1, shape.end(), 1, std::multiplies<size_t>()));
}
} else if (format == "NHWC") {
strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3]);
strideA[1] = 1;
strideA[2] = SizeToInt(shape[2] * shape[3]);
strideA[3] = SizeToInt(shape[3]);
} else if (format == "NCDHW") {
strideA[0] = SizeToInt(shape[1] * shape[2] * shape[3] * shape[4]);
strideA[1] = SizeToInt(shape[2] * shape[3] * shape[4]);
strideA[2] = SizeToInt(shape[3] * shape[4]);
strideA[3] = SizeToInt(shape[4]);
strideA[4] = 1;
} else {
MS_LOG(ERROR) << "Unsupported data format " << format;
}
@@ -201,6 +189,24 @@ class GpuKernel : public KernelMod {
}
}
void SetNCDHW(const std::vector<size_t> &shape, int *n, int *c, int *d, int *h, int *w, const std::string &format) {
if (format == "NCDHW" || format == "DefaultFormat") {
*n = SizeToInt(shape[0]);
*c = SizeToInt(shape[1]);
*d = SizeToInt(shape[2]);
*h = SizeToInt(shape[3]);
*w = SizeToInt(shape[4]);
} else if (format == "NDHWC") {
*n = SizeToInt(shape[0]);
*c = SizeToInt(shape[4]);
*d = SizeToInt(shape[1]);
*h = SizeToInt(shape[2]);
*w = SizeToInt(shape[3]);
} else {
MS_LOG(ERROR) << "Unsupported data format " << format;
}
}
inline void CheckBroadcast4TensorOp(const std::vector<int> &A, const std::vector<int> &B,
const std::vector<int> &Out) {
if (A != Out && B != Out) {


+ 8
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.cc View File

@@ -26,5 +26,13 @@ MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat32).Add
PoolingGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(AvgPool, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PoolingGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(MaxPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PoolingGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(AvgPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
PoolingGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(AvgPool3D, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
PoolingGpuFwdKernel, half)
} // namespace kernel
} // namespace mindspore

+ 75
- 17
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pooling_gpu_kernel.h View File

@@ -38,10 +38,13 @@ class PoolingGpuFwdKernel : public GpuKernel {
pooling_mode_(CUDNN_POOLING_MAX),
cudnn_data_type_(CUDNN_DATA_FLOAT),
compute_format_(CUDNN_TENSOR_NCHW),
old_depth_(0),
old_height_(0),
old_width_(0),
pad_depth_(0),
pad_height_(0),
pad_width_(0),
pad_front_(0),
pad_top_(0),
pad_left_(0),
n_(0),
@@ -83,6 +86,8 @@ class PoolingGpuFwdKernel : public GpuKernel {
auto format_attr = GetAttr<std::string>(kernel_node, "format");
if (format_attr == kOpFormat_NHWC) {
data_format_ = kOpFormat_NHWC;
} else if (format_attr == kOpFormat_NDHWC) {
data_format_ = kOpFormat_NDHWC;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);

@@ -94,24 +99,33 @@ class PoolingGpuFwdKernel : public GpuKernel {
return true;
}
CHECK_TENSOR_SIZE(input_shape);
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
const int nbDims = 4;
int dimA[4];
int strideAin[4];
int dimAout[4];
int strideAout[4];
SetDimA(input_shape, dimA, 4, data_format_);
SetStrideA(input_shape, strideAin, 4, data_format_);
SetDimA(output_shape, dimAout, 4, data_format_);
SetStrideA(output_shape, strideAout, 4, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, nbDims, dimA, strideAin),
"cudnnSetTensor4dDescriptor failed");
auto dim = input_shape.size();
if (dim == 4) {
SetNCHW(input_shape, &n_, &c_, &old_height_, &old_width_, data_format_);
} else if (dim == 5) {
SetNCDHW(input_shape, &n_, &c_, &old_depth_, &old_height_, &old_width_, data_format_);
}
const int kMaxDims = 5;
int dimA[kMaxDims];
int strideAin[kMaxDims];
int dimAout[kMaxDims];
int strideAout[kMaxDims];
SetDimA(input_shape, dimA, dim, data_format_);
SetStrideA(input_shape, strideAin, dim, data_format_);
SetDimA(output_shape, dimAout, dim, data_format_);
SetStrideA(output_shape, strideAout, dim, data_format_);
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetTensorNdDescriptor(input_descriptor_, cudnn_data_type_, dim, dimA, strideAin),
"cudnnSetTensorNdDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(
kernel_node_, cudnnSetTensorNdDescriptor(output_descriptor_, cudnn_data_type_, nbDims, dimAout, strideAout),
"cudnnSetTensor4dDescriptor failed");
kernel_node_, cudnnSetTensorNdDescriptor(output_descriptor_, cudnn_data_type_, dim, dimAout, strideAout),
"cudnnSetTensorNdDescriptor failed");
SetPoolingMode(kernel_node);
SetPad(kernel_node);
if (dim == 4) {
SetPad(kernel_node);
} else if (dim == 5) {
SetPad3D(kernel_node);
}
InitSizeLists();
return true;
}
@@ -161,7 +175,7 @@ class PoolingGpuFwdKernel : public GpuKernel {

void SetPoolingMode(const CNodePtr &kernel_node) {
mode_ = AnfAlgo::GetCNodeName(kernel_node);
if (mode_ == "AvgPool") {
if (mode_ == "AvgPool" || mode_ == "AvgPool3D") {
pooling_mode_ = CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
pad_value_ = 0.0;
} else {
@@ -212,6 +226,47 @@ class PoolingGpuFwdKernel : public GpuKernel {
"cudnnSetPoolingNdDescriptor failed");
}

void SetPad3D(const CNodePtr &kernel_node) {
pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("pad_mode"));
std::vector<int> window;
std::vector<int64_t> window_me =
GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("kernel_size"));
(void)std::transform(window_me.begin(), window_me.end(), std::back_inserter(window),
[](const int64_t &value) { return static_cast<int>(value); });
int window_depth = window[2];
int window_height = window[3];
int window_width = window[4];
std::vector<int64_t> stride_me =
GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride_),
[](const int64_t &value) { return static_cast<int>(value); });
int windowDimA[3] = {window_depth, window_height, window_width};
int paddingA[3] = {0, 0, 0};
int strideA[3] = {stride_[2], stride_[3], stride_[4]};
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
pad_depth_ =
std::max<int>(0, (((old_depth_ + stride_[2] - 1) / stride_[2]) - 1) * stride_[2] + window_depth - old_depth_);
pad_height_ = std::max<int>(
0, (((old_height_ + stride_[3] - 1) / stride_[3]) - 1) * stride_[3] + window_height - old_height_);
pad_width_ =
std::max<int>(0, (((old_width_ + stride_[4] - 1) / stride_[4]) - 1) * stride_[4] + window_width - old_width_);
pad_front_ = pad_depth_ / 2;
pad_top_ = pad_height_ / 2;
pad_left_ = pad_width_ / 2;
paddingA[0] = pad_front_;
paddingA[1] = pad_top_;
paddingA[2] = pad_left_;
} else {
pad_depth_ = 0;
pad_height_ = 0;
pad_width_ = 0;
}
CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_,
cudnnSetPoolingNdDescriptor(pooling_descriptor_, pooling_mode_, CUDNN_NOT_PROPAGATE_NAN,
3, windowDimA, paddingA, strideA),
"cudnnSetPoolingNdDescriptor failed");
}

cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t input_descriptor_;
cudnnTensorDescriptor_t output_descriptor_;
@@ -226,10 +281,13 @@ class PoolingGpuFwdKernel : public GpuKernel {
std::vector<size_t> workspace_size_list_;
cudnnDataType_t cudnn_data_type_;
cudnnTensorFormat_t compute_format_;
int old_depth_;
int old_height_;
int old_width_;
int pad_depth_;
int pad_height_;
int pad_width_;
int pad_front_;
int pad_top_;
int pad_left_;
int n_;


+ 1
- 0
mindspore/core/base/core_ops.h View File

@@ -226,6 +226,7 @@ inline const PrimitivePtr kPrimMaxPoolWithArgmax = std::make_shared<Primitive>("
inline const PrimitivePtr kPrimMaxPoolGradWithArgmax = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared<Primitive>("ApplyCenteredRMSProp");
inline const PrimitivePtr kPrimAvgPool = std::make_shared<Primitive>("AvgPool");
inline const PrimitivePtr kPrimAvgPool3D = std::make_shared<Primitive>("AvgPool3D");
inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared<Primitive>("AvgPoolGrad");
inline const PrimitivePtr kPrimAvgPoolGradVm = std::make_shared<Primitive>("AvgPoolGradVm");
inline const PrimitivePtr kPrimFusedSparseAdam = std::make_shared<Primitive>("FusedSparseAdam");


+ 2
- 1
mindspore/ops/operations/__init__.py View File

@@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax, MaxPool3D,
MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, ComputeAccidentalHits,
AvgPool, AvgPool3D, Conv2DBackpropInput, ComputeAccidentalHits,
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU,
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
@@ -317,6 +317,7 @@ __all__ = [
'UpdateState',
'identity',
'AvgPool',
'AvgPool3D',
# Back Primitive
'Equal',
'EqualCount',


+ 98
- 0
mindspore/ops/operations/nn_ops.py View File

@@ -1935,6 +1935,104 @@ class AvgPool(_Pool):
def __init__(self, kernel_size=1, strides=1, pad_mode="valid", data_format="NCHW"):
super(AvgPool, self).__init__(kernel_size, strides, pad_mode, data_format)

class AvgPool3D(_Pool):
r"""
Average pooling operation.

Applies a 3D average pooling over an input Tensor which can be regarded as a composition of 3D input planes.
Typically the input is of shape :math:`(N_{in}, C_{in}, D_{in}, H_{in}, W_{in})`, AvgPool outputs
regional average in the :math:`(D_{in}, H_{in}, W_{in})`-dimension. Given kernel size
:math:`ks = (d_{ker}, h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1, s_2)`, the operation is as follows.

.. math::
\text{output}(N_i, C_j, d, h, w) = \frac{1}{d_{ker} * h_{ker} * w_{ker}} \sum_{l=0}^{d_{ker}-1}
\sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1} \text{input}(N_i, C_j, s_0 \times d + l,
s_1 \times h + m, s_2 \times w + n)

Args:
kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value,
is an int number that represents depth, height and width are both kernel_size, or a tuple
of three int numbers that represent depth, height and width respectively. Default: 1.
strides (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the depth, height and width of movement are both strides, or a tuple of three int numbers that
represent depth, height and width of movement respectively. Default: 1.
pad_mode (str): The optional value for pad mode, is "same" or "valid", not case sensitive.
Default: "valid".

- same: Adopts the way of completion. The depth, height and width of the output will be the same as
the input. The total number of padding will be calculated in horizontal and vertical
directions and evenly distributed to top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the bottom and the right side.

- valid: Adopts the way of discarding. The possible largest depth, height and width of output
will be returned without padding. Extra pixels will be discarded.
data_format (str) - The format of input and output data. It should be 'NDHWC' or 'NCDHW',\
default is 'NCDHW'.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.

Outputs:
Tensor, with shape :math:`(N, C_{out}, D_{in}, H_{out}, W_{out})`.

Raises:
TypeError: If `kernel_size` or `strides` is neither int nor tuple.
ValueError: If `pad_mode` is neither 'valid' nor 'same' with not case sensitive.
ValueError: If `data_format` is neither 'NCDHW' nor 'NDHWC'.
ValueError: If `kernel_size` or `strides` is less than 1.
ValueError: If length of shape of `input` is not equal to 5.

Supported Platforms:
``CPU``

Examples:
>>> input = Tensor(np.arange(1 * 2 * 2 * 2 * 3).reshape((1, 2, 2, 2, 3)), mindspore.float32)
>>> avg_pool3d = P.AvgPool3D(kernel_size=2, strides=1, pad_mode="valid")
>>> output = avg_pool3d(input)
>>> print(output)
[[[[[ 5. 6.]]]
[[[17. 18.]]]]]
"""

@prim_attr_register
def __init__(self, kernel_size=1, strides=1, pad_mode="VALID", data_format="NCDHW"):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
validator.check_value_type('kernel_size', kernel_size, [int, tuple], self.name)
validator.check_value_type('strides', strides, [int, tuple], self.name)
validator.check_value_type('pad_mode', pad_mode, [str], self.name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
self.add_prim_attr("pad_mode", self.pad_mode)
self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
allow_five=False, ret_five=True)
self.add_prim_attr("kernel_size", self.kernel_size)
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=False, ret_five=True)
self.add_prim_attr("strides", self.strides)

def infer_shape(self, x_shape):
validator.check_equal_int(len(x_shape), 5, "x rank", self.name)
batch, channel, input_d, input_h, input_w = x_shape
self.add_prim_attr("x_shape", x_shape)
_, _, kernel_d, kernel_h, kernel_w = self.kernel_size
_, _, stride_d, stride_h, stride_w = self.strides

if self.pad_mode == "VALID":
out_d = math.ceil((input_d - (kernel_d - 1)) / stride_d)
out_h = math.ceil((input_h - (kernel_h - 1)) / stride_h)
out_w = math.ceil((input_w - (kernel_w - 1)) / stride_w)
elif self.pad_mode == "SAME":
out_d = math.ceil(input_d / stride_d)
out_h = math.ceil(input_h / stride_h)
out_w = math.ceil(input_w / stride_w)
out_shape = [batch, channel, out_d, out_h, out_w]

_check_shape('output', out_shape, self.name)
return out_shape

def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x", x_dtype, [mstype.float16, mstype.float32], self.name)
return x_dtype


class Conv2DBackpropInput(PrimitiveWithInfer):
"""


+ 179
- 0
tests/st/ops/cpu/test_avgpool_op.py View File

@@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from functools import reduce
import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
@@ -87,6 +89,183 @@ def test_avgpool_k3s2ps():
assert np.allclose(out.asnumpy(), expect_result)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_avg_pool3d_1():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[9, 10],
[13, 14]]],
[[[33, 34],
[37, 38]]],
[[[57, 58],
[61, 62]]]],
[[[[81, 82],
[85, 86]]],
[[[105, 106],
[109, 110]]],
[[[129, 130],
[133, 134]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_avg_pool3d_2():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5, 9.5, 10.5],
[12.5, 13.5, 14.5]]],
[[[32.5, 33.5, 34.5],
[36.5, 37.5, 38.5]]],
[[[56.5, 57.5, 58.5],
[60.5, 61.5, 62.5]]]],
[[[[80.5, 81.5, 82.5],
[84.5, 85.5, 86.5]]],
[[[104.5, 105.5, 106.5],
[108.5, 109.5, 110.5]]],
[[[128.5, 129.5, 130.5],
[132.5, 133.5, 134.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_avg_pool3d_3():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 3
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5]]],
[[[32.5]]],
[[[56.5]]]],
[[[[80.5]]],
[[[104.5]]],
[[[128.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_avg_pool3d_4():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5, 9, 10, 10.5],
[12.5, 13, 14, 14.5],
[14.5, 15, 16, 16.5]],
[[14.5, 15, 16, 16.5],
[18.5, 19, 20, 20.5],
[20.5, 21, 22, 22.5]]],
[[[32.5, 33, 34, 34.5],
[36.5, 37, 38, 38.5],
[38.5, 39, 40, 40.5]],
[[38.5, 39, 40, 40.5],
[42.5, 43, 44, 44.5],
[44.5, 45, 46, 46.5]]],
[[[56.5, 57, 58, 58.5],
[60.5, 61, 62, 62.5],
[62.5, 63, 64, 64.5]],
[[62.5, 63, 64, 64.5],
[66.5, 67, 68, 68.5],
[68.5, 69, 70, 70.5]]]],
[[[[80.5, 81, 82, 82.5],
[84.5, 85, 86, 86.5],
[86.5, 87, 88, 88.5]],
[[86.5, 87, 88, 88.5],
[90.5, 91, 92, 92.5],
[92.5, 93, 94, 94.5]]],
[[[104.5, 105, 106, 106.5],
[108.5, 109, 110, 110.5],
[110.5, 111, 112, 112.5]],
[[110.5, 111, 112, 112.5],
[114.5, 115, 116, 116.5],
[116.5, 117, 118, 118.5]]],
[[[128.5, 129, 130, 130.5],
[132.5, 133, 134, 134.5],
[134.5, 135, 136, 136.5]],
[[134.5, 135, 136, 136.5],
[138.5, 139, 140, 140.5],
[140.5, 141, 142, 142.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_avg_pool3d_5():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5, 9, 10, 10.5],
[12.5, 13, 14, 14.5],
[14.5, 15, 16, 16.5]],
[[14.5, 15, 16, 16.5],
[18.5, 19, 20, 20.5],
[20.5, 21, 22, 22.5]]],
[[[32.5, 33, 34, 34.5],
[36.5, 37, 38, 38.5],
[38.5, 39, 40, 40.5]],
[[38.5, 39, 40, 40.5],
[42.5, 43, 44, 44.5],
[44.5, 45, 46, 46.5]]],
[[[56.5, 57, 58, 58.5],
[60.5, 61, 62, 62.5],
[62.5, 63, 64, 64.5]],
[[62.5, 63, 64, 64.5],
[66.5, 67, 68, 68.5],
[68.5, 69, 70, 70.5]]]],
[[[[80.5, 81, 82, 82.5],
[84.5, 85, 86, 86.5],
[86.5, 87, 88, 88.5]],
[[86.5, 87, 88, 88.5],
[90.5, 91, 92, 92.5],
[92.5, 93, 94, 94.5]]],
[[[104.5, 105, 106, 106.5],
[108.5, 109, 110, 110.5],
[110.5, 111, 112, 112.5]],
[[110.5, 111, 112, 112.5],
[114.5, 115, 116, 116.5],
[116.5, 117, 118, 118.5]]],
[[[128.5, 129, 130, 130.5],
[132.5, 133, 134, 134.5],
[134.5, 135, 136, 136.5]],
[[134.5, 135, 136, 136.5],
[138.5, 139, 140, 140.5],
[140.5, 141, 142, 142.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


if __name__ == '__main__':
test_avgpool_k2s1pv()
test_avgpool_k2s2pv()


tests/st/ops/cpu/test_conv2d_op.py → tests/st/ops/cpu/test_conv_op.py View File

@@ -21,7 +21,9 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import ParameterTuple
from mindspore.ops import operations as P
from mindspore.ops import composite as C

context.set_context(mode=context.GRAPH_MODE, device_target='CPU')

@@ -170,5 +172,86 @@ def test_conv():
assert (loss < error).all()


test_conv2d()
test_conv()
class NetConv3d(nn.Cell):
def __init__(self):
super(NetConv3d, self).__init__()
out_channel = 4
kernel_size = 2
self.conv = P.Conv3D(out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1)

def construct(self, x, w):
return self.conv(x, w)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_conv3d():
x = Tensor(np.arange(1 * 3 * 3 * 3 * 3).reshape(1, 3, 3, 3, 3).astype(np.float32))
w = Tensor(np.arange(4 * 3 * 2 * 2 * 2).reshape(4, 3, 2, 2, 2).astype(np.float32))
expect = np.array([[[[[12960., 13236.],
[13788., 14064.]],
[[15444., 15720.],
[16272., 16548.]]],
[[[32256., 33108.],
[34812., 35664.]],
[[39924., 40776.],
[42480., 43332.]]],
[[[51552., 52980.],
[55836., 57264.]],
[[64404., 65832.],
[68688., 70116.]]],
[[[70848., 72852.],
[76860., 78864.]],
[[88884., 90888.],
[94896., 96900.]]]]]).astype(np.float32)

context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
net = NetConv3d()
output = net(x, w)
assert (output.asnumpy() == expect).all()
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
net = NetConv3d()
output = net(x, w)
assert (output.asnumpy() == expect).all()


class MSConv3dNet(nn.Cell):
def __init__(self, in_channels, out_channels, kernel_size, pad_mode='pad', padding=0, stride=1, dilation=1,
has_bias=False, weight_init='normal'):
super(MSConv3dNet, self).__init__()
self.cv1 = nn.Conv3d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
pad_mode=pad_mode,
padding=padding,
stride=stride,
dilation=dilation,
group=1,
has_bias=has_bias,
weight_init=weight_init,
data_format='NCDHW')

def construct(self, x):
x = self.cv1(x)
return x


class MSGradNet(nn.Cell):
def __init__(self, network):
super(MSGradNet, self).__init__()
self.grad = C.GradOperation(get_all=True, sens_param=True, get_by_list=True)
self.network = network
self.params = ParameterTuple(network.trainable_params())

def construct(self, x, dy):
grad_op = self.grad(self.network, self.params)
output = grad_op(x, dy)
return output

+ 179
- 0
tests/st/ops/cpu/test_maxpool_op.py View File

@@ -13,11 +13,13 @@
# limitations under the License.
# ============================================================================

from functools import reduce
import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
@@ -88,3 +90,180 @@ def test_maxpool():
maxpool2d = Net_Pool()
with pytest.raises(Exception):
maxpool2d(x)


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_max_pool3d_1():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[18, 19],
[22, 23]]],
[[[42, 43],
[46, 47]]],
[[[66, 67],
[70, 71]]]],
[[[[90, 91],
[94, 95]]],
[[[114, 115],
[118, 119]]],
[[[138, 139],
[142, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_max_pool3d_2():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17, 18, 19],
[21, 22, 23]]],
[[[41, 42, 43],
[45, 46, 47]]],
[[[65, 66, 67],
[69, 70, 71]]]],
[[[[89, 90, 91],
[93, 94, 95]]],
[[[113, 114, 115],
[117, 118, 119]]],
[[[137, 138, 139],
[141, 142, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_max_pool3d_3():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 3
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17]]],
[[[41]]],
[[[65]]]],
[[[[89]]],
[[[113]]],
[[[137]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_max_pool3d_4():
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]],
[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]]],
[[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]],
[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]]],
[[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]],
[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]]]],
[[[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]],
[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]]],
[[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]],
[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]]],
[[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]],
[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_max_pool3d_5():
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]],
[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]]],
[[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]],
[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]]],
[[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]],
[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]]]],
[[[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]],
[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]]],
[[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]],
[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]]],
[[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]],
[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()

+ 273
- 0
tests/st/ops/gpu/test_avgpool_gpu_op.py View File

@@ -0,0 +1,273 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from functools import reduce
import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor

context.set_context(mode=context.GRAPH_MODE, device_target='GPU')


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avgpool_k2s1pv():
x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32)
net = nn.AvgPool2d(kernel_size=2, stride=1, pad_mode='valid')
out = net(Tensor(x))
print(out)
expect_result = np.array(
[[[[3.5, 4.5, 5.5, 6.5, 7.5],
[9.5, 10.5, 11.5, 12.5, 13.5],
[15.5, 16.5, 17.5, 18.5, 19.5],
[21.5, 22.5, 23.5, 24.5, 25.5],
[27.5, 28.5, 29.5, 30.5, 31.5]]]]
)
assert np.allclose(out.asnumpy(), expect_result)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avgpool_k2s2pv():
x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32)
net = nn.AvgPool2d(kernel_size=2, stride=2, pad_mode='valid')
out = net(Tensor(x))
print(out)
expect_result = np.array(
[[[[3.5, 5.5, 7.5],
[15.5, 17.5, 19.5],
[27.5, 29.5, 31.5]]]]
)
assert np.allclose(out.asnumpy(), expect_result)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avgpool_k3s2pv():
x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32)
net = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='valid')
out = net(Tensor(x))
print(out)
expect_result = np.array(
[[[[7., 9.],
[19., 21.]]]]
)
assert np.allclose(out.asnumpy(), expect_result)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avgpool_k3s2ps():
x = np.arange(1 * 1 * 6 * 6).reshape((1, 1, 6, 6)).astype(np.float32)
net = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')
out = net(Tensor(x))
print(out)
expect_result = np.array(
[[[[7., 9., 10.5],
[19., 21., 22.5],
[28., 30., 31.5]]]]
)
assert np.allclose(out.asnumpy(), expect_result)


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avg_pool3d_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[9, 10],
[13, 14]]],
[[[33, 34],
[37, 38]]],
[[[57, 58],
[61, 62]]]],
[[[[81, 82],
[85, 86]]],
[[[105, 106],
[109, 110]]],
[[[129, 130],
[133, 134]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avg_pool3d_2():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5, 9.5, 10.5],
[12.5, 13.5, 14.5]]],
[[[32.5, 33.5, 34.5],
[36.5, 37.5, 38.5]]],
[[[56.5, 57.5, 58.5],
[60.5, 61.5, 62.5]]]],
[[[[80.5, 81.5, 82.5],
[84.5, 85.5, 86.5]]],
[[[104.5, 105.5, 106.5],
[108.5, 109.5, 110.5]]],
[[[128.5, 129.5, 130.5],
[132.5, 133.5, 134.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avg_pool3d_3():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 3
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5]]],
[[[32.5]]],
[[[56.5]]]],
[[[[80.5]]],
[[[104.5]]],
[[[128.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avg_pool3d_4():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5, 9, 10, 10.5],
[12.5, 13, 14, 14.5],
[14.5, 15, 16, 16.5]],
[[14.5, 15, 16, 16.5],
[18.5, 19, 20, 20.5],
[20.5, 21, 22, 22.5]]],
[[[32.5, 33, 34, 34.5],
[36.5, 37, 38, 38.5],
[38.5, 39, 40, 40.5]],
[[38.5, 39, 40, 40.5],
[42.5, 43, 44, 44.5],
[44.5, 45, 46, 46.5]]],
[[[56.5, 57, 58, 58.5],
[60.5, 61, 62, 62.5],
[62.5, 63, 64, 64.5]],
[[62.5, 63, 64, 64.5],
[66.5, 67, 68, 68.5],
[68.5, 69, 70, 70.5]]]],
[[[[80.5, 81, 82, 82.5],
[84.5, 85, 86, 86.5],
[86.5, 87, 88, 88.5]],
[[86.5, 87, 88, 88.5],
[90.5, 91, 92, 92.5],
[92.5, 93, 94, 94.5]]],
[[[104.5, 105, 106, 106.5],
[108.5, 109, 110, 110.5],
[110.5, 111, 112, 112.5]],
[[110.5, 111, 112, 112.5],
[114.5, 115, 116, 116.5],
[116.5, 117, 118, 118.5]]],
[[[128.5, 129, 130, 130.5],
[132.5, 133, 134, 134.5],
[134.5, 135, 136, 136.5]],
[[134.5, 135, 136, 136.5],
[138.5, 139, 140, 140.5],
[140.5, 141, 142, 142.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_avg_pool3d_5():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.AvgPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[8.5, 9, 10, 10.5],
[12.5, 13, 14, 14.5],
[14.5, 15, 16, 16.5]],
[[14.5, 15, 16, 16.5],
[18.5, 19, 20, 20.5],
[20.5, 21, 22, 22.5]]],
[[[32.5, 33, 34, 34.5],
[36.5, 37, 38, 38.5],
[38.5, 39, 40, 40.5]],
[[38.5, 39, 40, 40.5],
[42.5, 43, 44, 44.5],
[44.5, 45, 46, 46.5]]],
[[[56.5, 57, 58, 58.5],
[60.5, 61, 62, 62.5],
[62.5, 63, 64, 64.5]],
[[62.5, 63, 64, 64.5],
[66.5, 67, 68, 68.5],
[68.5, 69, 70, 70.5]]]],
[[[[80.5, 81, 82, 82.5],
[84.5, 85, 86, 86.5],
[86.5, 87, 88, 88.5]],
[[86.5, 87, 88, 88.5],
[90.5, 91, 92, 92.5],
[92.5, 93, 94, 94.5]]],
[[[104.5, 105, 106, 106.5],
[108.5, 109, 110, 110.5],
[110.5, 111, 112, 112.5]],
[[110.5, 111, 112, 112.5],
[114.5, 115, 116, 116.5],
[116.5, 117, 118, 118.5]]],
[[[128.5, 129, 130, 130.5],
[132.5, 133, 134, 134.5],
[134.5, 135, 136, 136.5]],
[[134.5, 135, 136, 136.5],
[138.5, 139, 140, 140.5],
[140.5, 141, 142, 142.5]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


if __name__ == '__main__':
test_avgpool_k2s1pv()
test_avgpool_k2s2pv()
test_avgpool_k3s2pv()
test_avgpool_k3s2ps()

+ 179
- 0
tests/st/ops/gpu/test_maxpool_gpu_op.py View File

@@ -13,11 +13,13 @@
# limitations under the License.
# ============================================================================

from functools import reduce
import numpy as np
import pytest

import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor


@@ -77,3 +79,180 @@ def test_maxpool2d():
output = maxpool2d(x)
assert (output.asnumpy() == expect_result).all()
assert (output2.asnumpy() == expect_result2).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_max_pool3d_1():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[18, 19],
[22, 23]]],
[[[42, 43],
[46, 47]]],
[[[66, 67],
[70, 71]]]],
[[[[90, 91],
[94, 95]]],
[[[114, 115],
[118, 119]]],
[[[138, 139],
[142, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_max_pool3d_2():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 1
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17, 18, 19],
[21, 22, 23]]],
[[[41, 42, 43],
[45, 46, 47]]],
[[[65, 66, 67],
[69, 70, 71]]]],
[[[[89, 90, 91],
[93, 94, 95]]],
[[[113, 114, 115],
[117, 118, 119]]],
[[[137, 138, 139],
[141, 142, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_max_pool3d_3():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = 2
strides = 3
pad_mode = 'VALID'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17]]],
[[[41]]],
[[[65]]]],
[[[[89]]],
[[[113]]],
[[[137]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_max_pool3d_4():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]],
[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]]],
[[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]],
[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]]],
[[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]],
[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]]]],
[[[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]],
[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]]],
[[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]],
[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]]],
[[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]],
[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()


@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_max_pool3d_5():
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
x_shape = (2, 3, 2, 3, 4)
kernel_size = (2, 2, 3)
strides = 1
pad_mode = 'SAME'
x_val = np.arange(reduce(lambda x, y: x * y, x_shape))
x_ms = Tensor(x_val).reshape(x_shape).astype(np.float32)
output_ms = P.MaxPool3D(kernel_size=kernel_size, strides=strides, pad_mode=pad_mode)(x_ms)
expert_result = (np.array([[[[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]],
[[17, 18, 19, 19],
[21, 22, 23, 23],
[21, 22, 23, 23]]],
[[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]],
[[41, 42, 43, 43],
[45, 46, 47, 47],
[45, 46, 47, 47]]],
[[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]],
[[65, 66, 67, 67],
[69, 70, 71, 71],
[69, 70, 71, 71]]]],
[[[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]],
[[89, 90, 91, 91],
[93, 94, 95, 95],
[93, 94, 95, 95]]],
[[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]],
[[113, 114, 115, 115],
[117, 118, 119, 119],
[117, 118, 119, 119]]],
[[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]],
[[137, 138, 139, 139],
[141, 142, 143, 143],
[141, 142, 143, 143]]]]]))
assert (output_ms.asnumpy() == expert_result).all()

Loading…
Cancel
Save