From: @liu_xiao_93 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -73,11 +73,8 @@ std::string InitDefaultFormat(const AnfNodePtr &node) { | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (!result) { | |||
| auto attr = GetValue<std::string>(data_format_ptr); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| return kOpFormat_NCDHW; | |||
| } | |||
| if (result && data_format == Format::NCDHW) { | |||
| return kOpFormat_NCDHW; | |||
| } | |||
| } else if (AnfAlgo::IsRealKernel(node)) { | |||
| auto formats = AnfAlgo::GetAllOutputFormats(node); | |||
| @@ -432,11 +432,8 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | |||
| if (!result) { | |||
| auto attr = GetValue<std::string>(data_format_ptr); | |||
| if (attr == kOpFormat_NCDHW) { | |||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||
| } | |||
| if (result && data_format == Format::NCDHW) { | |||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||
| } | |||
| } | |||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | |||
| @@ -26,8 +26,8 @@ namespace abstract { | |||
| int64_t GetAndCheckFormat(const ValuePtr &value) { | |||
| int64_t data_format; | |||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | |||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { | |||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; | |||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) { | |||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW"; | |||
| } | |||
| return data_format; | |||
| } | |||
| @@ -30,19 +30,21 @@ | |||
| namespace mindspore { | |||
| static std::map<std::string, int64_t> DataFormatToEnumMap = { | |||
| {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4}, | |||
| {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW}, | |||
| {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK}, | |||
| {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC}, | |||
| {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT}, | |||
| {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4}, | |||
| {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW}, | |||
| {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK}, | |||
| {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC}, | |||
| {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT}, | |||
| {"NCDHW", Format::NCDHW}, | |||
| }; | |||
| static std::map<int64_t, std::string> DataFormatToStrMap = { | |||
| {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"}, | |||
| {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"}, | |||
| {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"}, | |||
| {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"}, | |||
| {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"}, | |||
| {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"}, | |||
| {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"}, | |||
| {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"}, | |||
| {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"}, | |||
| {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"}, | |||
| {Format::NCDHW, "NCDHW"}, | |||
| }; | |||
| static std::map<std::string, int64_t> ReductionToEnumMap = { | |||
| @@ -59,7 +59,8 @@ enum Format : int64_t { | |||
| NC = 11, | |||
| NC4 = 12, | |||
| NC4HW4 = 13, | |||
| NUM_OF_FORMAT = 14 | |||
| NUM_OF_FORMAT = 14, | |||
| NCDHW = 15 | |||
| }; | |||
| enum ActivationType : int64_t { | |||
| NO_ACTIVATION = 0, | |||
| @@ -36,14 +36,11 @@ adam_apply_one_op_info = TBERegOp("AdamApplyOne") \ | |||
| .output(0, "output0", False, "required", "all") \ | |||
| .output(1, "output1", False, "required", "all") \ | |||
| .output(2, "output2", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||
| DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||
| DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||
| DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||
| DataType.None_None) \ | |||
| .get_op_info() | |||
| @@ -30,6 +30,18 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_NHWC) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \ | |||
| .get_op_info() | |||
| @@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \ | |||
| .input(0, "x", False, "required", "all") \ | |||
| .input(1, "weight", False, "required", "all") \ | |||
| .output(0, "y", False, "required", "all") \ | |||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||
| .get_op_info() | |||
| @@ -27,11 +27,10 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \ | |||
| .input(1, "features", False, "required", "all") \ | |||
| .input(2, "weights", False, "required", "all") \ | |||
| .output(0, "dx", False, "required", "all") \ | |||
| .output(0, "da", False, "required", "all") \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default) \ | |||
| .output(1, "da", False, "required", "all") \ | |||
| .op_pattern("dynamicFormat") \ | |||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||
| DataType.None_None) \ | |||
| .get_op_info() | |||
| @@ -39,12 +39,20 @@ sgd_op_info = TBERegOp("SGD") \ | |||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | |||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | |||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_Default, DataType.F16_NDC1HWC0, | |||
| DataType.F32_Default, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \ | |||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D, DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, | |||
| DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D) \ | |||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | |||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | |||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | |||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_Default, DataType.F32_NDC1HWC0, | |||
| DataType.F32_Default, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D, DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, | |||
| DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D) \ | |||
| .get_op_info() | |||
| @@ -229,13 +229,13 @@ class BiasAddGrad(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self, data_format="NCHW"): | |||
| self.init_prim_io_names(inputs=['dout'], outputs=['output']) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.add_prim_attr('data_format', self.format) | |||
| def infer_shape(self, d_output): | |||
| channel = d_output[1] if self.format == "NCHW" else d_output[-1] | |||
| channel = d_output[-1] if self.format == "NHWC" else d_output[1] | |||
| return (channel,) | |||
| def infer_dtype(self, dout_dtype): | |||
| @@ -2056,11 +2056,11 @@ class BiasAdd(PrimitiveWithCheck): | |||
| except for the channel axis. | |||
| Args: | |||
| data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW', | |||
| data_format (str): The format of input and output data. It should be 'NHWC', 'NCHW' or 'NCDHW', | |||
| default is 'NCHW'. | |||
| Inputs: | |||
| - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. | |||
| - **input_x** (Tensor) - The input tensor. The shape can be 2-5 dimensions. | |||
| - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of | |||
| `bias` must be the same as `input_x`'s channel dimension. | |||
| @@ -2082,15 +2082,17 @@ class BiasAdd(PrimitiveWithCheck): | |||
| @prim_attr_register | |||
| def __init__(self, data_format="NCHW"): | |||
| self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) | |||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | |||
| raise ValueError("NHWC format only support in GPU target.") | |||
| self.add_prim_attr('data_format', self.format) | |||
| def check_shape(self, x_shape, b_shape): | |||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | |||
| if self.format == "NCDHW" and len(x_shape) != 5: | |||
| raise ValueError("NCDHW format only support 5-dims input.") | |||
| validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | |||
| x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] | |||
| x_channel = x_shape[-1] if self.format == "NHWC" else x_shape[1] | |||
| if np.all(np.array(x_shape) != -1): | |||
| validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name) | |||