|
|
|
@@ -281,9 +281,9 @@ class THOR_GPU(Optimizer): |
|
|
|
degree = _get_device_num() |
|
|
|
if self.conv_layer_count > 0: |
|
|
|
if not split_indices: |
|
|
|
self.split_indices = split_indices |
|
|
|
else: |
|
|
|
self.split_indices = [len(self.matrix_A) - 1] |
|
|
|
else: |
|
|
|
self.split_indices = split_indices |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2") |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4") |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") |
|
|
|
@@ -294,9 +294,9 @@ class THOR_GPU(Optimizer): |
|
|
|
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8) |
|
|
|
else: |
|
|
|
if not split_indices: |
|
|
|
self.split_indices = split_indices |
|
|
|
else: |
|
|
|
self.split_indices = [len(self.params) - 1] |
|
|
|
else: |
|
|
|
self.split_indices = split_indices |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum3") |
|
|
|
self.grad_reducer_g = DistributedGradReducer(self.params, mean, degree, fusion_type=3) |
|
|
|
|
|
|
|
@@ -595,9 +595,9 @@ class THOR_Ascend(Optimizer): |
|
|
|
degree = _get_device_num() |
|
|
|
if self.conv_layer_count > 0: |
|
|
|
if not split_indices: |
|
|
|
self.split_indices = split_indices |
|
|
|
else: |
|
|
|
self.split_indices = [len(self.matrix_A) - 1] |
|
|
|
else: |
|
|
|
self.split_indices = split_indices |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2") |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4") |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6") |
|
|
|
@@ -608,9 +608,9 @@ class THOR_Ascend(Optimizer): |
|
|
|
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8) |
|
|
|
else: |
|
|
|
if not split_indices: |
|
|
|
self.split_indices = split_indices |
|
|
|
else: |
|
|
|
self.split_indices = [len(self.params) - 1] |
|
|
|
else: |
|
|
|
self.split_indices = split_indices |
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum3") |
|
|
|
self.grad_reducer_g = DistributedGradReducer(self.params, mean, degree, fusion_type=3) |
|
|
|
|
|
|
|
|