Browse Source

!8455 Add MirrorsOp for UnsortedSegmentOps and Remove Useless Operations for Tensor Map

From: @huangxinjing
Reviewed-by: @yangzhenzhang,@yao_yf,@stsuteng,@zh_qh
Signed-off-by: @stsuteng
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
42ace59a5e
2 changed files with 30 additions and 4 deletions
  1. +29
    -3
      mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc
  2. +1
    -1
      mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h

+ 29
- 3
mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.cc View File

@@ -103,6 +103,35 @@ Status UnsortedSegmentOpInfo::InferDevMatrixShape() {
return SUCCESS;
}

Status UnsortedSegmentOpInfo::InferMirrorOps() {
mirror_ops_.clear();

// Only the first input could be parameter.
Shape tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group failed.";
return FAILED;
}

OperatorVector mirror_op;
OperatorVector op_for_value;
OperatorVector op_for_value2;
if (group.empty()) {
MS_LOG(INFO) << name_ << " : The mirror ops is empty.";
return SUCCESS;
} else {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
mirror_ops_.push_back(mirror_op);
mirror_ops_.push_back(op_for_value);
mirror_ops_.push_back(op_for_value2);
std::string group_name = group[0].name();
MS_LOG(INFO) << name_ << " : Create the mirror ops success, the group name is " << group_name;
}

return SUCCESS;
}

// As the op converts the vector x1,x2,x3...,xr -> number of segments, xn,..,xr
// the dimension x1,x2,x3,..,xn is eliminated
// suppose the strategy of the inputs is (a,b,c,d), (a,b)
@@ -221,9 +250,6 @@ Status UnsortedSegmentOpInfo::InferForwardCommunication() {
std::vector<Group> group_list;
Shape tmp_group_tensor_map = outputs_tensor_map_.at(0);
if (repeated_calc_num_ > 1) {
for (size_t i = 1; i < tmp_group_tensor_map.size(); ++i) {
tmp_group_tensor_map[i] += 1;
}
tmp_group_tensor_map.push_back(0);
}
if (CreateGroupByTensorMap(tmp_group_tensor_map, &group_list) != SUCCESS) {


+ 1
- 1
mindspore/ccsrc/frontend/parallel/ops_info/unsorted_segment_op_info.h View File

@@ -47,7 +47,7 @@ class UnsortedSegmentOpInfo : public OperatorInfo {
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferForwardCommunication() override;
Status InferMirrorOps() override { return SUCCESS; }
Status InferMirrorOps() override;
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;


Loading…
Cancel
Save