Browse Source

PyNative only support hccl_world_group

tags/v1.1.0
caifubi 5 years ago
parent
commit
ac061052a4
1 changed files with 9 additions and 1 deletions
  1. +9
    -1
      mindspore/ops/operations/comm_ops.py

+ 9
- 1
mindspore/ops/operations/comm_ops.py View File

@@ -21,7 +21,7 @@ from ..._checkparam import Rel
from ...communication.management import get_rank, get_group_size, GlobalComm, _get_group
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register
from ...common.api import context

class ReduceOp:
"""
@@ -45,6 +45,12 @@ class ReduceOp:

target_dtypes = (mstype.int8, mstype.int32, mstype.float16, mstype.float32)

def check_hcom_group_valid(group):
if context.get_context("mode") == context.PYNATIVE_MODE and \
context.get_context("device_target") == "Ascend" and \
group != GlobalComm.WORLD_COMM_GROUP:
raise RuntimeError("Only hccl_world_group is supported in Pynative mode, but got {}".format(group))

class AllReduce(PrimitiveWithInfer):
"""
Reduces the tensor data across all devices in such a way that all devices will get the same final result.
@@ -103,6 +109,7 @@ class AllReduce(PrimitiveWithInfer):
raise TypeError("The operation of AllReduce should be str.")
if not isinstance(_get_group(group), str):
raise TypeError("The group of AllReduce should be str.")
check_hcom_group_valid(group)
self.op = op
self.add_prim_attr('group', _get_group(group))
self.add_prim_attr('fusion', 0)
@@ -407,6 +414,7 @@ class Broadcast(PrimitiveWithInfer):
def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP):
validator.check_value_type('root_rank', root_rank, (int,), self.name)
validator.check_value_type('group', _get_group(group), (str,), self.name)
check_hcom_group_valid(group)
self.add_prim_attr('group', _get_group(group))

def infer_shape(self, x_shape):


Loading…
Cancel
Save