|
|
|
@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype |
|
|
|
import mindspore.context as context |
|
|
|
from mindspore._checkparam import check_bool, check_typename |
|
|
|
from mindspore._extends import cell_attr_register |
|
|
|
from mindspore.communication.management import get_local_rank_size, get_rank |
|
|
|
from mindspore.communication.management import get_group_size, get_rank |
|
|
|
from mindspore.communication import management |
|
|
|
from mindspore._checkparam import check_int_positive |
|
|
|
from ..cell import Cell |
|
|
|
@@ -71,7 +71,7 @@ class _BatchNorm(Cell): |
|
|
|
self.group = check_int_positive(group) |
|
|
|
if self.group != 1: |
|
|
|
self.rank_id = get_rank() |
|
|
|
self.rank_size = get_local_rank_size() |
|
|
|
self.rank_size = get_group_size() |
|
|
|
self.device_list = [i for i in range(0, self.rank_size)] |
|
|
|
self.rank_list = self.list_group(self.device_list, self.group) |
|
|
|
self.rank_list_idx = len(self.rank_list) |
|
|
|
|