Browse Source

!12875 Adapt some ops for 3d format.

From: @liu_xiao_93
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
7446b67a3b
12 changed files with 59 additions and 46 deletions
  1. +2
    -5
      mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc
  2. +2
    -5
      mindspore/ccsrc/backend/session/kernel_graph.cc
  3. +2
    -2
      mindspore/core/abstract/prim_nn.cc
  4. +12
    -10
      mindspore/core/utils/check_convert_utils.cc
  5. +2
    -1
      mindspore/core/utils/check_convert_utils.h
  6. +5
    -8
      mindspore/ops/_op_impl/tbe/adam_apply_one.py
  7. +12
    -0
      mindspore/ops/_op_impl/tbe/bias_add_grad.py
  8. +2
    -4
      mindspore/ops/_op_impl/tbe/prelu.py
  9. +4
    -5
      mindspore/ops/_op_impl/tbe/prelu_grad.py
  10. +8
    -0
      mindspore/ops/_op_impl/tbe/sgd.py
  11. +2
    -2
      mindspore/ops/operations/_grad_ops.py
  12. +6
    -4
      mindspore/ops/operations/nn_ops.py

+ 2
- 5
mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc View File

@@ -73,11 +73,8 @@ std::string InitDefaultFormat(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
if (!result) {
auto attr = GetValue<std::string>(data_format_ptr);
if (attr == kOpFormat_NCDHW) {
return kOpFormat_NCDHW;
}
if (result && data_format == Format::NCDHW) {
return kOpFormat_NCDHW;
}
} else if (AnfAlgo::IsRealKernel(node)) {
auto formats = AnfAlgo::GetAllOutputFormats(node);


+ 2
- 5
mindspore/ccsrc/backend/session/kernel_graph.cc View File

@@ -432,11 +432,8 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL(data_format_ptr);
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format);
if (!result) {
auto attr = GetValue<std::string>(data_format_ptr);
if (attr == kOpFormat_NCDHW) {
ResetInFormat(cnode, kOpFormat_NCDHW);
}
if (result && data_format == Format::NCDHW) {
ResetInFormat(cnode, kOpFormat_NCDHW);
}
}
AnfAlgo::SetGraphId(graph_id_, cnode.get());


+ 2
- 2
mindspore/core/abstract/prim_nn.cc View File

@@ -26,8 +26,8 @@ namespace abstract {
int64_t GetAndCheckFormat(const ValuePtr &value) {
int64_t data_format;
bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format);
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) {
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC";
if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) {
MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW";
}
return data_format;
}


+ 12
- 10
mindspore/core/utils/check_convert_utils.cc View File

@@ -30,19 +30,21 @@

namespace mindspore {
static std::map<std::string, int64_t> DataFormatToEnumMap = {
{"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4},
{"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW},
{"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK},
{"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC},
{"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
{"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4},
{"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW},
{"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK},
{"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC},
{"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT},
{"NCDHW", Format::NCDHW},
};

static std::map<int64_t, std::string> DataFormatToStrMap = {
{Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"},
{Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"},
{Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"},
{Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"},
{Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
{Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"},
{Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"},
{Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"},
{Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"},
{Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"},
{Format::NCDHW, "NCDHW"},
};

static std::map<std::string, int64_t> ReductionToEnumMap = {


+ 2
- 1
mindspore/core/utils/check_convert_utils.h View File

@@ -59,7 +59,8 @@ enum Format : int64_t {
NC = 11,
NC4 = 12,
NC4HW4 = 13,
NUM_OF_FORMAT = 14
NUM_OF_FORMAT = 14,
NCDHW = 15
};
enum ActivationType : int64_t {
NO_ACTIVATION = 0,


+ 5
- 8
mindspore/ops/_op_impl/tbe/adam_apply_one.py View File

@@ -36,14 +36,11 @@ adam_apply_one_op_info = TBERegOp("AdamApplyOne") \
.output(0, "output0", False, "required", "all") \
.output(1, "output1", False, "required", "all") \
.output(2, "output2", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None) \
.get_op_info()




+ 12
- 0
mindspore/ops/_op_impl/tbe/bias_add_grad.py View File

@@ -30,6 +30,18 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F32_NHWC) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \
.dtype_format(DataType.F16_Default, DataType.F32_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_NHWC) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F32_Default) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NHWC) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_Default) \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_NHWC) \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \
.get_op_info()




+ 2
- 4
mindspore/ops/_op_impl/tbe/prelu.py View File

@@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \
.input(0, "x", False, "required", "all") \
.input(1, "weight", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()




+ 4
- 5
mindspore/ops/_op_impl/tbe/prelu_grad.py View File

@@ -27,11 +27,10 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \
.input(1, "features", False, "required", "all") \
.input(2, "weights", False, "required", "all") \
.output(0, "dx", False, "required", "all") \
.output(0, "da", False, "required", "all") \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.output(1, "da", False, "required", "all") \
.op_pattern("dynamicFormat") \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None,
DataType.None_None) \
.get_op_info()




+ 8
- 0
mindspore/ops/_op_impl/tbe/sgd.py View File

@@ -39,12 +39,20 @@ sgd_op_info = TBERegOp("SGD") \
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ,
DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_Default, DataType.F16_NDC1HWC0,
DataType.F32_Default, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D, DataType.F32_Default, DataType.F16_FRACTAL_Z_3D,
DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD,
DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ,
DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_Default, DataType.F32_NDC1HWC0,
DataType.F32_Default, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D, DataType.F32_Default, DataType.F32_FRACTAL_Z_3D,
DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D) \
.get_op_info()




+ 2
- 2
mindspore/ops/operations/_grad_ops.py View File

@@ -229,13 +229,13 @@ class BiasAddGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, data_format="NCHW"):
self.init_prim_io_names(inputs=['dout'], outputs=['output'])
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)

def infer_shape(self, d_output):
channel = d_output[1] if self.format == "NCHW" else d_output[-1]
channel = d_output[-1] if self.format == "NHWC" else d_output[1]
return (channel,)

def infer_dtype(self, dout_dtype):


+ 6
- 4
mindspore/ops/operations/nn_ops.py View File

@@ -2056,11 +2056,11 @@ class BiasAdd(PrimitiveWithCheck):
except for the channel axis.

Args:
data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW',
data_format (str): The format of input and output data. It should be 'NHWC', 'NCHW' or 'NCDHW',
default is 'NCHW'.

Inputs:
- **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions.
- **input_x** (Tensor) - The input tensor. The shape can be 2-5 dimensions.
- **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of
`bias` must be the same as `input_x`'s channel dimension.

@@ -2082,15 +2082,17 @@ class BiasAdd(PrimitiveWithCheck):
@prim_attr_register
def __init__(self, data_format="NCHW"):
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")
self.add_prim_attr('data_format', self.format)

def check_shape(self, x_shape, b_shape):
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
if self.format == "NCDHW" and len(x_shape) != 5:
raise ValueError("NCDHW format only support 5-dims input.")
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1]
x_channel = x_shape[-1] if self.format == "NHWC" else x_shape[1]
if np.all(np.array(x_shape) != -1):
validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name)



Loading…
Cancel
Save