diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h index ecbb5461e2..9d5aa0f412 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/conv2d_gpu_kernel.h @@ -31,31 +31,7 @@ namespace kernel { template class Conv2dGpuFwdKernel : public GpuKernel { public: - Conv2dGpuFwdKernel() - : cudnn_handle_(nullptr), - input_desc_(nullptr), - output_desc_(nullptr), - filter_desc_(nullptr), - conv_desc_(nullptr), - padded_desc_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - compute_format_(CUDNN_TENSOR_NCHW), - old_height_(0), - old_width_(0), - pad_height_(0), - pad_width_(0), - pad_top_(0), - pad_left_(0), - n_(0), - c_(0), - group_(1), - is_null_input_(false), - input_size_(0), - filter_size_(0), - output_size_(0), - padded_size_(0), - workspace_size_(0), - use_pad_(true) {} + Conv2dGpuFwdKernel() { ResetResource(); } ~Conv2dGpuFwdKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -194,6 +170,38 @@ class Conv2dGpuFwdKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + cudnn_handle_ = nullptr; + input_desc_ = nullptr; + output_desc_ = nullptr; + filter_desc_ = nullptr; + conv_desc_ = nullptr; + padded_desc_ = nullptr; + cudnn_data_type_ = CUDNN_DATA_FLOAT; + compute_format_ = CUDNN_TENSOR_NCHW; + old_height_ = 0; + old_width_ = 0; + pad_height_ = 0; + pad_width_ = 0; + pad_top_ = 0; + pad_left_ = 0; + n_ = 0; + c_ = 0; + stride_.clear(); + dilation_.clear(); + group_ = 1; + is_null_input_ = false; + input_size_ = 0; + filter_size_ = 0; + output_size_ = 0; + padded_size_ = 0; + workspace_size_ = 0; + use_pad_ = true; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + void DestroyResource() noexcept override { CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyConvolutionDescriptor(conv_desc_), "cudnnDestroyConvolutionDescriptor failed"); diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 7c7ade82c2..25abf5e75d 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -57,6 +57,8 @@ AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const Primitiv const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index d745bca2de..4aad2b8f3a 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include "abstract/infer_functions.h" #include "abstract/utils.h" #include "abstract/param_validator.h" @@ -267,6 +268,194 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit return std::make_shared(rets); } +AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + + AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + ShapeVector x_shape = input_x->shape()->shape(); + ShapeVector x_min_shape = input_x->shape()->min_shape(); + ShapeVector x_max_shape = input_x->shape()->max_shape(); + (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); + + AbstractTensorPtr input_w = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(input_w); + MS_EXCEPTION_IF_NULL(input_w->shape()); + ShapeVector w_shape = input_w->shape()->shape(); + ShapeVector w_min_shape = input_w->shape()->min_shape(); + ShapeVector w_max_shape = input_w->shape()->max_shape(); + (void)CheckMinMaxShape(w_shape, &w_min_shape, &w_max_shape); + + std::set available_data_format{"NCHW", "NHWC"}; + int64_t n_axis = 0; + int64_t c_axis = 1; + int64_t h_axis = 2; + int64_t w_axis = 3; + auto data_format_ptr = primitive->GetAttr("format"); + std::string data_format = "NCHW"; + if ((data_format_ptr != nullptr) && data_format_ptr->isa()) { + data_format = data_format_ptr->cast()->value(); + if (available_data_format.find(data_format) == available_data_format.end()) { + MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ". use NCHW or NHWC"; + } + if (data_format == "NHWC") { + c_axis = 3; + h_axis = 1; + w_axis = 2; + } + } + + int64_t group = primitive->GetAttr("group")->cast()->value(); + if (group <= 0) { + MS_LOG(EXCEPTION) << "Invalid group value: " << group << ", should be greater then 0"; + } + if ((x_shape[c_axis] != Shape::SHP_ANY) && (x_shape[c_axis] % group != 0)) { + MS_LOG(EXCEPTION) << "x_shape[" << c_axis << "] = " << x_shape[c_axis] + << " (channels) must be divisible by group = " << group; + } + + int64_t out_channel = primitive->GetAttr("out_channel")->cast()->value(); + if (out_channel <= 0) { + MS_LOG(EXCEPTION) << "Invalid out_channel value: " << out_channel << ", should be greater then 0"; + } + if ((w_shape[0] != Shape::SHP_ANY) && (w_shape[0] != out_channel)) { + MS_LOG(EXCEPTION) << "w_shape[0] = " << w_shape[0] << " must equal to = " << out_channel; + } + + int64_t kernel_h = 0; + int64_t kernel_w = 0; + ValuePtr kernel_size_attr = primitive->GetAttr("kernel_size"); + MS_EXCEPTION_IF_NULL(kernel_size_attr); + if (kernel_size_attr->isa()) { + std::vector kernel_size_vec = kernel_size_attr->cast()->value(); + kernel_h = GetValue(kernel_size_vec[0]); + kernel_w = GetValue(kernel_size_vec[1]); + } else { + int64_t kernel_size = kernel_size_attr->cast()->value(); + kernel_h = kernel_size; + kernel_w = kernel_size; + } + if ((w_shape[2] != Shape::SHP_ANY) && (w_shape[2] != kernel_h)) { + MS_LOG(EXCEPTION) << "weight height, w_shape[2] = " << w_shape[2] << ", must equal to = " << kernel_h; + } + if ((w_shape[3] != Shape::SHP_ANY) && (w_shape[3] != kernel_w)) { + MS_LOG(EXCEPTION) << "weight width, w_shape[3] = " << w_shape[3] << ", must equal to = " << kernel_w; + } + + int64_t stride_h = 0; + int64_t stride_w = 0; + ValuePtr stride_attr = primitive->GetAttr("stride"); + MS_EXCEPTION_IF_NULL(stride_attr); + if (stride_attr->isa()) { + std::vector stride_vec = stride_attr->cast()->value(); + stride_h = GetValue(stride_vec[2]); + stride_w = GetValue(stride_vec[3]); + } else { + int64_t stride = stride_attr->cast()->value(); + stride_h = stride; + stride_w = stride; + } + + int64_t dilation_h = 0; + int64_t dilation_w = 0; + ValuePtr dilation_attr = primitive->GetAttr("dilation"); + MS_EXCEPTION_IF_NULL(dilation_attr); + if (dilation_attr->isa()) { + std::vector dilation_vec = dilation_attr->cast()->value(); + dilation_h = GetValue(dilation_vec[2]); + dilation_w = GetValue(dilation_vec[3]); + } else { + int64_t dilation = dilation_attr->cast()->value(); + dilation_h = dilation; + dilation_w = dilation; + } + + std::vector padding; + ValuePtr padding_attr = primitive->GetAttr("pad"); + MS_EXCEPTION_IF_NULL(padding_attr); + if (padding_attr->isa()) { + std::vector padding_vec = padding_attr->cast()->value(); + (void)std::transform(std::begin(padding_vec), std::end(padding_vec), std::back_inserter(padding), + [](const ValuePtr &e) -> int64_t { return GetValue(e); }); + } else { + int64_t padding_val = padding_attr->cast()->value(); + padding = {padding_val, padding_val, padding_val, padding_val}; + } + + std::set available_pad_mode{"pad", "same", "valid"}; + ValuePtr pad_mode_attr = primitive->GetAttr("pad_mode"); + MS_EXCEPTION_IF_NULL(pad_mode_attr); + auto pad_mode = pad_mode_attr->cast()->value(); + if (available_pad_mode.find(pad_mode) == available_pad_mode.end()) { + MS_LOG(EXCEPTION) << "Unsupported pad mode: " << pad_mode << ". use pad, same, valid"; + } + + std::function &, std::vector &)> pad_function = + [kernel_h, kernel_w, stride_h, stride_w, dilation_h, dilation_w, pad_mode, padding]( + int64_t x_h, int64_t x_w, std::vector &output_hw, std::vector &pad_list) { + if (pad_mode == "valid") { + output_hw.push_back(std::ceil(((x_h * 1.0) - dilation_h * (kernel_h - 1)) / stride_h)); + output_hw.push_back(std::ceil(((x_w * 1.0) - dilation_w * (kernel_w - 1)) / stride_w)); + pad_list = {0, 0, 0, 0}; + } else if (pad_mode == "same") { + output_hw.push_back(std::ceil((x_h * 1.0) / stride_h)); + output_hw.push_back(std::ceil((x_w * 1.0) / stride_w)); + int64_t pad_needed_h = (output_hw[0] - 1) * stride_h + dilation_h * (kernel_h - 1) + 1 - x_h; + pad_needed_h = std::max((int64_t)0, pad_needed_h); + pad_list.push_back(std::floor(pad_needed_h / 2)); + pad_list.push_back(pad_needed_h - pad_list[0]); + int64_t pad_needed_w = (output_hw[1] - 1) * stride_w + dilation_w * (kernel_w - 1) + 1 - x_w; + pad_needed_w = std::max((int64_t)0, pad_needed_w); + pad_list.push_back(std::floor(pad_needed_w / 2)); + pad_list.push_back(pad_needed_w - pad_list[2]); + } else if (pad_mode == "pad") { + pad_list = padding; + output_hw.push_back(std::floor( + 1 + ((x_h * 1.0) + pad_list[0] + pad_list[1] - kernel_h - (kernel_h - 1) * (dilation_h - 1)) / stride_h)); + output_hw.push_back(std::floor( + 1 + ((x_w * 1.0) + pad_list[2] + pad_list[3] - kernel_w - (kernel_w - 1) * (dilation_w - 1)) / stride_w)); + } + }; + + std::vector output_hw; + std::vector pad_list; + std::vector output_hw_min; + std::vector pad_list_min; + std::vector output_hw_max; + std::vector pad_list_max; + pad_function(x_shape[h_axis], x_shape[w_axis], output_hw, pad_list); + if (x_shape[h_axis] == Shape::SHP_ANY) { + output_hw[0] = Shape::SHP_ANY; + } + if (x_shape[w_axis] == Shape::SHP_ANY) { + output_hw[1] = Shape::SHP_ANY; + } + pad_function(x_min_shape[h_axis], x_min_shape[w_axis], output_hw_min, pad_list_min); + pad_function(x_max_shape[h_axis], x_max_shape[w_axis], output_hw_max, pad_list_max); + + std::vector pad_list_val = {MakeValue(pad_list[0]), MakeValue(pad_list[1]), MakeValue(pad_list[2]), + MakeValue(pad_list[3])}; + primitive->set_attr("pad_list", MakeValue(pad_list_val)); + ShapeVector output_shape; + ShapeVector output_shape_min; + ShapeVector output_shape_max; + if (data_format == "NHWC") { + output_shape = {x_shape[n_axis], output_hw[0], output_hw[1], out_channel}; + output_shape_min = {x_min_shape[n_axis], output_hw_min[0], output_hw_min[1], out_channel}; + output_shape_max = {x_max_shape[n_axis], output_hw_max[0], output_hw_max[1], out_channel}; + } else { + output_shape = {x_shape[n_axis], out_channel, output_hw[0], output_hw[1]}; + output_shape_min = {x_min_shape[n_axis], out_channel, output_hw_min[0], output_hw_min[1]}; + output_shape_max = {x_max_shape[n_axis], out_channel, output_hw_max[0], output_hw_max[1]}; + } + + ShapePtr output_shape_ptr = std::make_shared(output_shape, output_shape_min, output_shape_max); + return std::make_shared(input_x->element(), output_shape_ptr); +} + AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: three tensors(doutput, input, filters). diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 7defa35a14..0af8578033 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -111,6 +111,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, {prim::kPrimReluGrad, {InferImplReluGrad, true}}, + {prim::kPrimConv2D, {InferImplConv2D, true}}, {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index b254074479..5f524aa79b 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1207,7 +1207,7 @@ class BatchNorm(PrimitiveWithInfer): return (input_x, scale, bias, input_x, input_x) -class Conv2D(PrimitiveWithInfer): +class Conv2D(PrimitiveWithCheck): r""" 2D convolution layer. @@ -1314,65 +1314,16 @@ class Conv2D(PrimitiveWithInfer): self.add_prim_attr('groups', self.group) self.add_prim_attr('offset_a', 0) - def infer_shape(self, x_shape, w_shape, b_shape=None): + def check_shape(self, x_shape, w_shape, b_shape=None): x_shape_norm = x_shape if self.format == "NCHW" else (x_shape[0], x_shape[3], x_shape[1], x_shape[2]) w_shape_norm = w_shape if self.format == "NCHW" else (w_shape[0], w_shape[3], w_shape[1], w_shape[2]) - validator.check_equal_int(len(w_shape_norm), 4, "weight rank", self.name) validator.check_equal_int(len(x_shape_norm), 4, "x rank", self.name) - validator.check(f"x_shape[1] / group", x_shape_norm[1] // self.group, "w_shape[1]", w_shape_norm[1], \ - Rel.EQ, self.name) - validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape_norm[0], Rel.EQ, self.name) - validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape_norm[2:4]), Rel.EQ, self.name) - - kernel_size_h = w_shape_norm[2] - kernel_size_w = w_shape_norm[3] - - stride_h = self.stride[2] - stride_w = self.stride[3] - dilation_h = self.dilation[2] - dilation_w = self.dilation[3] - - if self.pad_mode == "valid": - h_out = math.ceil((x_shape_norm[2] - dilation_h * (kernel_size_h - 1)) / stride_h) - w_out = math.ceil((x_shape_norm[3] - dilation_w * (kernel_size_w - 1)) / stride_w) - pad_top, pad_bottom, pad_left, pad_right = 0, 0, 0, 0 - elif self.pad_mode == "same": - h_out = math.ceil(x_shape_norm[2] / stride_h) - w_out = math.ceil(x_shape_norm[3] / stride_w) - - pad_needed_h = max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape_norm[2]) - pad_top = math.floor(pad_needed_h / 2) - pad_bottom = pad_needed_h - pad_top - pad_needed_w = max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape_norm[3]) - pad_left = math.floor(pad_needed_w / 2) - pad_right = pad_needed_w - pad_left - elif self.pad_mode == 'pad': - pad_top, pad_bottom, pad_left, pad_right = self.padding - - h_out = 1 + (x_shape_norm[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) \ - * (dilation_h - 1)) / stride_h - w_out = 1 + (x_shape_norm[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) \ - * (dilation_w - 1)) / stride_w - h_out = math.floor(h_out) - w_out = math.floor(w_out) - - self.pad_list = [pad_top, pad_bottom, pad_left, pad_right] - self.add_prim_attr('pad_list', (pad_top, pad_bottom, pad_left, pad_right)) - out_channel = self.out_channel - out_shape = [x_shape_norm[0], out_channel, h_out, w_out] if self.format == "NCHW" else \ - [x_shape_norm[0], h_out, w_out, out_channel] - _check_shape('output', out_shape, self.name) - return out_shape - - def infer_dtype(self, x_dtype, w_dtype, b_dtype=None): + def check_dtype(self, x_dtype, w_dtype, b_dtype=None): args = {'x': x_dtype, 'w': w_dtype} valid_dtypes = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) - if x_dtype.element_type() == mstype.int8: - return mstype.tensor_type(mstype.int32) - return x_dtype class DepthwiseConv2dNative(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_conv2d_op.py b/tests/st/ops/gpu/test_conv2d_op.py index 6af5fc3965..afa15a818d 100644 --- a/tests/st/ops/gpu/test_conv2d_op.py +++ b/tests/st/ops/gpu/test_conv2d_op.py @@ -20,6 +20,9 @@ import mindspore.context as context import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops.operations import _inner_ops as inner +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer class NetConv2d(nn.Cell): @@ -61,3 +64,171 @@ def test_conv2d(): conv2d = NetConv2d() output = conv2d(x, w) assert (output.asnumpy() == expect).all() + + +class NetConv(nn.Cell): + def __init__(self, weight, x): + super(NetConv, self).__init__() + self.conv = nn.Conv2d(in_channels=3, + out_channels=3, + kernel_size=(5, 3), + stride=2, + pad_mode='same', + padding=(0, 0, 0, 0), + dilation=(1, 1), + group=1, + has_bias=False, + weight_init=Tensor(weight) + ) + self.x = Parameter(initializer(Tensor(x), [1, 3, 4, 2]), name="x") + + def construct(self): + return self.conv(self.x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_conv(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + weight = np.array([[[[0.38968208, 0.14398979, 0.7962463], + [-2.1836321, -0.63823014, -0.50588065], + [0.6660469, 0.64673275, -0.13160042], + [1.3683757, 1.4005762, -0.37235805], + [-0.22638111, 0.45427424, -0.10293389]], + [[1.4985064, -0.29318333, -0.92694616], + [1.539068, 0.8937254, -1.2598171], + [0.9658142, -0.63945454, -0.23185322], + [1.363089, -0.41694695, -2.2750475], + [-0.4865508, -1.6938025, 0.609849]], + [[1.1844803, 0.99874926, -1.9475793], + [0.4987858, 0.5307887, -0.04226681], + [0.4529779, -1.1960793, 0.9456575], + [3.133675, 0.2309789, -0.29201075], + [-0.59632736, -0.0789804, -0.69486314]]], + [[[-0.5606142, 0.6420862, 0.2478745], + [0.02717604, 1.5483379, -0.9373383], + [-1.1017276, -0.259478, 1.0311872], + [1.8387799, 0.16468556, 0.33392152], + [-1.8781787, 1.0158662, 1.6527579]], + + [[0.45696944, -0.5652523, -1.5618048], + [-0.30304828, 0.1331878, -0.36955845], + [0.91655576, 0.66612357, 0.3068175], + [-0.45732066, 0.8923335, 1.0542952], + [-0.73519516, 1.0518405, -1.0273266]], + + [[-0.79712886, -0.26814285, 0.12779616], + [1.0367643, -1.6180774, 0.42999932], + [-0.81818223, -0.81502074, 0.882194], + [0.53640485, 0.4178927, 1.6037121], + [0.9256354, -1.1006796, 0.16614541]]], + + [[[-1.5216796, -1.2473261, 0.6549515], + [0.63627815, 0.7221449, 0.02977821], + [-0.61331123, -0.49451825, 0.33852202], + [1.4510741, -1.3818305, -0.791747], + [0.6989747, 0.49558765, 1.0813237]], + + [[-0.03969796, 0.71586496, 0.8326594], + [-0.15443641, 1.0389746, -0.59301984], + [0.7197836, 0.03257621, 1.8398637], + [0.6111736, -0.16166899, -2.4869773], + [1.3066711, -1.8003578, 0.17412892]], + + [[-0.31470737, -0.5938182, -1.1311078], + [-0.99081016, 0.4005125, 0.44154453], + [1.0876914, -2.5958562, -0.5914863], + [1.3759689, -0.7741513, 0.19928917], + [1.6792973, 2.2744863, -0.04308867]]]]).astype(np.float32) + x = np.array([[[[-1.4311737, 1.015344], + [0.04431088, -2.2886624], + [1.4832113, 1.240908], + [0.67040104, 0.15266363]], + + [[0.44226435, 1.1461105], + [1.194218, 1.5547837], + [0.23152256, 1.5911953], + [0.11206784, 0.17978816]], + + [[-0.57803905, 0.8039611], + [0.0823025, -0.6134477], + [-1.4171146, 1.6269946], + [0.48878875, 0.9117505]]]]).astype(np.float32) + conv2d = NetConv(weight, x) + output = conv2d() + expected = np.array([[[[2.3498724], + [-1.9199573]], + [[5.376562], + [-5.425745]], + [[5.9105043], + [7.469034]]]]).astype(np.float32) + loss = np.abs(expected - output.asnumpy()) + error = 1e-4 * np.ones(loss.shape) + assert (loss < error).all() + + +class NetConv2dDynamic(nn.Cell): + def __init__(self, axis=0, out_nums=1): + super(NetConv2dDynamic, self).__init__() + self.dynshape = inner.GpuConvertToDynamicShape() + out_channel = 2 + kernel_size = 1 + self.conv = P.Conv2D(out_channel, + kernel_size, + mode=1, + pad_mode="valid", + pad=0, + stride=1, + dilation=1, + group=1) + + def construct(self, x, w): + x_dyn = self.dynshape(x) + w_dyn = self.dynshape(w) + x_conv = self.conv(x_dyn, w_dyn) + return x_conv + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_conv2d_dynamic(): + x1 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) + w1 = Tensor(np.arange(2 * 3 * 1 * 1).reshape(2, 3, 1, 1).astype(np.float32)) + expect1 = np.array([[[[45, 48, 51], + [54, 57, 60], + [63, 66, 69]], + [[126, 138, 150], + [162, 174, 186], + [198, 210, 222]]]]).astype(np.float32) + + x2 = Tensor(np.arange(5 * 1 * 2 * 2).reshape(5, 1, 2, 2).astype(np.float32)) + w2 = Tensor(np.arange(2 * 1 * 1 * 1).reshape(2, 1, 1, 1).astype(np.float32)) + expect2 = np.array([[[[0., 0.], + [0., 0.]], + [[0., 1.], + [2., 3.]]], + [[[0., 0.], + [0., 0.]], + [[4., 5.], + [6., 7.]]], + [[[0., 0.], + [0., 0.]], + [[8., 9.], + [10., 11.]]], + [[[0., 0.], + [0., 0.]], + [[12., 13.], + [14., 15.]]], + [[[0., 0.], + [0., 0.]], + [[16., 17.], + [18., 19.]]]]).astype(np.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + conv2d = NetConv2dDynamic() + output1 = conv2d(x1, w1) + assert (output1.asnumpy() == expect1).all() + output2 = conv2d(x2, w2) + assert (output2.asnumpy() == expect2).all()