From: @liangzhibo Reviewed-by: @kingxian Signed-off-by: @kingxianpull/12757/MERGE
| @@ -18,7 +18,22 @@ | |||
| namespace mindspore { | |||
| namespace transform { | |||
| std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"MatMul", "ND"}, {"Conv3D", "format"}}; | |||
| std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"BasicLSTMCell", "ND"}, | |||
| {"BasicLSTMCellInputGrad", "ND"}, | |||
| {"BasicLSTMCellCStateGrad", "ND"}, | |||
| {"Dequant", "ND"}, | |||
| {"DynamicGRUV2", "ND"}, | |||
| {"DynamicGRUV2Grad", "ND"}, | |||
| {"DynamicRNN", "ND"}, | |||
| {"DynamicRNNGrad", "ND"}, | |||
| {"MatMul", "ND"}, | |||
| {"Quant", "ND"}, | |||
| {"BasicLSTMCellWeightGrad", "HWCN"}, | |||
| {"ExtractImagePatches", "NCHW"}, | |||
| {"Conv3D", "format"}, | |||
| {"Conv3DBackpropFilter", "format"}, | |||
| {"Conv3DBackpropInput", "format"}, | |||
| {"Conv3DTranspose", "format"}}; | |||
| std::unordered_map<std::string, std::string> &IOFormatMap::get() { return io_format_map_; } | |||
| } // namespace transform | |||
| } // namespace mindspore | |||
| @@ -294,6 +294,9 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) { | |||
| MS_LOG(ERROR) << "The anf is not a Primitive."; | |||
| return ret; | |||
| } | |||
| if (prim->HasAttr("io_format")) { | |||
| return GetValue<std::string>(prim->GetAttr("io_format")); | |||
| } | |||
| auto io_format_map = IOFormatMap::get(); | |||
| auto iter = io_format_map.find(prim->name()); | |||
| if (iter == io_format_map.end()) { | |||
| @@ -393,7 +393,6 @@ class Conv3DBackpropFilter(PrimitiveWithInfer): | |||
| self.add_prim_attr('groups', self.group) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.add_prim_attr('io_format', self.format) | |||
| def __infer__(self, x, doutput, w_size): | |||
| w_size_v = w_size['value'] | |||
| @@ -1367,7 +1366,6 @@ class DynamicRNNGrad(PrimitiveWithInfer): | |||
| time_major=True, | |||
| forget_bias=0.0): | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, | |||
| c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): | |||
| @@ -1478,7 +1476,6 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): | |||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | |||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, | |||
| dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): | |||
| @@ -2063,7 +2060,6 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||
| def __init__(self, forget_bias, activation): | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | |||
| # dhy and dcy should be same shape | |||
| @@ -2110,10 +2106,9 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||
| class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | |||
| """Computes the weight gradients of BasicLSTM.""" | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| self.add_prim_attr("io_format", "HWCN") | |||
| pass | |||
| def infer_shape(self, x_shape, h_shape, dgate_shape): | |||
| validator.check_equal_int(len(x_shape), 2, "x rank", self.name) | |||
| @@ -2145,7 +2140,6 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||
| def __init__(self, keep_prob): | |||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | |||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, dgate_shape, w_shape): | |||
| validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) | |||
| @@ -74,7 +74,6 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||
| _check_tuple_or_list("rate", rates, self.name) | |||
| self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) | |||
| self.add_prim_attr("padding", self.padding) | |||
| self.add_prim_attr("io_format", "NCHW") | |||
| self.is_ge = context.get_context("enable_ge") | |||
| def infer_shape(self, input_x): | |||
| @@ -213,7 +212,6 @@ class Quant(PrimitiveWithInfer): | |||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | |||
| self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"], | |||
| "round_mode", self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| @@ -265,7 +263,6 @@ class Dequant(PrimitiveWithInfer): | |||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | |||
| self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | |||
| self.add_prim_attr("dtype", mstype.float16) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, deq_scale_shape): | |||
| return x_shape | |||
| @@ -7222,7 +7222,6 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | |||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | |||
| validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name) | |||
| @@ -7373,7 +7372,6 @@ class DynamicRNN(PrimitiveWithInfer): | |||
| self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name) | |||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name) | |||
| @@ -7533,7 +7531,6 @@ class DynamicGRUV2(PrimitiveWithInfer): | |||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | |||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | |||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | |||
| self.add_prim_attr("io_format", "ND") | |||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | |||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | |||
| @@ -8024,7 +8021,6 @@ class Conv3DBackpropInput(PrimitiveWithInfer): | |||
| self.add_prim_attr('groups', self.group) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.add_prim_attr('io_format', self.format) | |||
| def __infer__(self, w, doutput, x_size): | |||
| validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name) | |||
| @@ -8202,7 +8198,6 @@ class Conv3DTranspose(PrimitiveWithInfer): | |||
| self.add_prim_attr('groups', self.group) | |||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | |||
| self.add_prim_attr('data_format', self.format) | |||
| self.add_prim_attr('io_format', self.format) | |||
| self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name, | |||
| allow_five=True, ret_five=True, greater_zero=False) | |||