From: @zh_qh Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -31,7 +31,6 @@ | |||||
| #include "backend/kernel_compiler/kernel_build_info.h" | #include "backend/kernel_compiler/kernel_build_info.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| #include "abstract/param_validator.h" | #include "abstract/param_validator.h" | ||||
| #include "abstract/primitive_infer_map.h" | |||||
| #include "pipeline/jit/static_analysis/static_analysis.h" | #include "pipeline/jit/static_analysis/static_analysis.h" | ||||
| #include "utils/trace_base.h" | #include "utils/trace_base.h" | ||||
| @@ -1806,14 +1805,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { | |||||
| args_spec_list.emplace_back(real_input->abstract()); | args_spec_list.emplace_back(real_input->abstract()); | ||||
| } | } | ||||
| } | } | ||||
| auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | |||||
| auto ret = prim_eval_implement_map.find(primitive); | |||||
| if (ret == prim_eval_implement_map.end()) { | |||||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << primitive->name() | |||||
| << " primitive type:" << primitive->type_name(); | |||||
| } | |||||
| auto eval_result = ret->second.impl_(nullptr, primitive, args_spec_list); | |||||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||||
| node->set_abstract(eval_result); | node->set_abstract(eval_result); | ||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| @@ -230,6 +230,8 @@ ResolveIRPassLib::ResolveIRPassLib() { | |||||
| {prim::kPrimGetAttr, prim::kPrimResolve}); | {prim::kPrimGetAttr, prim::kPrimResolve}); | ||||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | ||||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr); | resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr); | ||||
| resolver_getattr_resolve_ = | |||||
| MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr); | |||||
| } | } | ||||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | InferenceOptPrepareLib::InferenceOptPrepareLib() { | ||||
| @@ -154,6 +154,7 @@ class ResolveIRPassLib { | |||||
| SubstitutionPtr resolver_resolve_and_getattr_; | SubstitutionPtr resolver_resolve_and_getattr_; | ||||
| SubstitutionPtr resolver_resolve_; | SubstitutionPtr resolver_resolve_; | ||||
| SubstitutionPtr resolver_getattr_; | SubstitutionPtr resolver_getattr_; | ||||
| SubstitutionPtr resolver_getattr_resolve_; | |||||
| }; | }; | ||||
| class InferenceOptPrepareLib { | class InferenceOptPrepareLib { | ||||
| @@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Replace UpdateState with the input monad. | // Replace UpdateState with the input monad. | ||||
| return update_state->inputs().at(kInputIndex); | |||||
| return update_state->input(kInputIndex); | |||||
| } | } | ||||
| // Eliminate UpdateState that attaches a pure (no-side-effect) node. | // Eliminate UpdateState that attaches a pure (no-side-effect) node. | ||||
| @@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A | |||||
| } | } | ||||
| } | } | ||||
| // Remove UpdateState by replace it with its input monad. | // Remove UpdateState by replace it with its input monad. | ||||
| return update_state->inputs().at(kInputIndex); | |||||
| return update_state->input(kInputIndex); | |||||
| } | } | ||||
| // Eliminate redundant UpdateState/Depend pair nodes caused by inline. | // Eliminate redundant UpdateState/Depend pair nodes caused by inline. | ||||
| @@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN | |||||
| // Skip if Depend attach input is not a monad. | // Skip if Depend attach input is not a monad. | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto update_monad = update_state->inputs().at(kInputIndex); | |||||
| auto update_monad = update_state->input(kInputIndex); | |||||
| if (!HasAbstractMonad(update_monad)) { | if (!HasAbstractMonad(update_monad)) { | ||||
| // Skip if UpdateState input is not a monad. | // Skip if UpdateState input is not a monad. | ||||
| MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString(); | MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString(); | ||||
| @@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN | |||||
| } | } | ||||
| // Replace Depend with its input. | // Replace Depend with its input. | ||||
| if (depend->size() == kMinDependSize) { | if (depend->size() == kMinDependSize) { | ||||
| auto depend_input = depend->inputs().at(kInputIndex); | |||||
| auto depend_input = depend->input(kInputIndex); | |||||
| mgr->Replace(depend, depend_input); | mgr->Replace(depend, depend_input); | ||||
| } else { | } else { | ||||
| auto inputs = depend->inputs(); | auto inputs = depend->inputs(); | ||||
| @@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN | |||||
| if (make_tuple->size() != kMakeTupleSize) { | if (make_tuple->size() != kMakeTupleSize) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto &node = make_tuple->inputs().at(kAttachIndex); | |||||
| auto &node = make_tuple->input(kAttachIndex); | |||||
| auto node_abs = node->abstract(); | auto node_abs = node->abstract(); | ||||
| if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) { | if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // Create a new UpdateState to replace the old one. | // Create a new UpdateState to replace the old one. | ||||
| const auto &attach = make_tuple->inputs().at(kInputIndex); | |||||
| const auto &attach = make_tuple->input(kInputIndex); | |||||
| auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach}); | auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach}); | ||||
| new_update_state->set_abstract(update_state->abstract()); | new_update_state->set_abstract(update_state->abstract()); | ||||
| new_update_state->set_scope(update_state->scope()); | new_update_state->set_scope(update_state->scope()); | ||||
| @@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c | |||||
| if (make_tuple->size() != kMakeTupleSize) { | if (make_tuple->size() != kMakeTupleSize) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto &first_input = make_tuple->inputs().at(kInputIndex); | |||||
| auto &first_input = make_tuple->input(kInputIndex); | |||||
| if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { | if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { | ||||
| return update_state->input(1); | return update_state->input(1); | ||||
| } | } | ||||
| auto &second_input = make_tuple->inputs().at(kAttachIndex); | |||||
| auto &second_input = make_tuple->input(kAttachIndex); | |||||
| if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { | if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { | ||||
| return update_state->input(1); | return update_state->input(1); | ||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads); | |||||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads); | |||||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads); | |||||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, | |||||
| std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads); | |||||
| // Search consecutive load nodes from UpdateState node. | // Search consecutive load nodes from UpdateState node. | ||||
| size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *loads) { | |||||
| auto &attach = update_state->inputs().at(kAttachIndex); | |||||
| size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states, | |||||
| std::vector<CNodePtr> *loads) { | |||||
| auto &attach = update_state->input(kAttachIndex); | |||||
| if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { | if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { | ||||
| return GetLoadsFollowLoad(attach->cast<CNodePtr>(), loads); | |||||
| update_states->emplace_back(update_state); | |||||
| return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads); | |||||
| } | } | ||||
| if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { | if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { | ||||
| return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), loads); | |||||
| update_states->emplace_back(update_state); | |||||
| return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads); | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads) { | |||||
| loads->push_back(load); | |||||
| auto &load_attach = load->inputs().at(kAttachIndex); | |||||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) { | |||||
| loads->emplace_back(load); | |||||
| auto &load_attach = load->input(kAttachIndex); | |||||
| if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { | if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { | ||||
| return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), loads) + 1; | |||||
| return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1; | |||||
| } | } | ||||
| return 1; | return 1; | ||||
| } | } | ||||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads) { | |||||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, | |||||
| std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) { | |||||
| if (!OnlyUpdateStateUse(update_state, make_tuple)) { | if (!OnlyUpdateStateUse(update_state, make_tuple)) { | ||||
| // UpdateState should be the only user of | // UpdateState should be the only user of | ||||
| return 0; | return 0; | ||||
| @@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu | |||||
| // Add load nodes from tuple elements. | // Add load nodes from tuple elements. | ||||
| for (size_t i = 1; i < inputs.size(); ++i) { | for (size_t i = 1; i < inputs.size(); ++i) { | ||||
| auto &element = inputs.at(i); | auto &element = inputs.at(i); | ||||
| loads->push_back(element->cast<CNodePtr>()); | |||||
| loads->emplace_back(element->cast<CNodePtr>()); | |||||
| } | } | ||||
| // Follow prev update state if found. | // Follow prev update state if found. | ||||
| auto prev_node = update_state->inputs().at(kInputIndex); | |||||
| auto prev_node = update_state->input(kInputIndex); | |||||
| if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { | if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { | ||||
| return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), loads) + 1; | |||||
| return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1; | |||||
| } | } | ||||
| return 1; | return 1; | ||||
| } | } | ||||
| @@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd | |||||
| // xN = Load(xN, u) | // xN = Load(xN, u) | ||||
| // t = make_tuple(x1, x2, ... , xN) | // t = make_tuple(x1, x2, ... , xN) | ||||
| // u1 = UpdateState(u, t) | // u1 = UpdateState(u, t) | ||||
| AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &loads) { | |||||
| AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states, | |||||
| const std::vector<CNodePtr> &loads) { | |||||
| auto fg = old_update_state->func_graph(); | auto fg = old_update_state->func_graph(); | ||||
| if (fg == nullptr) { | if (fg == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const | |||||
| std::set<AnfNodePtr> loaded_para_set; | std::set<AnfNodePtr> loaded_para_set; | ||||
| make_tuple_inputs.reserve(loads.size() + 1); | make_tuple_inputs.reserve(loads.size() + 1); | ||||
| make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | ||||
| auto input_monad = loads.back()->inputs().at(kAttachIndex); | |||||
| auto input_monad = loads.back()->input(kAttachIndex); | |||||
| for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) { | for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) { | ||||
| auto &load = *iter; | auto &load = *iter; | ||||
| auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex)); | |||||
| auto result = loaded_para_set.emplace(load->input(kInputIndex)); | |||||
| const bool is_new_load = result.second; | const bool is_new_load = result.second; | ||||
| if (is_new_load) { | if (is_new_load) { | ||||
| // Put Load node as a tuple element, if the parameter is not loaded by other Load. | // Put Load node as a tuple element, if the parameter is not loaded by other Load. | ||||
| make_tuple_inputs.emplace_back(load); | make_tuple_inputs.emplace_back(load); | ||||
| } | } | ||||
| if (load->inputs().at(kAttachIndex) != input_monad) { | |||||
| if (load->input(kAttachIndex) != input_monad) { | |||||
| // Set all load use same input monad. | // Set all load use same input monad. | ||||
| mgr->SetEdge(load, kAttachIndex, input_monad); | mgr->SetEdge(load, kAttachIndex, input_monad); | ||||
| } | } | ||||
| } | } | ||||
| for (auto i = update_states.size() - 1; i > 0; i--) { | |||||
| auto &us = update_states[i]; | |||||
| mgr->Replace(us, us->input(kInputIndex)); | |||||
| } | |||||
| if (make_tuple_inputs.size() == 1) { | if (make_tuple_inputs.size() == 1) { | ||||
| // This should not happen. | // This should not happen. | ||||
| MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); | MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); | ||||
| @@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode | |||||
| MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString(); | MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString(); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto &attach = update_state_node->inputs().at(kAttachIndex); | |||||
| auto &attach = update_state_node->input(kAttachIndex); | |||||
| if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { | if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { | ||||
| return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>()); | return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>()); | ||||
| } | } | ||||
| @@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode | |||||
| return new_node; | return new_node; | ||||
| } | } | ||||
| } | } | ||||
| std::vector<CNodePtr> update_states; | |||||
| std::vector<CNodePtr> loads; | std::vector<CNodePtr> loads; | ||||
| if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) { | |||||
| return EliminateUpdateStateForLoads(update_state_node, loads); | |||||
| if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) { | |||||
| return EliminateUpdateStateForLoads(update_state_node, update_states, loads); | |||||
| } | } | ||||
| // Eliminate UpdateStates that attaches a no-side-effect node. | // Eliminate UpdateStates that attaches a no-side-effect node. | ||||
| if (!attach_is_load && !attach_is_tuple) { | if (!attach_is_load && !attach_is_tuple) { | ||||
| @@ -103,8 +103,62 @@ static bool isTraversable(const AnfNodePtr &node) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||||
| const SubstitutionPtr &transform) const { | |||||
| static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||||
| const SubstitutionPtr &substitution) { | |||||
| auto manager = optimizer->manager(); | |||||
| bool is_match = substitution->predicate_(node); | |||||
| if (is_match) { | |||||
| TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info())); | |||||
| auto res = (*substitution)(optimizer, node); | |||||
| if (res != nullptr && res != node) { | |||||
| #ifdef ENABLE_PROFILE | |||||
| double t = GetTime(); | |||||
| #endif | |||||
| MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by " | |||||
| << substitution->name_; | |||||
| (void)manager->Replace(node, res); | |||||
| #ifdef ENABLE_PROFILE | |||||
| MsProfile::StatTime("replace." + substitution->name_, GetTime() - t); | |||||
| #endif | |||||
| return res; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||||
| std::deque<AnfNodePtr> *todo, bool change, size_t seen) { | |||||
| if (IsValueNode<FuncGraph>(node)) { | |||||
| (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||||
| } | |||||
| if (node->isa<CNode>()) { | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); | |||||
| } | |||||
| if (!change) { | |||||
| return; | |||||
| } | |||||
| auto manager = optimizer->manager(); | |||||
| auto &node_users = manager->node_users(); | |||||
| auto users_iterator = node_users.find(node); | |||||
| if (users_iterator == node_users.end()) { | |||||
| return; | |||||
| } | |||||
| auto users = users_iterator->second; | |||||
| for (auto &use : users) { | |||||
| auto use_node = use.first; | |||||
| if (use_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| (*todo).emplace_back(use_node); | |||||
| if (use_node->seen_ == seen) { | |||||
| use_node->seen_--; | |||||
| } | |||||
| } | |||||
| } | |||||
| bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double start = GetTime(); | double start = GetTime(); | ||||
| #endif | #endif | ||||
| @@ -113,7 +167,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||||
| // 1024 is for the initial capacity of deque | // 1024 is for the initial capacity of deque | ||||
| std::deque<AnfNodePtr> todo(1024); | std::deque<AnfNodePtr> todo(1024); | ||||
| todo.clear(); | todo.clear(); | ||||
| todo.push_back(root_node); | |||||
| todo.emplace_back(func_graph->output()); | |||||
| bool changes = false; | bool changes = false; | ||||
| auto &all_nodes = manager->all_nodes(); | auto &all_nodes = manager->all_nodes(); | ||||
| @@ -121,59 +175,61 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||||
| AnfNodePtr node = todo.front(); | AnfNodePtr node = todo.front(); | ||||
| todo.pop_front(); | todo.pop_front(); | ||||
| // check whether this node has been matched. | |||||
| if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { | if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| node->seen_ = seen; | node->seen_ = seen; | ||||
| // select nodes that this transform can be applied. | |||||
| bool is_match = transform->predicate_(node); | |||||
| // apply transform on this node | |||||
| bool change = false; | bool change = false; | ||||
| if (is_match) { | |||||
| TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info())); | |||||
| auto ret = (*transform)(optimizer, node); | |||||
| if (ret != nullptr && ret != node) { | |||||
| for (auto &substitution : list_) { | |||||
| auto res = DoTransform(optimizer, node, substitution); | |||||
| if (res != nullptr) { | |||||
| change = true; | change = true; | ||||
| changes = true; | changes = true; | ||||
| node = res; | |||||
| todo.emplace_back(res); | |||||
| break; | |||||
| } | |||||
| } | |||||
| UpdateTransformingList(optimizer, node, &todo, change, seen); | |||||
| } | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| double t = GetTime(); | |||||
| MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start); | |||||
| #endif | #endif | ||||
| MS_LOG(DEBUG) << "transform: " << transform->name_ << " will replace: " << node->DebugString() | |||||
| << " with: " << ret->DebugString(); | |||||
| (void)manager->Replace(node, ret); | |||||
| return changes; | |||||
| } | |||||
| bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||||
| const SubstitutionPtr &substitution) const { | |||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| MsProfile::StatTime("replace." + transform->name_, GetTime() - t); | |||||
| double start = GetTime(); | |||||
| #endif | #endif | ||||
| node = ret; | |||||
| } | |||||
| } | |||||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||||
| auto seen = NewSeenGeneration(); | |||||
| // 1024 is for the initial capacity of deque | |||||
| std::deque<AnfNodePtr> todo(1024); | |||||
| todo.clear(); | |||||
| todo.emplace_back(root_node); | |||||
| bool changes = false; | |||||
| // find success, and add them to todo list | |||||
| if (IsValueNode<FuncGraph>(node)) { | |||||
| todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||||
| } | |||||
| auto &all_nodes = manager->all_nodes(); | |||||
| while (!todo.empty()) { | |||||
| AnfNodePtr node = todo.front(); | |||||
| todo.pop_front(); | |||||
| if (node->isa<CNode>()) { | |||||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | |||||
| if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { | |||||
| continue; | |||||
| } | } | ||||
| node->seen_ = seen; | |||||
| auto &node_users = manager->node_users(); | |||||
| if (change && node_users.find(node) != node_users.end()) { | |||||
| for (auto &use : node_users[node]) { | |||||
| auto use_node = use.first; | |||||
| if (use_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| todo.push_back(use_node); | |||||
| if (use_node->seen_ == seen) { | |||||
| use_node->seen_--; | |||||
| } | |||||
| } | |||||
| bool change = false; | |||||
| auto res = DoTransform(optimizer, node, substitution); | |||||
| if (res != nullptr) { | |||||
| change = true; | |||||
| changes = true; | |||||
| node = res; | |||||
| } | } | ||||
| UpdateTransformingList(optimizer, node, &todo, change, seen); | |||||
| } | } | ||||
| #ifdef ENABLE_PROFILE | #ifdef ENABLE_PROFILE | ||||
| @@ -182,13 +238,29 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { | |||||
| MS_EXCEPTION_IF_NULL(optimizer); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||||
| manager->AddFuncGraph(func_graph); | |||||
| bool SubstitutionList::ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const { | |||||
| const auto &manager = optimizer->manager(); | |||||
| const auto &nodes = manager->isolate_nodes(); | |||||
| bool changes = false; | |||||
| bool loop = true; | |||||
| while (loop) { | |||||
| loop = false; | |||||
| std::for_each(list_.cbegin(), list_.cend(), [&](const auto &substitution) { | |||||
| std::for_each(nodes.cbegin(), nodes.cend(), [&](const auto &node) { | |||||
| bool change = ApplySubstitutionToIR(optimizer, node, substitution); | |||||
| changes = changes || change; | |||||
| loop = loop || change; | |||||
| }); | |||||
| }); | |||||
| if (is_once_) { | |||||
| break; | |||||
| } | |||||
| } | |||||
| return changes; | |||||
| } | |||||
| // for transform status counting | |||||
| bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | |||||
| // Add for substitution status counting | |||||
| size_t space = 0; | size_t space = 0; | ||||
| std::unordered_map<std::string, std::vector<bool>> status; | std::unordered_map<std::string, std::vector<bool>> status; | ||||
| if (optimizer->is_on_debug_) { | if (optimizer->is_on_debug_) { | ||||
| @@ -197,47 +269,39 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize | |||||
| } | } | ||||
| } | } | ||||
| bool loop = false; | |||||
| bool changes = false; | bool changes = false; | ||||
| do { | |||||
| bool loop = true; | |||||
| while (loop) { | |||||
| loop = false; | loop = false; | ||||
| for (size_t i = 0; i < list_.size(); i++) { | for (size_t i = 0; i < list_.size(); i++) { | ||||
| auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); | |||||
| const auto &substitution = list_[i]; | |||||
| bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution); | |||||
| changes = changes || change; | changes = changes || change; | ||||
| loop = loop || change; | loop = loop || change; | ||||
| // apply transform on isolate nodes. | |||||
| auto &isolate_nodes = manager->isolate_nodes(); | |||||
| for (auto &node : isolate_nodes) { | |||||
| change = ApplyTransform(optimizer, node, list_[i]); | |||||
| changes = changes || change; | |||||
| loop = loop || change; | |||||
| } | |||||
| // record the status of each transform | |||||
| static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); | static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); | ||||
| if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | ||||
| auto fg_name = optimizer->name() + "_" + std::to_string(optimizer->CurPass_.counter) + "_" + | |||||
| optimizer->CurPass_.name + "_" + list_[i]->name_; | |||||
| auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" + | |||||
| optimizer->CurPass_.name + "_" + substitution->name_; | |||||
| DumpIR(fg_name + ".ir", func_graph); | DumpIR(fg_name + ".ir", func_graph); | ||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | ||||
| func_graph->DumpFuncGraph(fg_name); | func_graph->DumpFuncGraph(fg_name); | ||||
| ExportIR(fg_name + ".dat", "", func_graph); | ExportIR(fg_name + ".dat", "", func_graph); | ||||
| } | } | ||||
| } | } | ||||
| // Record the status of each substitution | |||||
| if (optimizer->is_on_debug_) { | if (optimizer->is_on_debug_) { | ||||
| status[list_[i]->name_ + std::to_string(i)].push_back(change); | |||||
| space = std::max(list_[i]->name_.size(), space); | |||||
| status[substitution->name_ + std::to_string(i)].push_back(change); | |||||
| space = std::max(substitution->name_.size(), space); | |||||
| } | } | ||||
| } | } | ||||
| if (is_once_) { | if (is_once_) { | ||||
| break; | break; | ||||
| } | } | ||||
| } while (loop); | |||||
| } | |||||
| // display the status of each transform | |||||
| // Display the status of each substitution | |||||
| if (optimizer->is_on_debug_) { | if (optimizer->is_on_debug_) { | ||||
| std::stringstream ss; | std::stringstream ss; | ||||
| ss << std::endl | ss << std::endl | ||||
| @@ -253,7 +317,37 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize | |||||
| } | } | ||||
| MS_LOG(DEBUG) << ss.str(); | MS_LOG(DEBUG) << ss.str(); | ||||
| } | } | ||||
| return changes; | |||||
| } | |||||
| bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { | |||||
| MS_EXCEPTION_IF_NULL(optimizer); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||||
| manager->AddFuncGraph(func_graph); | |||||
| bool changes = false; | |||||
| static const auto traverse_mode = | |||||
| (common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions | |||||
| : kOptTraverseFromSubstitutionsToIR); | |||||
| if (traverse_mode == kOptTraverseFromIRToSubstitutions && | |||||
| MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | |||||
| optimizer->traverse_nodes_first()) { | |||||
| changes = ApplyIRToSubstitutions(optimizer, func_graph); | |||||
| } else { | |||||
| changes = ApplySubstitutionsToIR(optimizer, func_graph); | |||||
| } | |||||
| bool has_isolate = !manager->isolate_nodes().empty(); | |||||
| if (has_isolate) { | |||||
| #ifdef ENABLE_PROFILE | |||||
| double t = GetTime(); | |||||
| #endif | |||||
| bool change = ApplySubstitutionsToIRForIsolate(optimizer); | |||||
| changes = changes || change; | |||||
| #ifdef ENABLE_PROFILE | |||||
| MsProfile::StatTime("opt.isolate.transform." + optimizer->name(), GetTime() - t); | |||||
| #endif | |||||
| } | |||||
| return changes; | return changes; | ||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -59,6 +59,8 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std: | |||||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | ||||
| const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | ||||
| enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR }; | |||||
| class SubstitutionList { | class SubstitutionList { | ||||
| public: | public: | ||||
| explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false) | explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false) | ||||
| @@ -68,7 +70,10 @@ class SubstitutionList { | |||||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; | bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; | ||||
| private: | private: | ||||
| bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; | |||||
| bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | |||||
| bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; | |||||
| bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | |||||
| bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const; | |||||
| std::vector<SubstitutionPtr> list_; | std::vector<SubstitutionPtr> list_; | ||||
| // a flag to mark this list of Substitution can only be executed only once | // a flag to mark this list of Substitution can only be executed only once | ||||
| bool is_once_; | bool is_once_; | ||||
| @@ -88,13 +88,14 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; | |||||
| class Optimizer : public std::enable_shared_from_this<Optimizer> { | class Optimizer : public std::enable_shared_from_this<Optimizer> { | ||||
| public: | public: | ||||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) | |||||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true) | |||||
| : name_(name), | : name_(name), | ||||
| resource_(resource_ptr), | resource_(resource_ptr), | ||||
| run_only_once_(false), | run_only_once_(false), | ||||
| is_watch_renormalize_(false), | is_watch_renormalize_(false), | ||||
| is_enable_(true), | is_enable_(true), | ||||
| is_untyped_generated_(false) {} | |||||
| is_untyped_generated_(false), | |||||
| traverse_nodes_first_(traverse_nodes_first) {} | |||||
| virtual ~Optimizer() = default; | virtual ~Optimizer() = default; | ||||
| void Init(const OptPassGroupMap &passes, bool run_only_once) { | void Init(const OptPassGroupMap &passes, bool run_only_once) { | ||||
| @@ -129,8 +130,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, | static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, | ||||
| const OptPassGroupMap &passes, bool run_only_once = false, | const OptPassGroupMap &passes, bool run_only_once = false, | ||||
| bool watch_renormalize = false) { | |||||
| OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr); | |||||
| bool watch_renormalize = false, bool traverse_nodes_first = true) { | |||||
| OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first); | |||||
| optimizer->Init(passes, run_only_once); | optimizer->Init(passes, run_only_once); | ||||
| if (watch_renormalize) { | if (watch_renormalize) { | ||||
| optimizer->enable_watch_renormalize(); | optimizer->enable_watch_renormalize(); | ||||
| @@ -223,6 +224,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| bool is_watch_renormalize() { return is_watch_renormalize_; } | bool is_watch_renormalize() { return is_watch_renormalize_; } | ||||
| void set_enable(bool enable) { is_enable_ = enable; } | void set_enable(bool enable) { is_enable_ = enable; } | ||||
| bool traverse_nodes_first() { return traverse_nodes_first_; } | |||||
| struct { | struct { | ||||
| int64_t counter; | int64_t counter; | ||||
| std::string name; | std::string name; | ||||
| @@ -239,6 +242,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| bool is_watch_renormalize_; | bool is_watch_renormalize_; | ||||
| bool is_enable_; | bool is_enable_; | ||||
| bool is_untyped_generated_; | bool is_untyped_generated_; | ||||
| bool traverse_nodes_first_; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -308,7 +308,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa | |||||
| return false; | return false; | ||||
| } | } | ||||
| opt::irpass::ResolveIRPassLib irpass; | opt::irpass::ResolveIRPassLib irpass; | ||||
| opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); | |||||
| opt::OptimizerPtr opt_resolve = | |||||
| opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false); | |||||
| (void)parse::python_adapter::set_python_scoped(); | (void)parse::python_adapter::set_python_scoped(); | ||||
| @@ -246,7 +246,7 @@ class CNode : public AnfNode, public EffectInfoHolder { | |||||
| bool IsApply(const PrimitivePtr &) const; | bool IsApply(const PrimitivePtr &) const; | ||||
| const size_t size() const { return inputs_.size(); } | const size_t size() const { return inputs_.size(); } | ||||
| const AnfNodePtr input(size_t i) const { return inputs_[i]; } | |||||
| const AnfNodePtr &input(size_t i) const { return inputs_.at(i); } | |||||
| const std::vector<AnfNodePtr> &inputs() const { return inputs_; } | const std::vector<AnfNodePtr> &inputs() const { return inputs_; } | ||||
| void add_input(const AnfNodePtr &input); | void add_input(const AnfNodePtr &input); | ||||
| void set_input(size_t i, const AnfNodePtr &input); | void set_input(size_t i, const AnfNodePtr &input); | ||||