Browse Source

!4054 [AutoParallel] Use uniqueid to manage input tensors

Merge pull request !4054 from Chong/wd
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
6657adfaef
3 changed files with 30 additions and 0 deletions
  1. +3
    -0
      mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h
  2. +25
    -0
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
  3. +2
    -0
      mindspore/ccsrc/frontend/parallel/step_auto_parallel.h

+ 3
- 0
mindspore/ccsrc/frontend/parallel/auto_parallel/graph_costmodel.h View File

@@ -197,6 +197,9 @@ class CostGraph {
inputs_tensor_name_list_.push_back(inputs_tensor_name);
}
const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; }
void set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> &inputs_tensor_name_list) {
inputs_tensor_name_list_ = inputs_tensor_name_list;
}
void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
auto ret = tuple_getitem_list_.insert(tuple_getitem);
if (ret.second == false) {


+ 25
- 0
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc View File

@@ -527,6 +527,10 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name();
}

// Needed by rec_parser
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());

cnode->set_user_data<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
@@ -1124,6 +1128,27 @@ CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim
return nullptr;
}

void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
size_t iter_ops = 0;
for (auto op : entire_costgraph->GetOperators()) {
if (op->name() == name) {
break;
}
iter_ops = iter_ops + 1;
}

std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
for (size_t i = 0; i < input_tensor_names.size(); i++) {
for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
if (input_tensor_names[i][j] == uniqueid) {
input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
}
}
}

entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
}

Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {


+ 2
- 0
mindspore/ccsrc/frontend/parallel/step_auto_parallel.h View File

@@ -59,6 +59,8 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
std::vector<std::vector<std::string>> input_tensor_names);

CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node);

void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid);
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_STEP_AUTO_PARALLEL_H_

Loading…
Cancel
Save