From: @liangzhibo Reviewed-by: @kingxian Signed-off-by: @kingxianpull/12757/MERGE
| @@ -18,7 +18,22 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace transform { | namespace transform { | ||||
| std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"MatMul", "ND"}, {"Conv3D", "format"}}; | |||||
| std::unordered_map<std::string, std::string> IOFormatMap::io_format_map_ = {{"BasicLSTMCell", "ND"}, | |||||
| {"BasicLSTMCellInputGrad", "ND"}, | |||||
| {"BasicLSTMCellCStateGrad", "ND"}, | |||||
| {"Dequant", "ND"}, | |||||
| {"DynamicGRUV2", "ND"}, | |||||
| {"DynamicGRUV2Grad", "ND"}, | |||||
| {"DynamicRNN", "ND"}, | |||||
| {"DynamicRNNGrad", "ND"}, | |||||
| {"MatMul", "ND"}, | |||||
| {"Quant", "ND"}, | |||||
| {"BasicLSTMCellWeightGrad", "HWCN"}, | |||||
| {"ExtractImagePatches", "NCHW"}, | |||||
| {"Conv3D", "format"}, | |||||
| {"Conv3DBackpropFilter", "format"}, | |||||
| {"Conv3DBackpropInput", "format"}, | |||||
| {"Conv3DTranspose", "format"}}; | |||||
| std::unordered_map<std::string, std::string> &IOFormatMap::get() { return io_format_map_; } | std::unordered_map<std::string, std::string> &IOFormatMap::get() { return io_format_map_; } | ||||
| } // namespace transform | } // namespace transform | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -294,6 +294,9 @@ std::string GetOpIOFormat(const AnfNodePtr &anf) { | |||||
| MS_LOG(ERROR) << "The anf is not a Primitive."; | MS_LOG(ERROR) << "The anf is not a Primitive."; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (prim->HasAttr("io_format")) { | |||||
| return GetValue<std::string>(prim->GetAttr("io_format")); | |||||
| } | |||||
| auto io_format_map = IOFormatMap::get(); | auto io_format_map = IOFormatMap::get(); | ||||
| auto iter = io_format_map.find(prim->name()); | auto iter = io_format_map.find(prim->name()); | ||||
| if (iter == io_format_map.end()) { | if (iter == io_format_map.end()) { | ||||
| @@ -393,7 +393,6 @@ class Conv3DBackpropFilter(PrimitiveWithInfer): | |||||
| self.add_prim_attr('groups', self.group) | self.add_prim_attr('groups', self.group) | ||||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| self.add_prim_attr('io_format', self.format) | |||||
| def __infer__(self, x, doutput, w_size): | def __infer__(self, x, doutput, w_size): | ||||
| w_size_v = w_size['value'] | w_size_v = w_size['value'] | ||||
| @@ -1367,7 +1366,6 @@ class DynamicRNNGrad(PrimitiveWithInfer): | |||||
| time_major=True, | time_major=True, | ||||
| forget_bias=0.0): | forget_bias=0.0): | ||||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, | def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape, | ||||
| c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): | c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape): | ||||
| @@ -1478,7 +1476,6 @@ class DynamicGRUV2Grad(PrimitiveWithInfer): | |||||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | ||||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | ||||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, | def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape, | ||||
| dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): | dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape): | ||||
| @@ -2063,7 +2060,6 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||||
| def __init__(self, forget_bias, activation): | def __init__(self, forget_bias, activation): | ||||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | ||||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape): | ||||
| # dhy and dcy should be same shape | # dhy and dcy should be same shape | ||||
| @@ -2110,10 +2106,9 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer): | |||||
| class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | class BasicLSTMCellWeightGrad(PrimitiveWithInfer): | ||||
| """Computes the weight gradients of BasicLSTM.""" | """Computes the weight gradients of BasicLSTM.""" | ||||
| @prim_attr_register | @prim_attr_register | ||||
| def __init__(self): | def __init__(self): | ||||
| self.add_prim_attr("io_format", "HWCN") | |||||
| pass | |||||
| def infer_shape(self, x_shape, h_shape, dgate_shape): | def infer_shape(self, x_shape, h_shape, dgate_shape): | ||||
| validator.check_equal_int(len(x_shape), 2, "x rank", self.name) | validator.check_equal_int(len(x_shape), 2, "x rank", self.name) | ||||
| @@ -2145,7 +2140,6 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer): | |||||
| def __init__(self, keep_prob): | def __init__(self, keep_prob): | ||||
| self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name) | ||||
| self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | self.keep_prob = validator.check_float_range(keep_prob, 0.0, 1.0, Rel.INC_BOTH, "keep_prob", self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, dgate_shape, w_shape): | def infer_shape(self, dgate_shape, w_shape): | ||||
| validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) | validator.check_equal_int(len(dgate_shape), 2, "dgate rank", self.name) | ||||
| @@ -74,7 +74,6 @@ class ExtractImagePatches(PrimitiveWithInfer): | |||||
| _check_tuple_or_list("rate", rates, self.name) | _check_tuple_or_list("rate", rates, self.name) | ||||
| self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) | self.padding = validator.check_string(padding.upper(), ['VALID', 'SAME'], 'padding', self.name) | ||||
| self.add_prim_attr("padding", self.padding) | self.add_prim_attr("padding", self.padding) | ||||
| self.add_prim_attr("io_format", "NCHW") | |||||
| self.is_ge = context.get_context("enable_ge") | self.is_ge = context.get_context("enable_ge") | ||||
| def infer_shape(self, input_x): | def infer_shape(self, input_x): | ||||
| @@ -213,7 +212,6 @@ class Quant(PrimitiveWithInfer): | |||||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | ||||
| self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"], | self.round_mode = validator.check_string(round_mode, ["Round", "Floor", "Ceil", "Trunc"], | ||||
| "round_mode", self.name) | "round_mode", self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape): | def infer_shape(self, x_shape): | ||||
| return x_shape | return x_shape | ||||
| @@ -265,7 +263,6 @@ class Dequant(PrimitiveWithInfer): | |||||
| self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | self.sqrt_mode = validator.check_value_type("sqrt_mode", sqrt_mode, [bool], self.name) | ||||
| self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | self.relu_flag = validator.check_value_type("relu_flag", relu_flag, [bool], self.name) | ||||
| self.add_prim_attr("dtype", mstype.float16) | self.add_prim_attr("dtype", mstype.float16) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, deq_scale_shape): | def infer_shape(self, x_shape, deq_scale_shape): | ||||
| return x_shape | return x_shape | ||||
| @@ -7222,7 +7222,6 @@ class BasicLSTMCell(PrimitiveWithInfer): | |||||
| self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name) | ||||
| self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name) | ||||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape): | ||||
| validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name) | validator.check_int(len(x_shape), 2, Rel.EQ, "x rank", self.name) | ||||
| @@ -7373,7 +7372,6 @@ class DynamicRNN(PrimitiveWithInfer): | |||||
| self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name) | self.cell_type = validator.check_string(cell_type, ['LSTM'], "cell_type", self.name) | ||||
| self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name) | ||||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): | def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape): | ||||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name) | validator.check_int(len(x_shape), 3, Rel.EQ, "x_shape", self.name) | ||||
| @@ -7533,7 +7531,6 @@ class DynamicGRUV2(PrimitiveWithInfer): | |||||
| self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | self.activation = validator.check_string(activation, ['tanh'], "activation", self.name) | ||||
| self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name) | ||||
| self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape): | ||||
| validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name) | ||||
| @@ -8024,7 +8021,6 @@ class Conv3DBackpropInput(PrimitiveWithInfer): | |||||
| self.add_prim_attr('groups', self.group) | self.add_prim_attr('groups', self.group) | ||||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| self.add_prim_attr('io_format', self.format) | |||||
| def __infer__(self, w, doutput, x_size): | def __infer__(self, w, doutput, x_size): | ||||
| validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name) | validator.check_equal_int(len(w['shape']), 5, 'The dimension of weight ', self.name) | ||||
| @@ -8202,7 +8198,6 @@ class Conv3DTranspose(PrimitiveWithInfer): | |||||
| self.add_prim_attr('groups', self.group) | self.add_prim_attr('groups', self.group) | ||||
| self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name) | ||||
| self.add_prim_attr('data_format', self.format) | self.add_prim_attr('data_format', self.format) | ||||
| self.add_prim_attr('io_format', self.format) | |||||
| self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name, | self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name, | ||||
| allow_five=True, ret_five=True, greater_zero=False) | allow_five=True, ret_five=True, greater_zero=False) | ||||