|
|
|
@@ -17,7 +17,7 @@ |
|
|
|
|
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, get_group |
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register |
|
|
|
|
|
|
|
@@ -88,10 +88,10 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
raise TypeError("The operation of AllReduce should be str.") |
|
|
|
if op == ReduceOp.PROD: |
|
|
|
raise RuntimeError("The operation of AllReduce 'prod' is not supported yet.") |
|
|
|
if not isinstance(get_group(group), str): |
|
|
|
if not isinstance(_get_group(group), str): |
|
|
|
raise TypeError("The group of AllReduce should be str.") |
|
|
|
self.op = op |
|
|
|
self.add_prim_attr('group', get_group(group)) |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
self.add_prim_attr('fusion', 0) |
|
|
|
|
|
|
|
def vm_impl(self, x): |
|
|
|
@@ -149,12 +149,12 @@ class AllGather(PrimitiveWithInfer): |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): |
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name) |
|
|
|
self.rank = get_rank(get_group(group)) |
|
|
|
self.rank_size = get_group_size(get_group(group)) |
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name) |
|
|
|
self.rank = get_rank(_get_group(group)) |
|
|
|
self.rank_size = get_group_size(_get_group(group)) |
|
|
|
validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) |
|
|
|
self.add_prim_attr('rank_size', self.rank_size) |
|
|
|
self.add_prim_attr('group', get_group(group)) |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
x_shape[0] = x_shape[0] * self.rank_size |
|
|
|
@@ -205,11 +205,11 @@ class ReduceScatter(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): |
|
|
|
validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) |
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name) |
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name) |
|
|
|
self.op = op |
|
|
|
self.rank_size = get_group_size(get_group(group)) |
|
|
|
self.rank_size = get_group_size(_get_group(group)) |
|
|
|
self.add_prim_attr('rank_size', self.rank_size) |
|
|
|
self.add_prim_attr('group', get_group(group)) |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
if x_shape[0] % self.rank_size != 0: |
|
|
|
@@ -268,8 +268,8 @@ class Broadcast(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): |
|
|
|
validator.check_value_type('root_rank', root_rank, (int,), self.name) |
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name) |
|
|
|
self.add_prim_attr('group', get_group(group)) |
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name) |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
return x_shape |
|
|
|
@@ -306,11 +306,11 @@ class _AlltoAll(PrimitiveWithInfer): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): |
|
|
|
"""init AlltoAll""" |
|
|
|
validator.check_value_type('group', get_group(group), (str,), self.name) |
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name) |
|
|
|
self.split_count = split_count |
|
|
|
self.split_dim = split_dim |
|
|
|
self.concat_dim = concat_dim |
|
|
|
self.add_prim_attr('group', get_group(group)) |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
x_shape[self.concat_dim] = x_shape[self.concat_dim] * self.split_count |
|
|
|
|