|
|
|
@@ -21,7 +21,7 @@ from ..._checkparam import Rel |
|
|
|
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register |
|
|
|
|
|
|
|
from ...common.api import context |
|
|
|
|
|
|
|
class ReduceOp: |
|
|
|
""" |
|
|
|
@@ -45,6 +45,12 @@ class ReduceOp: |
|
|
|
|
|
|
|
target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) |
|
|
|
|
|
|
|
def check_hcom_group_valid(group): |
|
|
|
if context.get_context("mode") == context.PYNATIVE_MODE and \ |
|
|
|
context.get_context("device_target") == "Ascend" and \ |
|
|
|
group != GlobalComm.WORLD_COMM_GROUP: |
|
|
|
raise RuntimeError("Only hccl_world_group is supported in Pynative mode, but got {}".format(group)) |
|
|
|
|
|
|
|
class AllReduce(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Reduces the tensor data across all devices in such a way that all devices will get the same final result. |
|
|
|
@@ -103,6 +109,7 @@ class AllReduce(PrimitiveWithInfer): |
|
|
|
raise TypeError("The operation of AllReduce should be str.") |
|
|
|
if not isinstance(_get_group(group), str): |
|
|
|
raise TypeError("The group of AllReduce should be str.") |
|
|
|
check_hcom_group_valid(group) |
|
|
|
self.op = op |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
self.add_prim_attr('fusion', 0) |
|
|
|
@@ -420,6 +427,7 @@ class Broadcast(PrimitiveWithInfer): |
|
|
|
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): |
|
|
|
validator.check_value_type('root_rank', root_rank, (int,), self.name) |
|
|
|
validator.check_value_type('group', _get_group(group), (str,), self.name) |
|
|
|
check_hcom_group_valid(group) |
|
|
|
self.add_prim_attr('group', _get_group(group)) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
|