|
|
|
@@ -158,7 +158,10 @@ Status ReduceMethod::InferForwardCommunication() { |
|
|
|
size_t size = stra.size(); |
|
|
|
// judge if the reduce dim is partitioned. |
|
|
|
Shape group_creat_map; |
|
|
|
if (dev_matrix_shape_.size() > size) { |
|
|
|
|
|
|
|
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, |
|
|
|
// it need to handle the first dimention of map. |
|
|
|
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { |
|
|
|
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); |
|
|
|
} |
|
|
|
for (size_t index = 0; index < size; ++index) { |
|
|
|
@@ -169,6 +172,18 @@ Status ReduceMethod::InferForwardCommunication() { |
|
|
|
} |
|
|
|
group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); |
|
|
|
} |
|
|
|
|
|
|
|
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix, |
|
|
|
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map. |
|
|
|
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) { |
|
|
|
for (auto &ele : group_creat_map) { |
|
|
|
if (ele == MAP_NONE) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ele += 1; |
|
|
|
} |
|
|
|
group_creat_map.push_back(0); |
|
|
|
} |
|
|
|
std::vector<Group> forward_group; |
|
|
|
if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; |
|
|
|
@@ -220,9 +235,13 @@ Status ReduceMeanInfo::InferForwardCommunication() { |
|
|
|
size_t size = stra.size(); |
|
|
|
// judge if the reduce dim is partitioned. |
|
|
|
Shape group_creat_map; |
|
|
|
if (dev_matrix_shape_.size() > size) { |
|
|
|
|
|
|
|
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, |
|
|
|
// it need to handle the first dimention of map. |
|
|
|
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { |
|
|
|
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t index = 0; index < size; ++index) { |
|
|
|
auto pos = |
|
|
|
std::find_if(dim_list.begin(), dim_list.end(), [index](const int32_t &dim) { return SizeToInt(index) == dim; }); |
|
|
|
@@ -231,6 +250,19 @@ Status ReduceMeanInfo::InferForwardCommunication() { |
|
|
|
} |
|
|
|
group_creat_map.push_back(SizeToInt(size) - SizeToInt(index) - 1); |
|
|
|
} |
|
|
|
|
|
|
|
// if repeated calculation and the repeated_calc_num_ insert to the last dimension of dev matrix, |
|
|
|
// it need to handle the group_creat_map and insert the 0 to the last dimension of the group_creat_map. |
|
|
|
if (repeated_num_in_dev_matrix_right_ && (repeated_calc_num_ > 1)) { |
|
|
|
for (auto &ele : group_creat_map) { |
|
|
|
if (ele == MAP_NONE) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ele += 1; |
|
|
|
} |
|
|
|
group_creat_map.push_back(0); |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<Group> forward_group; |
|
|
|
if (CreateGroupByTensorMap(group_creat_map, &forward_group) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << ": InferForwardCommunication group failed."; |
|
|
|
|