From: @zhupuxu Reviewed-by: @zhunaipan,@kingxian Signed-off-by: @kingxiantags/v1.2.0-rc1
| @@ -2171,7 +2171,6 @@ class Concat(PrimitiveWithInfer): | |||||
| x_shp = input_x['shape'] | x_shp = input_x['shape'] | ||||
| x_type = input_x['dtype'] | x_type = input_x['dtype'] | ||||
| _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) | _, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) | ||||
| self.add_prim_attr('T', x_type[0].element_type()) | |||||
| self.add_prim_attr('inputNums', len(x_shp)) | self.add_prim_attr('inputNums', len(x_shp)) | ||||
| ret_shp = x_shp[0].copy() | ret_shp = x_shp[0].copy() | ||||
| value = None | value = None | ||||
| @@ -2662,7 +2661,6 @@ class Select(PrimitiveWithInfer): | |||||
| return x_shape | return x_shape | ||||
| def infer_dtype(self, cond_type, x_type, y_type): | def infer_dtype(self, cond_type, x_type, y_type): | ||||
| self.add_prim_attr('T', x_type) | |||||
| validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | validator.check_subclass("x_type", x_type, mstype.tensor, self.name) | ||||
| validator.check_subclass("y_type", y_type, mstype.tensor, self.name) | validator.check_subclass("y_type", y_type, mstype.tensor, self.name) | ||||
| validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name) | validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name) | ||||
| @@ -316,7 +316,6 @@ class _Reduce(PrimitiveWithInfer): | |||||
| """Initialize Reduce""" | """Initialize Reduce""" | ||||
| validator.check_value_type('keep_dims', keep_dims, [bool], self.name) | validator.check_value_type('keep_dims', keep_dims, [bool], self.name) | ||||
| self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def __call__(self, x, axis=()): | def __call__(self, x, axis=()): | ||||
| args = [x, axis] | args = [x, axis] | ||||
| @@ -756,7 +755,6 @@ class MatMul(PrimitiveWithCheck): | |||||
| cls_name = self.name | cls_name = self.name | ||||
| validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | validator.check_value_type("transpose_a", transpose_a, [bool], cls_name) | ||||
| validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | validator.check_value_type("transpose_b", transpose_b, [bool], cls_name) | ||||
| self.add_prim_attr("io_format", "ND") | |||||
| def check_shape_size(self, x1, x2): | def check_shape_size(self, x1, x2): | ||||
| if len(x1) != 2 or len(x2) != 2: | if len(x1) != 2 or len(x2) != 2: | ||||