Browse Source

add global batch normalization

tags/v0.2.0-alpha
zhaojichen 5 years ago
parent
commit
616b9ea394
2 changed files with 10 additions and 2 deletions
  1. +2
    -2
      mindspore/nn/layer/normalization.py
  2. +8
    -0
      tests/ut/python/hccl_test/manage/api.py

+ 2
- 2
mindspore/nn/layer/normalization.py View File

@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore._checkparam import check_bool, check_typename
from mindspore._extends import cell_attr_register
from mindspore.communication.management import get_local_rank_size, get_rank
from mindspore.communication.management import get_group_size, get_rank
from mindspore.communication import management
from mindspore._checkparam import check_int_positive
from ..cell import Cell
@@ -71,7 +71,7 @@ class _BatchNorm(Cell):
self.group = check_int_positive(group)
if self.group != 1:
self.rank_id = get_rank()
self.rank_size = get_local_rank_size()
self.rank_size = get_group_size()
self.device_list = [i for i in range(0, self.rank_size)]
self.rank_list = self.list_group(self.device_list, self.group)
self.rank_list_idx = len(self.rank_list)


+ 8
- 0
tests/ut/python/hccl_test/manage/api.py View File

@@ -65,6 +65,14 @@ def get_rank_size(group=None):
return int(group.split("-")[0])
raise ValueError

def get_group_size(group=None):
hccl = Hccl()
if group is None:
return hccl.rank_size
if isinstance(group, str):
return int(group.split("-")[0])
raise ValueError

# pylint: disable=unused-argument
def get_world_rank_from_group_rank(group, group_rank_id):
return group_rank_id


Loading…
Cancel
Save