|
|
|
@@ -156,6 +156,7 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
validator.check_integer("x shape", len(x_shape), 0, Rel.GT, self.name) |
|
|
|
x_shape[0] = x_shape[0] * self.rank_size |
|
|
|
return x_shape |
|
|
|
|
|
|
|
|