| @@ -21,7 +21,6 @@ | |||
| #include "frontend/optimizer/irpass/convert.h" | |||
| #include "frontend/optimizer/irpass/environ_eliminate.h" | |||
| #include "frontend/optimizer/irpass/grad_var_prepare.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/inline.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #include "frontend/optimizer/irpass/load_eliminate.h" | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -36,20 +36,6 @@ AnfNodePtr ExpandJPrimitive(const ValueNodePtr &vnode, const pipeline::ResourceB | |||
| return nullptr; | |||
| } | |||
| bool CheckIfEmbedJ(const CNodePtr &j_node) { | |||
| auto &value_node = j_node->input(1); | |||
| if (IsValueNode<Primitive>(value_node)) { | |||
| return false; | |||
| } | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(value_node); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Unexpected J node, input func graph should not be null, node: " << j_node->DebugString(); | |||
| } | |||
| auto func_graph_manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(func_graph_manager); | |||
| return func_graph_manager->func_graph_j_total(func_graph); | |||
| } | |||
| bool IsSideEffectOp(const AnfNodePtr &node) { | |||
| if (!node->isa<CNode>()) { | |||
| return false; | |||
| @@ -78,25 +64,17 @@ AnfNodePtr ExpandJ(const ValueNodePtr &vnode, const OptimizerPtr &optimizer) { | |||
| } // namespace internal | |||
| bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { | |||
| // Search all j nodes. | |||
| GetJPrim(func_graph); | |||
| // Get j nodes that don't have embed j nodes. | |||
| std::vector<CNodePtr> todo; | |||
| // If graph also contains J(FuncGraph) or J(Primitive), then ignore this graph. | |||
| // ExpandJ innermost graph or primitive first. | |||
| std::copy_if(j_nodes_.begin(), j_nodes_.end(), std::back_inserter(todo), | |||
| [](const CNodePtr &j_node) { return !internal::CheckIfEmbedJ(j_node); }); | |||
| // Check whether need to eliminate forward cnodes in pynative mode. | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) { | |||
| auto pynative_exec = pynative::PynativeExecutor::GetInstance(); | |||
| auto grad_exec = pynative_exec->grad_executor(); | |||
| bool eliminate_forward = grad_exec->eliminate_forward(); | |||
| grad_exec->set_eliminate_forward(eliminate_forward && todo.empty()); | |||
| grad_exec->set_eliminate_forward(eliminate_forward && prim_nodes_.empty()); | |||
| } | |||
| // Expand j nodes that don't have embed j nodes. | |||
| bool change = false; | |||
| auto manager = optimizer->manager(); | |||
| for (auto &j_node : todo) { | |||
| for (auto &j_node : prim_nodes_) { | |||
| auto expanded_j = internal::ExpandJ(j_node->input(1)->cast<ValueNodePtr>(), optimizer); | |||
| manager->Replace(j_node, expanded_j); | |||
| if (j_node->func_graph()->has_flag(FUNC_GRAPH_FLAG_K_GRAPH)) { | |||
| @@ -107,18 +85,6 @@ bool ExpandJPrim::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr | |||
| } | |||
| return change; | |||
| } | |||
| void ExpandJPrim::GetJPrim(const FuncGraphPtr &func_graph) { | |||
| j_nodes_.clear(); | |||
| AnfNodePtr ret = func_graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); | |||
| for (auto &node : all_nodes) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimJ)) { | |||
| j_nodes_.push_back(node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -27,21 +27,19 @@ | |||
| #include "utils/ms_utils.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/ad/grad.h" | |||
| #include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| // {prim::kPrimJ, C} | |||
| class ExpandJPrim { | |||
| class ExpandJPrim : public ExpandMetaFGPrim { | |||
| public: | |||
| ExpandJPrim() = default; | |||
| ExpandJPrim() { prim_ = prim::kPrimJ; } | |||
| virtual ~ExpandJPrim() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); | |||
| void GetJPrim(const FuncGraphPtr &func_graph); | |||
| private: | |||
| std::vector<CNodePtr> j_nodes_; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) override; | |||
| }; | |||
| using ExpandJPrimPtr = std::shared_ptr<ExpandJPrim>; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,43 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/irpass/meta_fg_eliminate.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| bool ExpandMetaFg::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) { | |||
| AnfNodePtr return_node = func_graph->get_return(); | |||
| MS_EXCEPTION_IF_NULL(return_node); | |||
| std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(return_node); | |||
| // The expanding of meta fg may change the number of outer layer meta fgs. | |||
| // So, find all kinds of candidate meta fgs together and then expands them. | |||
| for (auto expand_meta_fg_element : expand_meta_fg_list_) { | |||
| expand_meta_fg_element->GetMetaFGPrim(all_nodes); | |||
| } | |||
| bool ret = false; | |||
| for (auto expand_meta_fg_element : expand_meta_fg_list_) { | |||
| auto prim_nodes = expand_meta_fg_element->prim_nodes(); | |||
| if (prim_nodes.size() != 0) { | |||
| ret = ret || (*expand_meta_fg_element)(func_graph, optimizer); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "base/core_ops.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| class ExpandMetaFg { | |||
| public: | |||
| ExpandMetaFg() { (void)expand_meta_fg_list_.emplace_back(std::make_shared<ExpandJPrim>()); } | |||
| virtual ~ExpandMetaFg() = default; | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); | |||
| private: | |||
| std::vector<ExpandMetaFGPrimPtr> expand_meta_fg_list_; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_ELIMINATE_H_ | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/irpass/meta_fg_prim_eliminate.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| bool ExpandMetaFGPrim::CheckIfEmbedMetaFGPrim(const CNodePtr &node) const { | |||
| auto &value_node = node->input(1); | |||
| if (IsValueNode<Primitive>(value_node)) { | |||
| return false; | |||
| } | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(value_node); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Unexpected meta function graph node:" << node->DebugString(); | |||
| } | |||
| auto func_graph_manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(func_graph_manager); | |||
| return func_graph_manager->func_graph_j_total(func_graph); | |||
| } | |||
| void ExpandMetaFGPrim::GetMetaFGPrim(const std::vector<AnfNodePtr> &all_nodes) { | |||
| MS_EXCEPTION_IF_NULL(prim_); | |||
| prim_nodes_.clear(); | |||
| for (auto &node : all_nodes) { | |||
| if (IsPrimitiveCNode(node, prim_) && !CheckIfEmbedMetaFGPrim(node->cast<CNodePtr>())) { | |||
| prim_nodes_.push_back(node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_PRIM_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_PRIM_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/ad/grad.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| // {prim::kPrimJ, C} | |||
| class ExpandMetaFGPrim { | |||
| public: | |||
| ExpandMetaFGPrim() = default; | |||
| virtual ~ExpandMetaFGPrim() = default; | |||
| bool CheckIfEmbedMetaFGPrim(const CNodePtr &node) const; | |||
| const std::vector<CNodePtr> &prim_nodes() const { return prim_nodes_; } | |||
| virtual bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) = 0; | |||
| void GetMetaFGPrim(const std::vector<AnfNodePtr> &all_nodes); | |||
| protected: | |||
| std::vector<CNodePtr> prim_nodes_; | |||
| PrimitivePtr prim_{nullptr}; | |||
| }; | |||
| using ExpandMetaFGPrimPtr = std::shared_ptr<ExpandMetaFGPrim>; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_META_FG_PRIM_ELIMINATE_H_ | |||
| @@ -48,7 +48,7 @@ | |||
| #include "pipeline/pynative/pynative_execute.h" | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "frontend/optimizer/irpass/branch_culling.h" | |||
| #include "frontend/optimizer/irpass/gradient_eliminate.h" | |||
| #include "frontend/optimizer/irpass/meta_fg_eliminate.h" | |||
| #include "frontend/optimizer/irpass/parameter_eliminate.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| @@ -232,7 +232,7 @@ bool parallel_mode() { | |||
| void AddParallelRenormalize(OptPassGroupMap *map_a) { | |||
| if (parallel_mode()) { | |||
| auto parallel_end_opt = | |||
| find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "grad"; }); | |||
| find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "meta_fg_expand"; }); | |||
| if (parallel_end_opt != map_a->end()) { | |||
| (void)map_a->insert(parallel_end_opt, {"parallel_renormalize", opt::OptPassConfig::Renormalize()}); | |||
| } | |||
| @@ -357,7 +357,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)}, | |||
| {"virtual_dataset", virtual_dataset}, | |||
| {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})}, | |||
| {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())}, | |||
| {"meta_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaFg())}, | |||
| {"after_resolve", after_resolve_pass}, | |||
| {"a_after_grad", a_after_grad}, | |||
| {"renormalize", opt::OptPassConfig::Renormalize()}, | |||
| @@ -159,6 +159,9 @@ constexpr auto kCSRReduceSum = "CSRReduceSum"; | |||
| constexpr auto kCSRMV = "CSRMV"; | |||
| constexpr auto kCSRMul = "CSRMul"; | |||
| // Meta Function Graph | |||
| constexpr auto kJ = "J"; | |||
| // Others | |||
| constexpr auto kMakeTuple = "MakeTuple"; | |||
| constexpr auto kAssign = "Assign"; | |||
| @@ -822,7 +825,7 @@ MS_CORE_API inline const PrimitivePtr kPrimPyInterpret = std::make_shared<Primit | |||
| // Other primitive not used by backend but used in core; | |||
| MS_CORE_API inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem"); | |||
| MS_CORE_API inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J", kSideEffectPropagate); | |||
| MS_CORE_API inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>(kJ, kSideEffectPropagate); | |||
| MS_CORE_API inline const PrimitivePtr kPrimShard = std::make_shared<Primitive>("Shard", kSideEffectPropagate); | |||
| // Used to build graph which have keyword arguments | |||