|
|
|
@@ -2167,7 +2167,6 @@ class Concat(PrimitiveWithInfer): |
|
|
|
x_shp = input_x['shape'] |
|
|
|
x_type = input_x['dtype'] |
|
|
|
_, all_shp, _ = get_concat_offset(x_shp, x_type, axis, self.name) |
|
|
|
self.add_prim_attr('T', x_type[0].element_type()) |
|
|
|
self.add_prim_attr('inputNums', len(x_shp)) |
|
|
|
ret_shp = x_shp[0].copy() |
|
|
|
value = None |
|
|
|
@@ -2616,7 +2615,6 @@ class Select(PrimitiveWithInfer): |
|
|
|
return x_shape |
|
|
|
|
|
|
|
def infer_dtype(self, cond_type, x_type, y_type): |
|
|
|
self.add_prim_attr('T', x_type) |
|
|
|
validator.check_subclass("x_type", x_type, mstype.tensor, self.name) |
|
|
|
validator.check_subclass("y_type", y_type, mstype.tensor, self.name) |
|
|
|
validator.check_tensor_dtype_valid("cond", cond_type, [mstype.bool_], self.name) |
|
|
|
|