diff --git a/mindspore/ccsrc/transform/graph_ir/io_format_map.cc b/mindspore/ccsrc/transform/graph_ir/io_format_map.cc index 8d265cc7e4..61520589d6 100644 --- a/mindspore/ccsrc/transform/graph_ir/io_format_map.cc +++ b/mindspore/ccsrc/transform/graph_ir/io_format_map.cc @@ -18,7 +18,22 @@ namespace mindspore { namespace transform { -std::unordered_map IOFormatMap::io_format_map_ = {{"MatMul", "ND"}, {"Conv3D", "format"}}; +std::unordered_map IOFormatMap::io_format_map_ = {{"BasicLSTMCell", "ND"}, + {"BasicLSTMCellInputGrad", "ND"}, + {"BasicLSTMCellCStateGrad", "ND"}, + {"Dequant", "ND"}, + {"DynamicGRUV2", "ND"}, + {"DynamicGRUV2Grad", "ND"}, + {"DynamicRNN", "ND"}, + {"DynamicRNNGrad", "ND"}, + {"MatMul", "ND"}, + {"Quant", "ND"}, + {"BasicLSTMCellWeightGrad", "HWCN"}, + {"ExtractImagePatches", "NCHW"}, + {"Conv3D", "format"}, + {"Conv3DBackpropFilter", "format"}, + {"Conv3DBackpropInput", "format"}, + {"Conv3DTranspose", "format"}}; std::unordered_map &IOFormatMap::get() { return io_format_map_; } } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc index 36df9eb6df..5ec2bfa980 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_util.cc @@ -294,6 +294,9 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) { MS_LOG(ERROR) << "The anf is not a Primitive."; return ret; } + if (prim->HasAttr("io_format")) { + return GetValue(prim->GetAttr("io_format")); + } auto io_format_map = IOFormatMap::get(); auto iter = io_format_map.find(prim->name()); if (iter == io_format_map.end()) { diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 0bc44bcce0..6b78671364 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -393,7 +393,6 @@ class Conv3DBackpropFilter(PrimitiveWithInfer): self.add_prim_attr('groups', self.group) self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) self.add_prim_attr('data_format', self.format) - self.add_prim_attr('io_format', self.format) def __infer__(self, x, doutput, w_size): w_size_v = w_size['value'] @@ -1367,7 +1366,6 @@ class DynamicRNNGrad(PrimitiveWithInfer): time_major=True, forget_bias=0.0): self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): @@ -1478,7 +1476,6 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): @@ -2063,7 +2060,6 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): def __init__(self, forget_bias, activation): self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): # dhy and dcy should be same shape @@ -2110,10 +2106,9 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): class BasicLSTMCellWeightGrad(PrimitiveWithInfer): """Computes the weight gradients of BasicLSTM.""" - @prim_attr_register def __init__(self): - self.add_prim_attr("io_format", "HWCN") + pass def infer_shape(self, x_shape, h_shape, dgate_shape): validator.check_equal_int(len(x_shape), 2, "x rank", self.name) @@ -2145,7 +2140,6 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): def __init__(self, keep_prob): self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, dgate_shape, w_shape): validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index 07353a0e44..93da3906d8 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -74,7 +74,6 @@ class ExtractImagePatches(PrimitiveWithInfer): _check_tuple_or_list("rate", rates, self.name) self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) self.add_prim_attr("padding", self.padding) - self.add_prim_attr("io_format", "NCHW") self.is_ge = context.get_context("enable_ge") def infer_shape(self, input_x): @@ -213,7 +212,6 @@ class Quant(PrimitiveWithInfer): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"], "round_mode", self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape): return x_shape @@ -265,7 +263,6 @@ class Dequant(PrimitiveWithInfer): self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) self.add_prim_attr("dtype", mstype.float16) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, deq_scale_shape): return x_shape diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 13dc3a35a0..cd3ed5eb55 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -7222,7 +7222,6 @@ class BasicLSTMCell(PrimitiveWithInfer): self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name) @@ -7373,7 +7372,6 @@ class DynamicRNN(PrimitiveWithInfer): self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name) self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name) @@ -7533,7 +7531,6 @@ class DynamicGRUV2(PrimitiveWithInfer): self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) - self.add_prim_attr("io_format", "ND") def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) @@ -8024,7 +8021,6 @@ class Conv3DBackpropInput(PrimitiveWithInfer): self.add_prim_attr('groups', self.group) self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) self.add_prim_attr('data_format', self.format) - self.add_prim_attr('io_format', self.format) def __infer__(self, w, doutput, x_size): validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name) @@ -8202,7 +8198,6 @@ class Conv3DTranspose(PrimitiveWithInfer): self.add_prim_attr('groups', self.group) self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) self.add_prim_attr('data_format', self.format) - self.add_prim_attr('io_format', self.format) self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name, allow_five=True, ret_five=True, greater_zero=False)