From f37a230fe8c7dcd4959cca8018dcd52bfb17414f Mon Sep 17 00:00:00 2001 From: liu_xiao_93 Date: Fri, 29 Jan 2021 16:48:52 +0800 Subject: [PATCH] Adapt some ops for 3d format. --- .../backend/optimizer/ascend/ascend_helper.cc | 7 ++---- .../ccsrc/backend/session/kernel_graph.cc | 7 ++---- mindspore/core/abstract/prim_nn.cc | 4 ++-- mindspore/core/utils/check_convert_utils.cc | 22 ++++++++++--------- mindspore/core/utils/check_convert_utils.h | 3 ++- mindspore/ops/_op_impl/tbe/adam_apply_one.py | 13 +++++------ mindspore/ops/_op_impl/tbe/bias_add_grad.py | 12 ++++++++++ mindspore/ops/_op_impl/tbe/prelu.py | 6 ++--- mindspore/ops/_op_impl/tbe/prelu_grad.py | 9 ++++---- mindspore/ops/_op_impl/tbe/sgd.py | 8 +++++++ mindspore/ops/operations/_grad_ops.py | 4 ++-- mindspore/ops/operations/nn_ops.py | 10 +++++---- 12 files changed, 59 insertions(+), 46 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 802a40341c..7a314812b2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -73,11 +73,8 @@ std::string InitDefaultFormat(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(data_format_ptr); int64_t data_format; bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); - if (!result) { - auto attr = GetValue(data_format_ptr); - if (attr == kOpFormat_NCDHW) { - return kOpFormat_NCDHW; - } + if (result && data_format == Format::NCDHW) { + return kOpFormat_NCDHW; } } else if (AnfAlgo::IsRealKernel(node)) { auto formats = AnfAlgo::GetAllOutputFormats(node); diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 557f083d02..13702efe15 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -432,11 +432,8 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { MS_EXCEPTION_IF_NULL(data_format_ptr); int64_t data_format; bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); - if (!result) { - auto attr = GetValue(data_format_ptr); - if (attr == kOpFormat_NCDHW) { - ResetInFormat(cnode, kOpFormat_NCDHW); - } + if (result && data_format == Format::NCDHW) { + ResetInFormat(cnode, kOpFormat_NCDHW); } } AnfAlgo::SetGraphId(graph_id_, cnode.get()); diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 42294df5ab..fb0d336ed7 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -26,8 +26,8 @@ namespace abstract { int64_t GetAndCheckFormat(const ValuePtr &value) { int64_t data_format; bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); - if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { - MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; + if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) { + MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW"; } return data_format; } diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 92f1fcd9d6..5464b4bf74 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -30,19 +30,21 @@ namespace mindspore { static std::map DataFormatToEnumMap = { - {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4}, - {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW}, - {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK}, - {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC}, - {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT}, + {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4}, + {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW}, + {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK}, + {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC}, + {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT}, + {"NCDHW", Format::NCDHW}, }; static std::map DataFormatToStrMap = { - {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"}, - {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"}, - {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"}, - {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"}, - {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"}, + {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"}, + {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"}, + {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"}, + {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"}, + {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"}, + {Format::NCDHW, "NCDHW"}, }; static std::map ReductionToEnumMap = { diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index fc17e2aaa0..eeb3ab136e 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -59,7 +59,8 @@ enum Format : int64_t { NC = 11, NC4 = 12, NC4HW4 = 13, - NUM_OF_FORMAT = 14 + NUM_OF_FORMAT = 14, + NCDHW = 15 }; enum ActivationType : int64_t { NO_ACTIVATION = 0, diff --git a/mindspore/ops/_op_impl/tbe/adam_apply_one.py b/mindspore/ops/_op_impl/tbe/adam_apply_one.py index edd12bf556..860da9820c 100644 --- a/mindspore/ops/_op_impl/tbe/adam_apply_one.py +++ b/mindspore/ops/_op_impl/tbe/adam_apply_one.py @@ -36,14 +36,11 @@ adam_apply_one_op_info = TBERegOp("AdamApplyOne") \ .output(0, "output0", False, "required", "all") \ .output(1, "output1", False, "required", "all") \ .output(2, "output2", False, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default) \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, + DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, + DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, + DataType.None_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/bias_add_grad.py b/mindspore/ops/_op_impl/tbe/bias_add_grad.py index 482ed706ff..f0a5c74f95 100644 --- a/mindspore/ops/_op_impl/tbe/bias_add_grad.py +++ b/mindspore/ops/_op_impl/tbe/bias_add_grad.py @@ -30,6 +30,18 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \ + .dtype_format(DataType.F16_Default, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_Default, DataType.F32_NHWC) \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_Default) \ + .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_Default) \ + .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/prelu.py b/mindspore/ops/_op_impl/tbe/prelu.py index 0733e29bee..47bec5ed38 100644 --- a/mindspore/ops/_op_impl/tbe/prelu.py +++ b/mindspore/ops/_op_impl/tbe/prelu.py @@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \ .input(0, "x", False, "required", "all") \ .input(1, "weight", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/prelu_grad.py b/mindspore/ops/_op_impl/tbe/prelu_grad.py index f972428d8f..7c630043e5 100644 --- a/mindspore/ops/_op_impl/tbe/prelu_grad.py +++ b/mindspore/ops/_op_impl/tbe/prelu_grad.py @@ -27,11 +27,10 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \ .input(1, "features", False, "required", "all") \ .input(2, "weights", False, "required", "all") \ .output(0, "dx", False, "required", "all") \ - .output(0, "da", False, "required", "all") \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, - DataType.F32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default) \ + .output(1, "da", False, "required", "all") \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, + DataType.None_None) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/sgd.py b/mindspore/ops/_op_impl/tbe/sgd.py index 64ecc9272e..40d1f93317 100644 --- a/mindspore/ops/_op_impl/tbe/sgd.py +++ b/mindspore/ops/_op_impl/tbe/sgd.py @@ -39,12 +39,20 @@ sgd_op_info = TBERegOp("SGD") \ DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_Default, DataType.F16_NDC1HWC0, + DataType.F32_Default, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \ + .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D, DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, + DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_Default, DataType.F32_NDC1HWC0, + DataType.F32_Default, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ + .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D, DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, + DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D) \ .get_op_info() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 16f19c4278..8169d28423 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -228,13 +228,13 @@ class BiasAddGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, data_format="NCHW"): self.init_prim_io_names(inputs=['dout'], outputs=['output']) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.add_prim_attr('data_format', self.format) def infer_shape(self, d_output): - channel = d_output[1] if self.format == "NCHW" else d_output[-1] + channel = d_output[-1] if self.format == "NHWC" else d_output[1] return (channel,) def infer_dtype(self, dout_dtype): diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 6190518ac5..226c8cdc56 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -2056,11 +2056,11 @@ class BiasAdd(PrimitiveWithCheck): except for the channel axis. Args: - data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW', + data_format (str): The format of input and output data. It should be 'NHWC', 'NCHW' or 'NCDHW', default is 'NCHW'. Inputs: - - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. + - **input_x** (Tensor) - The input tensor. The shape can be 2-5 dimensions. - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of `bias` must be the same as `input_x`'s channel dimension. @@ -2082,15 +2082,17 @@ class BiasAdd(PrimitiveWithCheck): @prim_attr_register def __init__(self, data_format="NCHW"): self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.") self.add_prim_attr('data_format', self.format) def check_shape(self, x_shape, b_shape): validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) + if self.format == "NCDHW" and len(x_shape) != 5: + raise ValueError("NCDHW format only support 5-dims input.") validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) - x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] + x_channel = x_shape[-1] if self.format == "NHWC" else x_shape[1] if np.all(np.array(x_shape) != -1): validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name)