|
|
|
@@ -81,7 +81,7 @@ class _BatchNorm(Cell): |
|
|
|
self.parallel_mode = context.get_auto_parallel_context("parallel_mode") |
|
|
|
global SYNC_BN_GROUP_NAME |
|
|
|
# for GlobalBatchNorm |
|
|
|
if self.group_device_num != 1 and self.parallel_mode != context.ParallelMode.STAND_ALONE: |
|
|
|
if self.group_device_num != 1: |
|
|
|
self.rank_id = get_rank() |
|
|
|
self.rank_size = get_group_size() |
|
|
|
self.device_list = [i for i in range(0, self.rank_size)] |
|
|
|
|