From: @liu_xiao_93 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -73,11 +73,8 @@ std::string InitDefaultFormat(const AnfNodePtr &node) { | |||||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | MS_EXCEPTION_IF_NULL(data_format_ptr); | ||||
| int64_t data_format; | int64_t data_format; | ||||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | ||||
| if (!result) { | |||||
| auto attr = GetValue<std::string>(data_format_ptr); | |||||
| if (attr == kOpFormat_NCDHW) { | |||||
| return kOpFormat_NCDHW; | |||||
| } | |||||
| if (result && data_format == Format::NCDHW) { | |||||
| return kOpFormat_NCDHW; | |||||
| } | } | ||||
| } else if (AnfAlgo::IsRealKernel(node)) { | } else if (AnfAlgo::IsRealKernel(node)) { | ||||
| auto formats = AnfAlgo::GetAllOutputFormats(node); | auto formats = AnfAlgo::GetAllOutputFormats(node); | ||||
| @@ -432,11 +432,8 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { | |||||
| MS_EXCEPTION_IF_NULL(data_format_ptr); | MS_EXCEPTION_IF_NULL(data_format_ptr); | ||||
| int64_t data_format; | int64_t data_format; | ||||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | bool result = CheckAndConvertUtils::GetDataFormatEnumValue(data_format_ptr, &data_format); | ||||
| if (!result) { | |||||
| auto attr = GetValue<std::string>(data_format_ptr); | |||||
| if (attr == kOpFormat_NCDHW) { | |||||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||||
| } | |||||
| if (result && data_format == Format::NCDHW) { | |||||
| ResetInFormat(cnode, kOpFormat_NCDHW); | |||||
| } | } | ||||
| } | } | ||||
| AnfAlgo::SetGraphId(graph_id_, cnode.get()); | AnfAlgo::SetGraphId(graph_id_, cnode.get()); | ||||
| @@ -26,8 +26,8 @@ namespace abstract { | |||||
| int64_t GetAndCheckFormat(const ValuePtr &value) { | int64_t GetAndCheckFormat(const ValuePtr &value) { | ||||
| int64_t data_format; | int64_t data_format; | ||||
| bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); | ||||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { | |||||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; | |||||
| if (!result || (data_format != Format::NHWC && data_format != Format::NCHW && data_format != Format::NCDHW)) { | |||||
| MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW, NHWC and NCDHW"; | |||||
| } | } | ||||
| return data_format; | return data_format; | ||||
| } | } | ||||
| @@ -30,19 +30,21 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| static std::map<std::string, int64_t> DataFormatToEnumMap = { | static std::map<std::string, int64_t> DataFormatToEnumMap = { | ||||
| {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4}, | |||||
| {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW}, | |||||
| {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK}, | |||||
| {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC}, | |||||
| {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT}, | |||||
| {"NCHW", Format::NCHW}, {"NHWC", Format::NHWC}, {"NHWC4", Format::NHWC4}, | |||||
| {"HWKC", Format::HWKC}, {"HWCK", Format::HWCK}, {"KCHW", Format::KCHW}, | |||||
| {"CKHW", Format::CKHW}, {"KHWC", Format::KHWC}, {"CHWK", Format::CHWK}, | |||||
| {"HW", Format::HW}, {"HW4", Format::HW4}, {"NC", Format::NC}, | |||||
| {"NC4", Format::NC4}, {"NC4HW4", Format::NC4HW4}, {"NUM_OF_FORMAT", Format::NUM_OF_FORMAT}, | |||||
| {"NCDHW", Format::NCDHW}, | |||||
| }; | }; | ||||
| static std::map<int64_t, std::string> DataFormatToStrMap = { | static std::map<int64_t, std::string> DataFormatToStrMap = { | ||||
| {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"}, | |||||
| {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"}, | |||||
| {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"}, | |||||
| {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"}, | |||||
| {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"}, | |||||
| {Format::NCHW, "NCHW"}, {Format::NHWC, "NHWC"}, {Format::NHWC4, "NHWC4"}, | |||||
| {Format::HWKC, "HWKC"}, {Format::HWCK, "HWCK"}, {Format::KCHW, "KCHW"}, | |||||
| {Format::CKHW, "CKHW"}, {Format::KHWC, "KHWC"}, {Format::CHWK, "CHWK"}, | |||||
| {Format::HW, "HW"}, {Format::HW4, "HW4"}, {Format::NC, "NC"}, | |||||
| {Format::NC4, "NC4"}, {Format::NC4HW4, "NC4HW4"}, {Format::NUM_OF_FORMAT, "NUM_OF_FORMAT"}, | |||||
| {Format::NCDHW, "NCDHW"}, | |||||
| }; | }; | ||||
| static std::map<std::string, int64_t> ReductionToEnumMap = { | static std::map<std::string, int64_t> ReductionToEnumMap = { | ||||
| @@ -59,7 +59,8 @@ enum Format : int64_t { | |||||
| NC = 11, | NC = 11, | ||||
| NC4 = 12, | NC4 = 12, | ||||
| NC4HW4 = 13, | NC4HW4 = 13, | ||||
| NUM_OF_FORMAT = 14 | |||||
| NUM_OF_FORMAT = 14, | |||||
| NCDHW = 15 | |||||
| }; | }; | ||||
| enum ActivationType : int64_t { | enum ActivationType : int64_t { | ||||
| NO_ACTIVATION = 0, | NO_ACTIVATION = 0, | ||||
| @@ -36,14 +36,11 @@ adam_apply_one_op_info = TBERegOp("AdamApplyOne") \ | |||||
| .output(0, "output0", False, "required", "all") \ | .output(0, "output0", False, "required", "all") \ | ||||
| .output(1, "output1", False, "required", "all") \ | .output(1, "output1", False, "required", "all") \ | ||||
| .output(2, "output2", False, "required", "all") \ | .output(2, "output2", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, | |||||
| DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default) \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||||
| DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||||
| DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||||
| DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -30,6 +30,18 @@ bias_add_grad_op_info = TBERegOp("BiasAddGrad") \ | |||||
| .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ | .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | .dtype_format(DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ | .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F16_FracNZ, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \ | |||||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F32_NHWC) \ | |||||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -26,10 +26,8 @@ prelu_op_info = TBERegOp("PReLU") \ | |||||
| .input(0, "x", False, "required", "all") \ | .input(0, "x", False, "required", "all") \ | ||||
| .input(1, "weight", False, "required", "all") \ | .input(1, "weight", False, "required", "all") \ | ||||
| .output(0, "y", False, "required", "all") \ | .output(0, "y", False, "required", "all") \ | ||||
| .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ | |||||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -27,11 +27,10 @@ prelu_grad_op_info = TBERegOp("PReLUGrad") \ | |||||
| .input(1, "features", False, "required", "all") \ | .input(1, "features", False, "required", "all") \ | ||||
| .input(2, "weights", False, "required", "all") \ | .input(2, "weights", False, "required", "all") \ | ||||
| .output(0, "dx", False, "required", "all") \ | .output(0, "dx", False, "required", "all") \ | ||||
| .output(0, "da", False, "required", "all") \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, | |||||
| DataType.F32_5HD, DataType.F32_5HD) \ | |||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | |||||
| DataType.F32_Default, DataType.F32_Default) \ | |||||
| .output(1, "da", False, "required", "all") \ | |||||
| .op_pattern("dynamicFormat") \ | |||||
| .dtype_format(DataType.None_None, DataType.None_None, DataType.None_None, DataType.None_None, | |||||
| DataType.None_None) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -39,12 +39,20 @@ sgd_op_info = TBERegOp("SGD") \ | |||||
| DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | ||||
| .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_FracZ, | ||||
| DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | DataType.F16_Default, DataType.F16_FracZ, DataType.F16_FracZ) \ | ||||
| .dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0, DataType.F32_Default, DataType.F16_NDC1HWC0, | |||||
| DataType.F32_Default, DataType.F16_NDC1HWC0, DataType.F16_NDC1HWC0) \ | |||||
| .dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D, DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, | |||||
| DataType.F32_Default, DataType.F16_FRACTAL_Z_3D, DataType.F16_FRACTAL_Z_3D) \ | |||||
| .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_Default, DataType.F32_5HD, | ||||
| DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | DataType.F32_Default, DataType.F32_5HD, DataType.F32_5HD) \ | ||||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, | ||||
| DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | ||||
| .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_Default, DataType.F32_FracZ, | ||||
| DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | DataType.F32_Default, DataType.F32_FracZ, DataType.F32_FracZ) \ | ||||
| .dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0, DataType.F32_Default, DataType.F32_NDC1HWC0, | |||||
| DataType.F32_Default, DataType.F32_NDC1HWC0, DataType.F32_NDC1HWC0) \ | |||||
| .dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D, DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, | |||||
| DataType.F32_Default, DataType.F32_FRACTAL_Z_3D, DataType.F32_FRACTAL_Z_3D) \ | |||||
| .get_op_info() | .get_op_info() | ||||
| @@ -229,13 +229,13 @@ class BiasAddGrad(PrimitiveWithInfer): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, data_format="NCHW"): | def __init__(self, data_format="NCHW"): | ||||
| self.init_prim_io_names(inputs=['dout'], outputs=['output']) | self.init_prim_io_names(inputs=['dout'], outputs=['output']) | ||||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) | |||||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | if context.get_context("device_target") != "GPU" and self.format == "NHWC": | ||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| def infer_shape(self, d_output): | def infer_shape(self, d_output): | ||||
| channel = d_output[1] if self.format == "NCHW" else d_output[-1] | |||||
| channel = d_output[-1] if self.format == "NHWC" else d_output[1] | |||||
| return (channel,) | return (channel,) | ||||
| def infer_dtype(self, dout_dtype): | def infer_dtype(self, dout_dtype): | ||||
| @@ -2056,11 +2056,11 @@ class BiasAdd(PrimitiveWithCheck): | |||||
| except for the channel axis. | except for the channel axis. | ||||
| Args: | Args: | ||||
| data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW', | |||||
| data_format (str): The format of input and output data. It should be 'NHWC', 'NCHW' or 'NCDHW', | |||||
| default is 'NCHW'. | default is 'NCHW'. | ||||
| Inputs: | Inputs: | ||||
| - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. | |||||
| - **input_x** (Tensor) - The input tensor. The shape can be 2-5 dimensions. | |||||
| - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of | - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of | ||||
| `bias` must be the same as `input_x`'s channel dimension. | `bias` must be the same as `input_x`'s channel dimension. | ||||
| @@ -2082,15 +2082,17 @@ class BiasAdd(PrimitiveWithCheck): | |||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self, data_format="NCHW"): | def __init__(self, data_format="NCHW"): | ||||
| self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) | self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) | ||||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) | |||||
| self.format = validator.check_string(data_format, ['NCHW', 'NHWC', 'NCDHW'], 'format', self.name) | |||||
| if context.get_context("device_target") != "GPU" and self.format == "NHWC": | if context.get_context("device_target") != "GPU" and self.format == "NHWC": | ||||
| raise ValueError("NHWC format only support in GPU target.") | raise ValueError("NHWC format only support in GPU target.") | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| def check_shape(self, x_shape, b_shape): | def check_shape(self, x_shape, b_shape): | ||||
| validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) | ||||
| if self.format == "NCDHW" and len(x_shape) != 5: | |||||
| raise ValueError("NCDHW format only support 5-dims input.") | |||||
| validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) | ||||
| x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] | |||||
| x_channel = x_shape[-1] if self.format == "NHWC" else x_shape[1] | |||||
| if np.all(np.array(x_shape) != -1): | if np.all(np.array(x_shape) != -1): | ||||
| validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name) | validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name) | ||||