|
|
|
@@ -103,6 +103,35 @@ Status UnsortedSegmentOpInfo::InferDevMatrixShape() { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status UnsortedSegmentOpInfo::InferMirrorOps() { |
|
|
|
mirror_ops_.clear(); |
|
|
|
|
|
|
|
// Only the first input could be parameter. |
|
|
|
Shape tensor_map = inputs_tensor_map_[0]; |
|
|
|
std::vector<Group> group; |
|
|
|
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) { |
|
|
|
MS_LOG(ERROR) << name_ << " : Create group failed."; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
OperatorVector mirror_op; |
|
|
|
OperatorVector op_for_value; |
|
|
|
OperatorVector op_for_value2; |
|
|
|
if (group.empty()) { |
|
|
|
MS_LOG(INFO) << name_ << " : The mirror ops is empty."; |
|
|
|
return SUCCESS; |
|
|
|
} else { |
|
|
|
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum()); |
|
|
|
mirror_ops_.push_back(mirror_op); |
|
|
|
mirror_ops_.push_back(op_for_value); |
|
|
|
mirror_ops_.push_back(op_for_value2); |
|
|
|
std::string group_name = group[0].name(); |
|
|
|
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name; |
|
|
|
} |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
// As the op converts the vector x1,x2,x3...,xr -> number of segments, xn,..,xr |
|
|
|
// the dimension x1,x2,x3,..,xn is eliminated |
|
|
|
// suppose the strategy of the inputs is (a,b,c,d), (a,b) |
|
|
|
@@ -221,9 +250,6 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() { |
|
|
|
std::vector<Group> group_list; |
|
|
|
Shape tmp_group_tensor_map = outputs_tensor_map_.at(0); |
|
|
|
if (repeated_calc_num_ > 1) { |
|
|
|
for (size_t i = 1; i < tmp_group_tensor_map.size(); ++i) { |
|
|
|
tmp_group_tensor_map[i] += 1; |
|
|
|
} |
|
|
|
tmp_group_tensor_map.push_back(0); |
|
|
|
} |
|
|
|
if (CreateGroupByTensorMap(tmp_group_tensor_map, &group_list) != SUCCESS) { |
|
|
|
|