Merge pull request !4897 from wenchunjiang/adapte_to_ifbyif_noinlinetags/v0.7.0-beta
| @@ -30,6 +30,10 @@ namespace { | |||
| void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) { | |||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| if (AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimCall) || | |||
| AnfAlgo::CheckPrimitiveType(cnode_ptr, prim::kPrimPartial)) { | |||
| return; | |||
| } | |||
| std::vector<AnfNodePtr> plant_inputs; | |||
| std::vector<int> dyn_input_sizes; | |||
| plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); | |||
| @@ -26,22 +26,19 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *transed_nodes) { | |||
| AnfNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const AnfNodePtr &tuple_anf) { | |||
| MS_EXCEPTION_IF_NULL(tuple_anf); | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| MS_EXCEPTION_IF_NULL(transed_nodes); | |||
| if (!AnfAlgo::IsTupleOutput(tuple_anf)) { | |||
| return tuple_anf; | |||
| } | |||
| auto transed_node_it = transed_nodes->find(tuple_anf); | |||
| if (transed_node_it != transed_nodes->end()) { | |||
| return transed_node_it->second; | |||
| } | |||
| auto kernel_graph = graph->cast<KernelGraphPtr>(); | |||
| if (kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf)) { | |||
| return kernel_graph->FindTupleParameterToMakeTupleMap(tuple_anf); | |||
| } | |||
| auto make_tuple = kernel_graph->TransTupleToMakeTuple(tuple_anf); | |||
| (*transed_nodes)[tuple_anf] = make_tuple; | |||
| kernel_graph->InsertTupleParameterToMakeTupleMap(tuple_anf, make_tuple); | |||
| // replace graph inputs if input is a parameter | |||
| kernel_graph->ReplaceGraphInput(tuple_anf, make_tuple); | |||
| return make_tuple; | |||
| @@ -61,7 +58,6 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> transed_nodes; | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) { | |||
| auto real_input = AnfAlgo::GetTupleGetItemRealInput(cnode); | |||
| MS_EXCEPTION_IF_NULL(real_input); | |||
| @@ -77,7 +73,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||
| const auto &input = cnode->inputs()[i]; | |||
| if (input->Type() != nullptr && AnfAlgo::IsRealKernel(input) && AnfAlgo::IsTupleOutput(input) && | |||
| !AnfAlgo::CheckPrimitiveType(input, prim::kPrimCall)) { | |||
| cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input, &transed_nodes)); | |||
| cnode->set_input(i, ConvertTupleInputToMakeTuple(func_graph, input)); | |||
| cnode_input_changed = true; | |||
| } | |||
| } | |||
| @@ -523,12 +523,22 @@ void AscendControlParser::LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNod | |||
| } | |||
| } | |||
| void AscendControlParser::AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, | |||
| const std::vector<AnfNodePtr> orig_inputs) { | |||
| std::vector<AnfNodePtr> make_tuple_inputs = { | |||
| mindspore::NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name()))}; | |||
| std::copy(orig_inputs.begin(), orig_inputs.end(), std::back_inserter(make_tuple_inputs)); | |||
| auto make_tuple = graph->NewCNode(make_tuple_inputs); | |||
| InsertDependToGraph(graph, NOT_NULL(make_tuple)); | |||
| } | |||
| void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodePtr> cur_node, const CNodePtr &next_node, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| MS_LOG(INFO) << "Process call func " << cur_node->DebugString(); | |||
| // 1 get kernel graph | |||
| const std::vector<AnfNodePtr> &origin_inputs = cur_node->inputs(); | |||
| std::vector<AnfNodePtr> origin_inputs = cur_node->inputs(); | |||
| if (kCNodeCallArg >= origin_inputs.size()) { | |||
| MS_LOG(EXCEPTION) << "Index out of range,size:" << origin_inputs.size(); | |||
| } | |||
| @@ -555,6 +565,8 @@ void AscendControlParser::RecurseCall(NotNull<KernelGraphPtr> kg, NotNull<CNodeP | |||
| cur_node->set_inputs(new_inputs); | |||
| cur_node->set_abstract(nullptr); | |||
| AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue<std::vector<KernelGraphPtr>>({call_kg}), cur_node.get()); | |||
| origin_inputs.assign(origin_inputs.begin() + kCNodeCallArg + 1, origin_inputs.end()); | |||
| AttachOriginalInputsToGraph(kg, origin_inputs); | |||
| MS_LOG(INFO) << "Succeed processing call func " << cur_node->DebugString(); | |||
| } | |||
| @@ -587,11 +599,13 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod | |||
| for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) { | |||
| // 3.1 branch kernel graph and args | |||
| KernelGraphPtr branch_fg; | |||
| std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| std::vector<AnfNodePtr> origin_inputs; | |||
| std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| child_graphs.push_back(branch_fg); | |||
| // 3.2 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| AttachOriginalInputsToGraph(kg, origin_inputs); | |||
| } | |||
| std::swap(new_switch_inputs[kCNodeSwitchTrue], new_switch_inputs[kCNodeSwitchFalse]); | |||
| @@ -635,11 +649,13 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull | |||
| for (size_t i = 0; i < branch_partial.size(); ++i) { | |||
| // 3.1 branch kernel graph and args | |||
| KernelGraphPtr branch_fg; | |||
| std::tie(branch_fg, std::ignore) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| std::vector<AnfNodePtr> origin_inputs; | |||
| std::tie(branch_fg, origin_inputs) = ParsePartial(NOT_NULL(origin_switch_inputs[i])); | |||
| child_graphs.push_back(branch_fg); | |||
| // 3.2 recurse sub graph | |||
| CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo); | |||
| new_switch_inputs.push_back(branch_label); | |||
| AttachOriginalInputsToGraph(kg, origin_inputs); | |||
| } | |||
| new_switch_inputs.insert(new_switch_inputs.end(), branch_partial.begin(), branch_partial.end()); | |||
| cur_node->set_inputs(new_switch_inputs); | |||
| @@ -76,6 +76,7 @@ class AscendControlParser { | |||
| static bool CheckLabelIndex(uint32_t index, uint32_t label_index, const CNodePtr &cnode); | |||
| static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo); | |||
| static void AttachOriginalInputsToGraph(NotNull<KernelGraphPtr> graph, const std::vector<AnfNodePtr> orig_inputs); | |||
| }; | |||
| class AscendControlParser::ReferenceCounter { | |||
| public: | |||
| @@ -162,6 +162,19 @@ class KernelGraph : public FuncGraph { | |||
| void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) { | |||
| child_graph_result_ = child_graph_result; | |||
| } | |||
| void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) { | |||
| if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { | |||
| return; | |||
| } | |||
| tuple_parameter_to_make_tuple_map_[param] = make_tuple; | |||
| } | |||
| AnfNodePtr FindTupleParameterToMakeTupleMap(const AnfNodePtr ¶m) { | |||
| if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) { | |||
| return tuple_parameter_to_make_tuple_map_[param]; | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| private: | |||
| // remove value node form graph | |||
| @@ -229,6 +242,7 @@ class KernelGraph : public FuncGraph { | |||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, std::pair<AnfNodePtr, bool>>> internal_outputs_to_front_map_; | |||
| std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_; | |||
| uint32_t current_epoch_; | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_; | |||
| }; | |||
| } // namespace session | |||
| using KernelGraphPtr = std::shared_ptr<session::KernelGraph>; | |||