|
|
|
@@ -527,6 +527,10 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no |
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() |
|
|
|
<< " does not match the Prim: " << prim->name(); |
|
|
|
} |
|
|
|
|
|
|
|
// Needed by rec_parser |
|
|
|
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId()); |
|
|
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(current_op_ptr); |
|
|
|
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() |
|
|
|
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() |
|
|
|
@@ -1124,6 +1128,27 @@ CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) { |
|
|
|
size_t iter_ops = 0; |
|
|
|
for (auto op : entire_costgraph->GetOperators()) { |
|
|
|
if (op->name() == name) { |
|
|
|
break; |
|
|
|
} |
|
|
|
iter_ops = iter_ops + 1; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list(); |
|
|
|
for (size_t i = 0; i < input_tensor_names.size(); i++) { |
|
|
|
for (size_t j = 0; j < input_tensor_names[i].size(); j++) { |
|
|
|
if (input_tensor_names[i][j] == uniqueid) { |
|
|
|
input_tensor_names[i][j] = input_tensor_names[iter_ops][0]; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
entire_costgraph->set_inputs_tensor_name_list(input_tensor_names); |
|
|
|
} |
|
|
|
|
|
|
|
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) { |
|
|
|
if (CostModelContext::GetInstance()->is_multi_subgraphs()) { |
|
|
|
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) { |
|
|
|
|