|
|
|
@@ -218,8 +218,6 @@ class HostAllGather(PrimitiveWithInfer): |
|
|
|
validator.check_int_range("rank_id", r, 0, 7, Rel.INC_BOTH, self.name) |
|
|
|
validator.check_value_type("rank_id", r, (int,), self.name) |
|
|
|
self.group_size = len(group) |
|
|
|
self.rank = get_rank() |
|
|
|
validator.check('rank', self.rank, 'group', self.group, Rel.IN, self.name) |
|
|
|
self.add_prim_attr('group', group) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
|