|
|
|
@@ -69,16 +69,17 @@ class _BatchNorm(Cell): |
|
|
|
self.beta = Parameter(initializer( |
|
|
|
beta_init, num_features), name="beta", requires_grad=affine) |
|
|
|
self.group = check_int_positive(group) |
|
|
|
self.rank_id = get_rank() |
|
|
|
self.rank_size = get_local_rank_size() |
|
|
|
self.device_list = [i for i in range(0, self.rank_size)] |
|
|
|
self.rank_list = self.list_group(self.device_list, self.group) |
|
|
|
self.rank_list_idx = len(self.rank_list) |
|
|
|
for i in range(self.rank_list_idx): |
|
|
|
if self.rank_id in self.rank_list[i] and self.group != 1: |
|
|
|
self.is_global = True |
|
|
|
management.create_group('group' + str(i), self.rank_list[i]) |
|
|
|
self.all_reduce = _GlobalBNHelper('group' + str(i)) |
|
|
|
if self.group != 1: |
|
|
|
self.rank_id = get_rank() |
|
|
|
self.rank_size = get_local_rank_size() |
|
|
|
self.device_list = [i for i in range(0, self.rank_size)] |
|
|
|
self.rank_list = self.list_group(self.device_list, self.group) |
|
|
|
self.rank_list_idx = len(self.rank_list) |
|
|
|
for i in range(self.rank_list_idx): |
|
|
|
if self.rank_id in self.rank_list[i] and self.group != 1: |
|
|
|
self.is_global = True |
|
|
|
management.create_group('group' + str(i), self.rank_list[i]) |
|
|
|
self.all_reduce = _GlobalBNHelper('group' + str(i)) |
|
|
|
self.shape = P.Shape() |
|
|
|
self.reduce_mean = P.ReduceMean() |
|
|
|
self.square = P.Square() |
|
|
|
@@ -110,7 +111,8 @@ class _BatchNorm(Cell): |
|
|
|
|
|
|
|
def list_group(self, world_rank, group_size): |
|
|
|
if group_size > get_local_rank_size(): |
|
|
|
raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format(group_size, get_local_rank_size())) |
|
|
|
raise ValueError("group size can not be greater than local rank size, group size is {}, local_rank_size is {}".format( |
|
|
|
group_size, get_local_rank_size())) |
|
|
|
if len(world_rank) % group_size != 0: |
|
|
|
raise ValueError("please make your group size correct.") |
|
|
|
world_rank_list = zip(*(iter(world_rank),) *group_size) |
|
|
|
@@ -322,6 +324,9 @@ class GlobalBatchNorm(_BatchNorm): |
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) |
|
|
|
>>> global_bn_op(input) |
|
|
|
""" |
|
|
|
def _check_data_dim(self, x): |
|
|
|
if x.dim == 0: |
|
|
|
pass |
|
|
|
|
|
|
|
class LayerNorm(Cell): |
|
|
|
r""" |
|
|
|
|