|
|
|
@@ -46,12 +46,13 @@ def _init_allreduce_operators(length, split_indices): |
|
|
|
return op_list |
|
|
|
|
|
|
|
|
|
|
|
def _init_allreduce_operators_by_parameters(parameters): |
|
|
|
def _init_allreduce_operators_by_parameters(parameters, split_indices): |
|
|
|
""" initialize allreduce communication operators by parameters""" |
|
|
|
op_list = () |
|
|
|
param_fusion = False |
|
|
|
last_comm_fusion = None |
|
|
|
first_parameter_flag = True |
|
|
|
index = 1 |
|
|
|
for parameter in parameters: |
|
|
|
comm_fusion = parameter.comm_fusion |
|
|
|
if first_parameter_flag: |
|
|
|
@@ -63,10 +64,15 @@ def _init_allreduce_operators_by_parameters(parameters): |
|
|
|
last_comm_fusion = comm_fusion |
|
|
|
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) |
|
|
|
op.add_prim_attr('fusion', comm_fusion) |
|
|
|
op.add_prim_attr('index', comm_fusion) |
|
|
|
op.add_prim_attr('index', index) |
|
|
|
index += 1 |
|
|
|
op_list = op_list + (op,) |
|
|
|
if not param_fusion: |
|
|
|
op_list = () |
|
|
|
if split_indices and split_indices[-1] == len(parameters) - 1: |
|
|
|
op_list = _init_allreduce_operators(len(parameters), split_indices) |
|
|
|
param_fusion = True |
|
|
|
else: |
|
|
|
op_list = () |
|
|
|
return op_list, param_fusion |
|
|
|
|
|
|
|
|
|
|
|
@@ -385,7 +391,7 @@ class DistributedGradReducer(Cell): |
|
|
|
self.op_list = _init_allreduce_operators(len(parameters), split_indices) |
|
|
|
else: |
|
|
|
self.split_fusion = True |
|
|
|
self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters) |
|
|
|
self.op_list, param_fusion = _init_allreduce_operators_by_parameters(parameters, split_indices) |
|
|
|
if not param_fusion: |
|
|
|
self.split_fusion = False |
|
|
|
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type) |
|
|
|
|