|
|
|
@@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer): |
|
|
|
axis = self.axis |
|
|
|
x_shp = input_x['shape'] |
|
|
|
x_type = input_x['dtype'] |
|
|
|
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis) |
|
|
|
_, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) |
|
|
|
self.add_prim_attr('T', x_type[0].element_type()) |
|
|
|
self.add_prim_attr('inputNums', len(x_shp)) |
|
|
|
ret_shp = x_shp[0].copy() |
|
|
|
|