Signed-off-by: zhoufeng <zhoufeng54@huawei.com>tags/v0.5.0-beta
| @@ -69,7 +69,7 @@ CNodePtr ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNo | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto inputs = cnode->inputs(); | |||
| (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(plant_inputs)); | |||
| } else if (AnfAlgo::IsTupleOutput(input_node)) { | |||
| } else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) { | |||
| ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes); | |||
| } else { | |||
| dyn_input_sizes.push_back(-1); | |||
| @@ -68,8 +68,9 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func | |||
| if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) || IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { | |||
| return nullptr; | |||
| } | |||
| if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), | |||
| [](const AnfNodePtr &node) { return AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); })) { | |||
| if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { | |||
| return node->Type() != nullptr && AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); | |||
| })) { | |||
| return ConvertTupleInputToMakeTuple(func_graph, cnode); | |||
| } | |||
| return nullptr; | |||
| @@ -18,6 +18,7 @@ | |||
| #include <memory> | |||
| #include "session/ascend_control_parser.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/union_find_set.h" | |||
| static constexpr size_t kCNodePrim = 0; | |||
| static constexpr size_t kCNodeCallArg = 1; | |||
| @@ -57,6 +58,110 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr | |||
| } | |||
| } | |||
| static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(kg.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(kg.get()); | |||
| const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs(); | |||
| for (auto &iter : real_inputs) { | |||
| auto ¶ = iter.first; | |||
| if (para->isa<Parameter>()) { | |||
| union_find_set->Add(para); | |||
| } | |||
| for (auto &arg : iter.second) { | |||
| if (!arg->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| union_find_set->Add(arg); | |||
| } | |||
| } | |||
| for (auto &child : kg->child_graph_order()) { | |||
| InitUnionFindSet(NOT_NULL(child), union_find_set, memo); | |||
| } | |||
| } | |||
| static void UnionParentParameter(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (memo->find(kg.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(kg.get()); | |||
| const std::map<AnfNodePtr, std::set<AnfNodePtr>> &real_inputs = kg->real_inputs(); | |||
| for (auto &iter : real_inputs) { | |||
| auto ¶ = iter.first; | |||
| for (auto &arg : iter.second) { | |||
| if (!arg->isa<Parameter>()) { | |||
| continue; | |||
| } | |||
| union_find_set->Union(arg, para); | |||
| } | |||
| } | |||
| for (auto &child : kg->child_graph_order()) { | |||
| UnionParentParameter(NOT_NULL(child), union_find_set, memo); | |||
| } | |||
| } | |||
| static UnionFindSet<AnfNodePtr> MakeUnionFindSet(NotNull<KernelGraphPtr> root_kg) { | |||
| UnionFindSet<AnfNodePtr> result; | |||
| std::set<KernelGraphPtr> memo; | |||
| InitUnionFindSet(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); | |||
| memo.clear(); | |||
| UnionParentParameter(root_kg, NOT_NULL(&result), NOT_NULL(&memo)); | |||
| return result; | |||
| } | |||
| static void RecursiveReplaceNode(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> main_parameter, | |||
| const std::set<AnfNodePtr> ¶meter_reuse_set, | |||
| const NotNull<std::set<KernelGraphPtr> *> memo) { | |||
| if (parameter_reuse_set.empty()) { | |||
| MS_LOG(EXCEPTION) << "parameter_reuse_set is empty."; | |||
| } | |||
| if (memo->find(kg.get()) != memo->end()) { | |||
| return; | |||
| } | |||
| memo->insert(kg.get()); | |||
| for (auto ¶ : parameter_reuse_set) { | |||
| if (para == main_parameter.get()) { | |||
| continue; | |||
| } | |||
| MS_LOG(INFO) << "Replace " << para->DebugString() << " of graph " << AnfAlgo::GetGraphId(para.get()) << " to " | |||
| << main_parameter->DebugString() << " of graph " << AnfAlgo::GetGraphId(main_parameter.get().get()); | |||
| kg->ReplaceNode(NOT_NULL(para), main_parameter); | |||
| } | |||
| for (auto &child : kg->child_graph_order()) { | |||
| RecursiveReplaceNode(NOT_NULL(child), main_parameter, parameter_reuse_set, memo); | |||
| } | |||
| } | |||
| static void ReuseParameter(NotNull<KernelGraphPtr> root_kg, NotNull<UnionFindSet<AnfNodePtr> *> parameter_set) { | |||
| auto parameter_reuse_sets = parameter_set->GetSets(); | |||
| for (auto &[key, parameter_reuse_set] : parameter_reuse_sets) { | |||
| if (parameter_reuse_set.size() <= 1) { | |||
| continue; | |||
| } | |||
| AnfNodePtr main_parameter = key; | |||
| std::set<AnfNodePtr> root_inputs_set; | |||
| const auto &root_inputs_vector = root_kg->inputs(); | |||
| root_inputs_set.insert(root_inputs_vector.begin(), root_inputs_vector.end()); | |||
| for (auto &node : parameter_reuse_set) { | |||
| if (root_inputs_set.find(node) == root_inputs_set.end()) { | |||
| continue; | |||
| } | |||
| main_parameter = node; | |||
| } | |||
| std::set<KernelGraphPtr> memo; | |||
| RecursiveReplaceNode(root_kg, NOT_NULL(main_parameter), parameter_reuse_set, NOT_NULL(&memo)); | |||
| } | |||
| } | |||
| void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||
| std::set<KernelGraphPtr> memo; | |||
| ProcessKernelGraph(kg, nullptr, nullptr, NOT_NULL(&memo)); | |||
| @@ -68,6 +173,11 @@ void AscendControlParser::LinkGraph(NotNull<KernelGraphPtr> kg) { | |||
| } | |||
| graph_id_map[g->graph_id()] = g; | |||
| } | |||
| // Make UnionFindSet | |||
| UnionFindSet<AnfNodePtr> parameter_set = MakeUnionFindSet(kg); | |||
| // Reuse Parameter | |||
| ReuseParameter(kg, NOT_NULL(¶meter_set)); | |||
| // Insert Assign | |||
| ChildGraphDataAssign(graph_id_map); | |||
| } | |||
| @@ -324,29 +434,6 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul | |||
| InsertDependToGraph(kg, NOT_NULL(assign_node)); | |||
| } | |||
| void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, | |||
| NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param) { | |||
| if (IsPrimitiveCNode(arg, prim::kPrimMakeTuple) && IsPrimitiveCNode(param, prim::kPrimMakeTuple)) { | |||
| MS_LOG(INFO) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " is a tuple"; | |||
| CNodePtr cnode_arg = arg.get()->cast<CNodePtr>(); | |||
| CNodePtr cnode_param = param.get()->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode_arg); | |||
| MS_EXCEPTION_IF_NULL(cnode_param); | |||
| if (cnode_arg->size() != cnode_param->size()) { | |||
| MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " size " << cnode_arg->size() << " but Param " | |||
| << param->DebugString() << " size " << cnode_param->size(); | |||
| } | |||
| for (size_t i = 1; i < cnode_param->size(); ++i) { | |||
| LinkArgsToParam(to_graph, target_graph, NOT_NULL(cnode_arg->input(i)), NOT_NULL(cnode_param->input(i))); | |||
| } | |||
| } else if (arg->isa<CNode>()) { | |||
| InsertAssignToGraph(target_graph, arg, param); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Arg " << arg->DebugString() << " Param " << param->DebugString() << " unknown type."; | |||
| } | |||
| } | |||
| void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) { | |||
| std::set<KernelGraphPtr> memo; | |||
| (void)RecurseGraph(root_graph, NOT_NULL(&memo)); | |||
| @@ -52,9 +52,6 @@ class AscendControlParser { | |||
| const CNodePtr &last_label); | |||
| static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); | |||
| static void LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotNull<KernelGraphPtr> target_graph, | |||
| NotNull<AnfNodePtr> arg, NotNull<AnfNodePtr> param); | |||
| static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); | |||
| static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start); | |||
| @@ -224,14 +224,6 @@ static void BindCallArgsWithParameter(const std::vector<AnfNodePtr> ¶meters, | |||
| MS_LOG(INFO) << "Parameter and arg are same"; | |||
| continue; | |||
| } | |||
| // if arg is a parameter ,then reuse this parameter | |||
| if (args[i]->isa<Parameter>()) { | |||
| MS_LOG(INFO) << "Parameter:" << parameters[i]->DebugString() << " of graph:" << child_graph->graph_id() | |||
| << " reuse parameter:" << args[i]->DebugString() | |||
| << " of graph:" << AnfAlgo::GetGraphId(args[i].get()); | |||
| child_graph->ReplaceNode(parameters[i], args[i]); | |||
| continue; | |||
| } | |||
| child_graph->SetRealInput(parameters[i], args[i]); | |||
| } | |||
| } | |||
| @@ -412,7 +404,6 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor:: | |||
| VectorRef *const outputs) { | |||
| MS_LOG(INFO) << "start"; | |||
| auto kernel_graph = GetGraph(graph_id); | |||
| DumpIR("./run_graph.ir", kernel_graph); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // if none of child graph and no anf output exists | |||
| if (!kernel_graph->executable()) { | |||
| @@ -1134,7 +1125,7 @@ void AscendSession::SetChildGraphParameter(const AnfNodePtr &front_anf, GraphId | |||
| MS_EXCEPTION_IF_NULL(backend_arg); | |||
| MS_LOG(INFO) << "Reuse node [" << backend_arg->DebugString() << "], old node[" << backend_parameter->DebugString() | |||
| << "] will be replaced."; | |||
| to_graph->ReplaceNode(backend_parameter, backend_arg); | |||
| to_graph->ReplaceNode(NOT_NULL(backend_parameter), NOT_NULL(backend_arg)); | |||
| return; | |||
| } | |||
| MS_LOG(INFO) << "Assign of node" << backend_arg->DebugString() << " of graph " << from_graph_id << " to node" | |||
| @@ -587,9 +587,7 @@ bool KernelGraph::RemoveValueNodeFromGraph(const ValueNodePtr &value_node) { | |||
| return false; | |||
| } | |||
| void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node) { | |||
| MS_EXCEPTION_IF_NULL(old_anf_node); | |||
| MS_EXCEPTION_IF_NULL(new_anf_node); | |||
| void KernelGraph::ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node) { | |||
| MS_EXCEPTION_IF_NULL(inputs_); | |||
| auto it = node_output_edges_.find(old_anf_node); | |||
| if (it != node_output_edges_.end()) { | |||
| @@ -604,16 +602,16 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf | |||
| continue; | |||
| } | |||
| for (size_t i = 1; i < output_node_inputs.size(); i++) { | |||
| if (output_node_inputs[i] == old_anf_node) { | |||
| if (output_node_inputs[i] == old_anf_node.get()) { | |||
| output_cnode->set_input(i, new_anf_node); | |||
| } | |||
| } | |||
| // update graph inputs | |||
| for (size_t i = 0; i < inputs_->size(); i++) { | |||
| if ((*inputs_)[i] == old_anf_node) { | |||
| if ((*inputs_)[i] == old_anf_node.get()) { | |||
| MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString() | |||
| << ",new graph input:" << new_anf_node->DebugString(); | |||
| (*inputs_)[i] = new_anf_node; | |||
| (*inputs_)[i] = new_anf_node.get(); | |||
| break; | |||
| } | |||
| } | |||
| @@ -621,7 +619,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf | |||
| // update front to backend map | |||
| FrontBackendlMapUpdate(old_anf_node, new_anf_node); | |||
| // update output depend relations | |||
| node_output_edges_[new_anf_node] = it->second; | |||
| node_output_edges_[new_anf_node.get()] = it->second; | |||
| (void)node_output_edges_.erase(old_anf_node); | |||
| } | |||
| // update graph inputs in child graph | |||
| @@ -633,7 +631,7 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf | |||
| MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited."; | |||
| iter->second = it_real_inputs->second; | |||
| } else { | |||
| real_inputs_[new_anf_node] = it_real_inputs->second; | |||
| real_inputs_[new_anf_node.get()] = it_real_inputs->second; | |||
| } | |||
| // erase old parameter in map | |||
| real_inputs_.erase(old_anf_node); | |||
| @@ -697,7 +695,6 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) { | |||
| void KernelGraph::UpdateCallRealInput() { | |||
| MS_LOG(INFO) << "Update graph id: " << graph_id_; | |||
| std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map; | |||
| std::vector<std::pair<AnfNodePtr, AnfNodePtr>> replace_list; | |||
| for (auto &it : real_inputs_) { | |||
| auto parameter = it.first; | |||
| MS_EXCEPTION_IF_NULL(parameter); | |||
| @@ -722,16 +719,9 @@ void KernelGraph::UpdateCallRealInput() { | |||
| MS_LOG(INFO) << "paramter: " << parameter->DebugString() | |||
| << " insert real input:" << new_real_input->DebugString(); | |||
| (void)real_inputs.insert(new_real_input); | |||
| if (new_real_input->isa<Parameter>()) { | |||
| replace_list.emplace_back(parameter, new_real_input); | |||
| parameter = new_real_input; | |||
| } | |||
| } | |||
| real_inputs_map[parameter] = real_inputs; | |||
| } | |||
| for (auto [parameter, arg] : replace_list) { | |||
| ReplaceNode(parameter, arg); | |||
| } | |||
| real_inputs_ = real_inputs_map; | |||
| } | |||
| @@ -99,7 +99,7 @@ class KernelGraph : public FuncGraph { | |||
| std::vector<bool> *MutableValidInputs() { return &valid_inputs_; } | |||
| std::vector<bool> valid_inputs() const { return valid_inputs_; } | |||
| // replace node in graph | |||
| void ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf_node); | |||
| void ReplaceNode(NotNull<AnfNodePtr> old_anf_node, NotNull<AnfNodePtr> new_anf_node); | |||
| // set stream label of graph | |||
| void set_stream_distinction_label(uint32_t stream_label) { stream_distinction_label_ = stream_label; } | |||
| // get stream label of graph | |||
| @@ -459,6 +459,8 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) | |||
| if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) { | |||
| cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(anf)); | |||
| continue; | |||
| } else if (IsValueNode<FuncGraph>(anf)) { | |||
| continue; | |||
| } | |||
| MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]"; | |||
| } | |||
| @@ -613,6 +615,7 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphP | |||
| if (ExistSummaryNode(graph.get())) { | |||
| graph->set_summary_node_exist(true); | |||
| } | |||
| opt::BackendCommonOptimization(graph); | |||
| return graph; | |||
| } | |||
| @@ -626,7 +629,7 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶ | |||
| auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter); | |||
| if (backend_parameter == nullptr) { | |||
| // for example "def f(x,y,z) {return x + y}", parameter z in unused | |||
| CreateNewParameterFromParameter(parameter, false, graph); | |||
| CreateNewParameterFromParameter(parameter, true, graph); | |||
| MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString(); | |||
| continue; | |||
| } | |||
| @@ -0,0 +1,85 @@ | |||
| /** | |||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||
| * | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ | |||
| #define MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ | |||
| #include <map> | |||
| #include <set> | |||
| namespace mindspore { | |||
| template <class T> | |||
| class UnionFindSet { | |||
| public: | |||
| UnionFindSet() : union_find_set_() {} | |||
| void Add(const T &elem) { | |||
| if (union_find_set_.find(elem) != union_find_set_.end()) { | |||
| return; | |||
| } | |||
| union_find_set_[elem] = elem; | |||
| } | |||
| T Find(const T &key) { | |||
| T key_parent = key; | |||
| auto iter = union_find_set_.find(key_parent); | |||
| if (iter == union_find_set_.end()) { | |||
| MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; | |||
| } | |||
| while (key_parent != iter->second) { | |||
| key_parent = iter->second; | |||
| iter = union_find_set_.find(key_parent); | |||
| if (iter == union_find_set_.end()) { | |||
| MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << key_parent; | |||
| } | |||
| } | |||
| T tmp = key; | |||
| T tmp_parent; | |||
| while (tmp != key_parent) { | |||
| iter = union_find_set_.find(tmp); | |||
| if (iter == union_find_set_.end()) { | |||
| MS_LOG(EXCEPTION) << "union_find_set_ cannot find key " << tmp; | |||
| } | |||
| tmp_parent = iter->second; | |||
| union_find_set_[tmp] = key_parent; | |||
| tmp = tmp_parent; | |||
| } | |||
| return key_parent; | |||
| } | |||
| void Union(const T &left, const T &right) { union_find_set_[Find(left)] = Find(right); } | |||
| std::map<T, std::set<T>> GetSets() { | |||
| std::map<T, std::set<T>> result; | |||
| for (auto &iter : union_find_set_) { | |||
| (void)Find(iter.first); | |||
| } | |||
| for (auto &iter : union_find_set_) { | |||
| T parent = Find(iter.first); | |||
| result[parent].insert(iter.first); | |||
| } | |||
| return result; | |||
| } | |||
| private: | |||
| std::map<T, T> union_find_set_; | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_UNION_FIND_SET_H_ | |||