Merge pull request !4897 from wenchunjiang/adapte_to_ifbyif_noinlinetags/v0.7.0-beta
| @@ -30,6 +30,10 @@ namespace { | |||||
| void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { | void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { | ||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | MS_EXCEPTION_IF_NULL(cnode_ptr); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) || | |||||
| AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) { | |||||
| return; | |||||
| } | |||||
| std::vector<AnfNodePtr> plant_inputs; | std::vector<AnfNodePtr> plant_inputs; | ||||
| std::vector<int> dyn_input_sizes; | std::vector<int> dyn_input_sizes; | ||||
| plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); | plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); | ||||
| @@ -26,22 +26,19 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | namespace { | ||||
| AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf, | |||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) { | |||||
| AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf) { | |||||
| MS_EXCEPTION_IF_NULL(tuple_anf); | MS_EXCEPTION_IF_NULL(tuple_anf); | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(transed_nodes); | |||||
| if (!AnfAlgo::IsTupleOutput(tuple_anf)) { | if (!AnfAlgo::IsTupleOutput(tuple_anf)) { | ||||
| return tuple_anf; | return tuple_anf; | ||||
| } | } | ||||
| auto transed_node_it = transed_nodes->find(tuple_anf); | |||||
| if (transed_node_it != transed_nodes->end()) { | |||||
| return transed_node_it->second; | |||||
| } | |||||
| auto kernel_graph = graph->cast<KernelGraphPtr>(); | auto kernel_graph = graph->cast<KernelGraphPtr>(); | ||||
| if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) { | |||||
| return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf); | |||||
| } | |||||
| auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); | auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); | ||||
| (*transed_nodes)[tuple_anf] = make_tuple; | |||||
| kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple); | |||||
| // replace graph inputs if input is a parameter | // replace graph inputs if input is a parameter | ||||
| kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); | kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); | ||||
| return make_tuple; | return make_tuple; | ||||
| @@ -61,7 +58,6 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||||
| } | } | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes; | |||||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { | if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { | ||||
| auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode); | auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode); | ||||
| MS_EXCEPTION_IF_NULL(real_input); | MS_EXCEPTION_IF_NULL(real_input); | ||||
| @@ -77,7 +73,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||||
| const auto &input = cnode->inputs()[i]; | const auto &input = cnode->inputs()[i]; | ||||
| if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && | if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && | ||||
| !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { | !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { | ||||
| cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes)); | |||||
| cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input)); | |||||
| cnode_input_changed = true; | cnode_input_changed = true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod | |||||
| } | } | ||||
| } | } | ||||
| void AscendControlParser::AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, | |||||
| const std::vector<AnfNodePtr> orig_inputs) { | |||||
| std::vector<AnfNodePtr> make_tuple_inputs = { | |||||
| mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||||
| std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs)); | |||||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||||
| InsertDependToGraph(graph, NOT_NULL(make_tuple)); | |||||
| } | |||||
| void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | ||||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | const NotNull<std::set<KernelGraphPtr> *> memo) { | ||||
| MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); | MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); | ||||
| // 1 get kernel graph | // 1 get kernel graph | ||||
| const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs(); | |||||
| std::vector<AnfNodePtr> origin_inputs = cur_node->inputs(); | |||||
| if (kCNodeCallArg >= origin_inputs.size()) { | if (kCNodeCallArg >= origin_inputs.size()) { | ||||
| MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); | MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); | ||||
| } | } | ||||
| @@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||||
| cur_node->set_inputs(new_inputs); | cur_node->set_inputs(new_inputs); | ||||
| cur_node->set_abstract(nullptr); | cur_node->set_abstract(nullptr); | ||||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get()); | AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get()); | ||||
| origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end()); | |||||
| AttachOriginalInputsToGraph(kg, origin_inputs); | |||||
| MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); | MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); | ||||
| } | } | ||||
| @@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | ||||
| // 3.1 branch kernel graph and args | // 3.1 branch kernel graph and args | ||||
| KernelGraphPtr branch_fg; | KernelGraphPtr branch_fg; | ||||
| std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| std::vector<AnfNodePtr> origin_inputs; | |||||
| std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| child_graphs.push_back(branch_fg); | child_graphs.push_back(branch_fg); | ||||
| // 3.2 recurse sub graph | // 3.2 recurse sub graph | ||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | ||||
| new_switch_inputs.push_back(branch_label); | new_switch_inputs.push_back(branch_label); | ||||
| AttachOriginalInputsToGraph(kg, origin_inputs); | |||||
| } | } | ||||
| std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | ||||
| @@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | for (size_t i = 0; i < branch_partial.size(); ++i) { | ||||
| // 3.1 branch kernel graph and args | // 3.1 branch kernel graph and args | ||||
| KernelGraphPtr branch_fg; | KernelGraphPtr branch_fg; | ||||
| std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| std::vector<AnfNodePtr> origin_inputs; | |||||
| std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||||
| child_graphs.push_back(branch_fg); | child_graphs.push_back(branch_fg); | ||||
| // 3.2 recurse sub graph | // 3.2 recurse sub graph | ||||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | ||||
| new_switch_inputs.push_back(branch_label); | new_switch_inputs.push_back(branch_label); | ||||
| AttachOriginalInputsToGraph(kg, origin_inputs); | |||||
| } | } | ||||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | ||||
| cur_node->set_inputs(new_switch_inputs); | cur_node->set_inputs(new_switch_inputs); | ||||
| @@ -76,6 +76,7 @@ class AscendControlParser { | |||||
| static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); | static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); | ||||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | ||||
| const NotNull<std::set<KernelGraphPtr> *> memo); | const NotNull<std::set<KernelGraphPtr> *> memo); | ||||
| static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs); | |||||
| }; | }; | ||||
| class AscendControlParser::ReferenceCounter { | class AscendControlParser::ReferenceCounter { | ||||
| public: | public: | ||||
| @@ -162,6 +162,19 @@ class KernelGraph : public FuncGraph { | |||||
| void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { | void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { | ||||
| child_graph_result_ = child_graph_result; | child_graph_result_ = child_graph_result; | ||||
| } | } | ||||
| void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) { | |||||
| if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { | |||||
| return; | |||||
| } | |||||
| tuple_parameter_to_make_tuple_map_[param] = make_tuple; | |||||
| } | |||||
| AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr ¶m) { | |||||
| if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { | |||||
| return tuple_parameter_to_make_tuple_map_[param]; | |||||
| } else { | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| private: | private: | ||||
| // remove value node form graph | // remove value node form graph | ||||
| @@ -229,6 +242,7 @@ class KernelGraph : public FuncGraph { | |||||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_; | std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_; | ||||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; | std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; | ||||
| uint32_t current_epoch_; | uint32_t current_epoch_; | ||||
| std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_; | |||||
| }; | }; | ||||
| } // namespace session | } // namespace session | ||||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | ||||