From: @zh_qh Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -31,7 +31,6 @@ | |||
| #include "backend/kernel_compiler/kernel_build_info.h" | |||
| #include "common/trans.h" | |||
| #include "abstract/param_validator.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| #include "pipeline/jit/static_analysis/static_analysis.h" | |||
| #include "utils/trace_base.h" | |||
| @@ -1806,14 +1805,7 @@ void AnfRuntimeAlgorithm::InferShape(const CNodePtr &node) { | |||
| args_spec_list.emplace_back(real_input->abstract()); | |||
| } | |||
| } | |||
| auto &prim_eval_implement_map = abstract::GetPrimitiveToEvalImplMap(); | |||
| auto ret = prim_eval_implement_map.find(primitive); | |||
| if (ret == prim_eval_implement_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << primitive->name() | |||
| << " primitive type:" << primitive->type_name(); | |||
| } | |||
| auto eval_result = ret->second.impl_(nullptr, primitive, args_spec_list); | |||
| auto eval_result = abstract::CppInferShape(primitive, args_spec_list); | |||
| node->set_abstract(eval_result); | |||
| } | |||
| } // namespace session | |||
| @@ -230,6 +230,8 @@ ResolveIRPassLib::ResolveIRPassLib() { | |||
| {prim::kPrimGetAttr, prim::kPrimResolve}); | |||
| resolver_resolve_ = MakeSubstitution(std::make_shared<ResolverResolve>(), "resolver_resolve", prim::kPrimResolve); | |||
| resolver_getattr_ = MakeSubstitution(std::make_shared<ResolverGetAttr>(), "resolver_getattr", prim::kPrimGetAttr); | |||
| resolver_getattr_resolve_ = | |||
| MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "resolver_getattr_resolve", prim::kPrimGetAttr); | |||
| } | |||
| InferenceOptPrepareLib::InferenceOptPrepareLib() { | |||
| @@ -154,6 +154,7 @@ class ResolveIRPassLib { | |||
| SubstitutionPtr resolver_resolve_and_getattr_; | |||
| SubstitutionPtr resolver_resolve_; | |||
| SubstitutionPtr resolver_getattr_; | |||
| SubstitutionPtr resolver_getattr_resolve_; | |||
| }; | |||
| class InferenceOptPrepareLib { | |||
| @@ -71,7 +71,7 @@ AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const | |||
| return nullptr; | |||
| } | |||
| // Replace UpdateState with the input monad. | |||
| return update_state->inputs().at(kInputIndex); | |||
| return update_state->input(kInputIndex); | |||
| } | |||
| // Eliminate UpdateState that attaches a pure (no-side-effect) node. | |||
| @@ -100,7 +100,7 @@ AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const A | |||
| } | |||
| } | |||
| // Remove UpdateState by replace it with its input monad. | |||
| return update_state->inputs().at(kInputIndex); | |||
| return update_state->input(kInputIndex); | |||
| } | |||
| // Eliminate redundant UpdateState/Depend pair nodes caused by inline. | |||
| @@ -118,7 +118,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN | |||
| // Skip if Depend attach input is not a monad. | |||
| return nullptr; | |||
| } | |||
| auto update_monad = update_state->inputs().at(kInputIndex); | |||
| auto update_monad = update_state->input(kInputIndex); | |||
| if (!HasAbstractMonad(update_monad)) { | |||
| // Skip if UpdateState input is not a monad. | |||
| MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString(); | |||
| @@ -139,7 +139,7 @@ AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CN | |||
| } | |||
| // Replace Depend with its input. | |||
| if (depend->size() == kMinDependSize) { | |||
| auto depend_input = depend->inputs().at(kInputIndex); | |||
| auto depend_input = depend->input(kInputIndex); | |||
| mgr->Replace(depend, depend_input); | |||
| } else { | |||
| auto inputs = depend->inputs(); | |||
| @@ -163,7 +163,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN | |||
| if (make_tuple->size() != kMakeTupleSize) { | |||
| return nullptr; | |||
| } | |||
| auto &node = make_tuple->inputs().at(kAttachIndex); | |||
| auto &node = make_tuple->input(kAttachIndex); | |||
| auto node_abs = node->abstract(); | |||
| if (node_abs == nullptr || !node_abs->isa<abstract::AbstractError>()) { | |||
| return nullptr; | |||
| @@ -173,7 +173,7 @@ AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CN | |||
| return nullptr; | |||
| } | |||
| // Create a new UpdateState to replace the old one. | |||
| const auto &attach = make_tuple->inputs().at(kInputIndex); | |||
| const auto &attach = make_tuple->input(kInputIndex); | |||
| auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach}); | |||
| new_update_state->set_abstract(update_state->abstract()); | |||
| new_update_state->set_scope(update_state->scope()); | |||
| @@ -206,42 +206,47 @@ AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, c | |||
| if (make_tuple->size() != kMakeTupleSize) { | |||
| return nullptr; | |||
| } | |||
| auto &first_input = make_tuple->inputs().at(kInputIndex); | |||
| auto &first_input = make_tuple->input(kInputIndex); | |||
| if (IsValueNode<FuncGraph>(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { | |||
| return update_state->input(1); | |||
| } | |||
| auto &second_input = make_tuple->inputs().at(kAttachIndex); | |||
| auto &second_input = make_tuple->input(kAttachIndex); | |||
| if (IsValueNode<FuncGraph>(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { | |||
| return update_state->input(1); | |||
| } | |||
| return nullptr; | |||
| } | |||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads); | |||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads); | |||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads); | |||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, | |||
| std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads); | |||
| // Search consecutive load nodes from UpdateState node. | |||
| size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *loads) { | |||
| auto &attach = update_state->inputs().at(kAttachIndex); | |||
| size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector<CNodePtr> *update_states, | |||
| std::vector<CNodePtr> *loads) { | |||
| auto &attach = update_state->input(kAttachIndex); | |||
| if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { | |||
| return GetLoadsFollowLoad(attach->cast<CNodePtr>(), loads); | |||
| update_states->emplace_back(update_state); | |||
| return GetLoadsFollowLoad(attach->cast<CNodePtr>(), update_states, loads); | |||
| } | |||
| if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { | |||
| return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), loads); | |||
| update_states->emplace_back(update_state); | |||
| return GetLoadsFollowTuple(update_state, attach->cast<CNodePtr>(), update_states, loads); | |||
| } | |||
| return 0; | |||
| } | |||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *loads) { | |||
| loads->push_back(load); | |||
| auto &load_attach = load->inputs().at(kAttachIndex); | |||
| size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) { | |||
| loads->emplace_back(load); | |||
| auto &load_attach = load->input(kAttachIndex); | |||
| if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { | |||
| return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), loads) + 1; | |||
| return GetLoadsFromUpdateState(load_attach->cast<CNodePtr>(), update_states, loads) + 1; | |||
| } | |||
| return 1; | |||
| } | |||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector<CNodePtr> *loads) { | |||
| size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, | |||
| std::vector<CNodePtr> *update_states, std::vector<CNodePtr> *loads) { | |||
| if (!OnlyUpdateStateUse(update_state, make_tuple)) { | |||
| // UpdateState should be the only user of | |||
| return 0; | |||
| @@ -256,12 +261,12 @@ size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tu | |||
| // Add load nodes from tuple elements. | |||
| for (size_t i = 1; i < inputs.size(); ++i) { | |||
| auto &element = inputs.at(i); | |||
| loads->push_back(element->cast<CNodePtr>()); | |||
| loads->emplace_back(element->cast<CNodePtr>()); | |||
| } | |||
| // Follow prev update state if found. | |||
| auto prev_node = update_state->inputs().at(kInputIndex); | |||
| auto prev_node = update_state->input(kInputIndex); | |||
| if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { | |||
| return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), loads) + 1; | |||
| return GetLoadsFromUpdateState(prev_node->cast<CNodePtr>(), update_states, loads) + 1; | |||
| } | |||
| return 1; | |||
| } | |||
| @@ -301,7 +306,8 @@ AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_upd | |||
| // xN = Load(xN, u) | |||
| // t = make_tuple(x1, x2, ... , xN) | |||
| // u1 = UpdateState(u, t) | |||
| AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &loads) { | |||
| AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector<CNodePtr> &update_states, | |||
| const std::vector<CNodePtr> &loads) { | |||
| auto fg = old_update_state->func_graph(); | |||
| if (fg == nullptr) { | |||
| return nullptr; | |||
| @@ -315,20 +321,24 @@ AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const | |||
| std::set<AnfNodePtr> loaded_para_set; | |||
| make_tuple_inputs.reserve(loads.size() + 1); | |||
| make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| auto input_monad = loads.back()->inputs().at(kAttachIndex); | |||
| auto input_monad = loads.back()->input(kAttachIndex); | |||
| for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) { | |||
| auto &load = *iter; | |||
| auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex)); | |||
| auto result = loaded_para_set.emplace(load->input(kInputIndex)); | |||
| const bool is_new_load = result.second; | |||
| if (is_new_load) { | |||
| // Put Load node as a tuple element, if the parameter is not loaded by other Load. | |||
| make_tuple_inputs.emplace_back(load); | |||
| } | |||
| if (load->inputs().at(kAttachIndex) != input_monad) { | |||
| if (load->input(kAttachIndex) != input_monad) { | |||
| // Set all load use same input monad. | |||
| mgr->SetEdge(load, kAttachIndex, input_monad); | |||
| } | |||
| } | |||
| for (auto i = update_states.size() - 1; i > 0; i--) { | |||
| auto &us = update_states[i]; | |||
| mgr->Replace(us, us->input(kInputIndex)); | |||
| } | |||
| if (make_tuple_inputs.size() == 1) { | |||
| // This should not happen. | |||
| MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); | |||
| @@ -538,7 +548,7 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode | |||
| MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString(); | |||
| return nullptr; | |||
| } | |||
| auto &attach = update_state_node->inputs().at(kAttachIndex); | |||
| auto &attach = update_state_node->input(kAttachIndex); | |||
| if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { | |||
| return EliminateUpdateStateWithDepend(update_state_node, attach->cast<CNodePtr>()); | |||
| } | |||
| @@ -586,9 +596,10 @@ AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNode | |||
| return new_node; | |||
| } | |||
| } | |||
| std::vector<CNodePtr> update_states; | |||
| std::vector<CNodePtr> loads; | |||
| if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) { | |||
| return EliminateUpdateStateForLoads(update_state_node, loads); | |||
| if (GetLoadsFromUpdateState(update_state_node, &update_states, &loads) > 1 && loads.size() > 1) { | |||
| return EliminateUpdateStateForLoads(update_state_node, update_states, loads); | |||
| } | |||
| // Eliminate UpdateStates that attaches a no-side-effect node. | |||
| if (!attach_is_load && !attach_is_tuple) { | |||
| @@ -103,8 +103,62 @@ static bool isTraversable(const AnfNodePtr &node) { | |||
| return false; | |||
| } | |||
| bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||
| const SubstitutionPtr &transform) const { | |||
| static inline AnfNodePtr DoTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||
| const SubstitutionPtr &substitution) { | |||
| auto manager = optimizer->manager(); | |||
| bool is_match = substitution->predicate_(node); | |||
| if (is_match) { | |||
| TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info())); | |||
| auto res = (*substitution)(optimizer, node); | |||
| if (res != nullptr && res != node) { | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| #endif | |||
| MS_LOG(DEBUG) << "Replace " << node->DebugString() << " with " << res->DebugString() << ", by " | |||
| << substitution->name_; | |||
| (void)manager->Replace(node, res); | |||
| #ifdef ENABLE_PROFILE | |||
| MsProfile::StatTime("replace." + substitution->name_, GetTime() - t); | |||
| #endif | |||
| return res; | |||
| } | |||
| } | |||
| return nullptr; | |||
| } | |||
| static inline void UpdateTransformingList(const OptimizerPtr &optimizer, const AnfNodePtr &node, | |||
| std::deque<AnfNodePtr> *todo, bool change, size_t seen) { | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| (*todo).emplace_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||
| } | |||
| if (node->isa<CNode>()) { | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(*todo)); | |||
| } | |||
| if (!change) { | |||
| return; | |||
| } | |||
| auto manager = optimizer->manager(); | |||
| auto &node_users = manager->node_users(); | |||
| auto users_iterator = node_users.find(node); | |||
| if (users_iterator == node_users.end()) { | |||
| return; | |||
| } | |||
| auto users = users_iterator->second; | |||
| for (auto &use : users) { | |||
| auto use_node = use.first; | |||
| if (use_node == nullptr) { | |||
| continue; | |||
| } | |||
| (*todo).emplace_back(use_node); | |||
| if (use_node->seen_ == seen) { | |||
| use_node->seen_--; | |||
| } | |||
| } | |||
| } | |||
| bool SubstitutionList::ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | |||
| #ifdef ENABLE_PROFILE | |||
| double start = GetTime(); | |||
| #endif | |||
| @@ -113,7 +167,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||
| // 1024 is for the initial capacity of deque | |||
| std::deque<AnfNodePtr> todo(1024); | |||
| todo.clear(); | |||
| todo.push_back(root_node); | |||
| todo.emplace_back(func_graph->output()); | |||
| bool changes = false; | |||
| auto &all_nodes = manager->all_nodes(); | |||
| @@ -121,59 +175,61 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||
| AnfNodePtr node = todo.front(); | |||
| todo.pop_front(); | |||
| // check whether this node has been matched. | |||
| if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { | |||
| continue; | |||
| } | |||
| node->seen_ = seen; | |||
| // select nodes that this transform can be applied. | |||
| bool is_match = transform->predicate_(node); | |||
| // apply transform on this node | |||
| bool change = false; | |||
| if (is_match) { | |||
| TraceGuard trace_guard(std::make_shared<TraceOpt>(node->debug_info())); | |||
| auto ret = (*transform)(optimizer, node); | |||
| if (ret != nullptr && ret != node) { | |||
| for (auto &substitution : list_) { | |||
| auto res = DoTransform(optimizer, node, substitution); | |||
| if (res != nullptr) { | |||
| change = true; | |||
| changes = true; | |||
| node = res; | |||
| todo.emplace_back(res); | |||
| break; | |||
| } | |||
| } | |||
| UpdateTransformingList(optimizer, node, &todo, change, seen); | |||
| } | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| MsProfile::StatTime("opt.transforms." + optimizer->name(), GetTime() - start); | |||
| #endif | |||
| MS_LOG(DEBUG) << "transform: " << transform->name_ << " will replace: " << node->DebugString() | |||
| << " with: " << ret->DebugString(); | |||
| (void)manager->Replace(node, ret); | |||
| return changes; | |||
| } | |||
| bool SubstitutionList::ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, | |||
| const SubstitutionPtr &substitution) const { | |||
| #ifdef ENABLE_PROFILE | |||
| MsProfile::StatTime("replace." + transform->name_, GetTime() - t); | |||
| double start = GetTime(); | |||
| #endif | |||
| node = ret; | |||
| } | |||
| } | |||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||
| auto seen = NewSeenGeneration(); | |||
| // 1024 is for the initial capacity of deque | |||
| std::deque<AnfNodePtr> todo(1024); | |||
| todo.clear(); | |||
| todo.emplace_back(root_node); | |||
| bool changes = false; | |||
| // find success, and add them to todo list | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| todo.push_back(GetValueNode<FuncGraphPtr>(node)->output()); | |||
| } | |||
| auto &all_nodes = manager->all_nodes(); | |||
| while (!todo.empty()) { | |||
| AnfNodePtr node = todo.front(); | |||
| todo.pop_front(); | |||
| if (node->isa<CNode>()) { | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); | |||
| if (node == nullptr || node->seen_ == seen || !isTraversable(node) || !all_nodes.contains(node)) { | |||
| continue; | |||
| } | |||
| node->seen_ = seen; | |||
| auto &node_users = manager->node_users(); | |||
| if (change && node_users.find(node) != node_users.end()) { | |||
| for (auto &use : node_users[node]) { | |||
| auto use_node = use.first; | |||
| if (use_node == nullptr) { | |||
| continue; | |||
| } | |||
| todo.push_back(use_node); | |||
| if (use_node->seen_ == seen) { | |||
| use_node->seen_--; | |||
| } | |||
| } | |||
| bool change = false; | |||
| auto res = DoTransform(optimizer, node, substitution); | |||
| if (res != nullptr) { | |||
| change = true; | |||
| changes = true; | |||
| node = res; | |||
| } | |||
| UpdateTransformingList(optimizer, node, &todo, change, seen); | |||
| } | |||
| #ifdef ENABLE_PROFILE | |||
| @@ -182,13 +238,29 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo | |||
| return changes; | |||
| } | |||
| bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||
| manager->AddFuncGraph(func_graph); | |||
| bool SubstitutionList::ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const { | |||
| const auto &manager = optimizer->manager(); | |||
| const auto &nodes = manager->isolate_nodes(); | |||
| bool changes = false; | |||
| bool loop = true; | |||
| while (loop) { | |||
| loop = false; | |||
| std::for_each(list_.cbegin(), list_.cend(), [&](const auto &substitution) { | |||
| std::for_each(nodes.cbegin(), nodes.cend(), [&](const auto &node) { | |||
| bool change = ApplySubstitutionToIR(optimizer, node, substitution); | |||
| changes = changes || change; | |||
| loop = loop || change; | |||
| }); | |||
| }); | |||
| if (is_once_) { | |||
| break; | |||
| } | |||
| } | |||
| return changes; | |||
| } | |||
| // for transform status counting | |||
| bool SubstitutionList::ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const { | |||
| // Add for substitution status counting | |||
| size_t space = 0; | |||
| std::unordered_map<std::string, std::vector<bool>> status; | |||
| if (optimizer->is_on_debug_) { | |||
| @@ -197,47 +269,39 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize | |||
| } | |||
| } | |||
| bool loop = false; | |||
| bool changes = false; | |||
| do { | |||
| bool loop = true; | |||
| while (loop) { | |||
| loop = false; | |||
| for (size_t i = 0; i < list_.size(); i++) { | |||
| auto change = ApplyTransform(optimizer, func_graph->output(), list_[i]); | |||
| const auto &substitution = list_[i]; | |||
| bool change = ApplySubstitutionToIR(optimizer, func_graph->output(), substitution); | |||
| changes = changes || change; | |||
| loop = loop || change; | |||
| // apply transform on isolate nodes. | |||
| auto &isolate_nodes = manager->isolate_nodes(); | |||
| for (auto &node : isolate_nodes) { | |||
| change = ApplyTransform(optimizer, node, list_[i]); | |||
| changes = changes || change; | |||
| loop = loop || change; | |||
| } | |||
| // record the status of each transform | |||
| static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); | |||
| if (enable_dump_pass_ir && MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { | |||
| auto fg_name = optimizer->name() + "_" + std::to_string(optimizer->CurPass_.counter) + "_" + | |||
| optimizer->CurPass_.name + "_" + list_[i]->name_; | |||
| auto fg_name = optimizer->name() + "_r" + std::to_string(optimizer->CurPass_.counter) + "_" + | |||
| optimizer->CurPass_.name + "_" + substitution->name_; | |||
| DumpIR(fg_name + ".ir", func_graph); | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| ExportIR(fg_name + ".dat", "", func_graph); | |||
| } | |||
| } | |||
| // Record the status of each substitution | |||
| if (optimizer->is_on_debug_) { | |||
| status[list_[i]->name_ + std::to_string(i)].push_back(change); | |||
| space = std::max(list_[i]->name_.size(), space); | |||
| status[substitution->name_ + std::to_string(i)].push_back(change); | |||
| space = std::max(substitution->name_.size(), space); | |||
| } | |||
| } | |||
| if (is_once_) { | |||
| break; | |||
| } | |||
| } while (loop); | |||
| } | |||
| // display the status of each transform | |||
| // Display the status of each substitution | |||
| if (optimizer->is_on_debug_) { | |||
| std::stringstream ss; | |||
| ss << std::endl | |||
| @@ -253,7 +317,37 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize | |||
| } | |||
| MS_LOG(DEBUG) << ss.str(); | |||
| } | |||
| return changes; | |||
| } | |||
| bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { | |||
| MS_EXCEPTION_IF_NULL(optimizer); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| FuncGraphManagerPtr manager = optimizer->manager(); | |||
| manager->AddFuncGraph(func_graph); | |||
| bool changes = false; | |||
| static const auto traverse_mode = | |||
| (common::GetEnv("ENV_TRAVERSE_SUBSTITUTIONS_MODE") != "1" ? kOptTraverseFromIRToSubstitutions | |||
| : kOptTraverseFromSubstitutionsToIR); | |||
| if (traverse_mode == kOptTraverseFromIRToSubstitutions && | |||
| MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode && | |||
| optimizer->traverse_nodes_first()) { | |||
| changes = ApplyIRToSubstitutions(optimizer, func_graph); | |||
| } else { | |||
| changes = ApplySubstitutionsToIR(optimizer, func_graph); | |||
| } | |||
| bool has_isolate = !manager->isolate_nodes().empty(); | |||
| if (has_isolate) { | |||
| #ifdef ENABLE_PROFILE | |||
| double t = GetTime(); | |||
| #endif | |||
| bool change = ApplySubstitutionsToIRForIsolate(optimizer); | |||
| changes = changes || change; | |||
| #ifdef ENABLE_PROFILE | |||
| MsProfile::StatTime("opt.isolate.transform." + optimizer->name(), GetTime() - t); | |||
| #endif | |||
| } | |||
| return changes; | |||
| } | |||
| } // namespace opt | |||
| @@ -59,6 +59,8 @@ SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std: | |||
| SubstitutionPtr MakeSubstitution(const OptimizerCallerPtr &transform, const std::string &name, | |||
| const PredicateFuncType &predicate, const RenormAction &action_renorm = CHECK_RENORM); | |||
| enum OptTraverseSubstitutionsMode { kOptTraverseFromIRToSubstitutions = 0, kOptTraverseFromSubstitutionsToIR }; | |||
| class SubstitutionList { | |||
| public: | |||
| explicit SubstitutionList(const std::vector<SubstitutionPtr> &patterns, bool is_once = false) | |||
| @@ -68,7 +70,10 @@ class SubstitutionList { | |||
| bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const; | |||
| private: | |||
| bool ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &transform) const; | |||
| bool ApplyIRToSubstitutions(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | |||
| bool ApplySubstitutionToIR(const OptimizerPtr &optimizer, const AnfNodePtr &node, const SubstitutionPtr &sub) const; | |||
| bool ApplySubstitutionsToIR(const OptimizerPtr &optimizer, const FuncGraphPtr &func_graph) const; | |||
| bool ApplySubstitutionsToIRForIsolate(const OptimizerPtr &optimizer) const; | |||
| std::vector<SubstitutionPtr> list_; | |||
| // a flag to mark this list of Substitution can only be executed only once | |||
| bool is_once_; | |||
| @@ -88,13 +88,14 @@ using OptPassGroupMap = std::vector<std::pair<std::string, OptPassConfig>>; | |||
| class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| public: | |||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr) | |||
| Optimizer(const std::string &name, const pipeline::ResourceBasePtr &resource_ptr, bool traverse_nodes_first = true) | |||
| : name_(name), | |||
| resource_(resource_ptr), | |||
| run_only_once_(false), | |||
| is_watch_renormalize_(false), | |||
| is_enable_(true), | |||
| is_untyped_generated_(false) {} | |||
| is_untyped_generated_(false), | |||
| traverse_nodes_first_(traverse_nodes_first) {} | |||
| virtual ~Optimizer() = default; | |||
| void Init(const OptPassGroupMap &passes, bool run_only_once) { | |||
| @@ -129,8 +130,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| static std::shared_ptr<Optimizer> MakeOptimizer(const std::string &name, const pipeline::ResourceBasePtr resource_ptr, | |||
| const OptPassGroupMap &passes, bool run_only_once = false, | |||
| bool watch_renormalize = false) { | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr); | |||
| bool watch_renormalize = false, bool traverse_nodes_first = true) { | |||
| OptimizerPtr optimizer = std::make_shared<Optimizer>(name, resource_ptr, traverse_nodes_first); | |||
| optimizer->Init(passes, run_only_once); | |||
| if (watch_renormalize) { | |||
| optimizer->enable_watch_renormalize(); | |||
| @@ -223,6 +224,8 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| bool is_watch_renormalize() { return is_watch_renormalize_; } | |||
| void set_enable(bool enable) { is_enable_ = enable; } | |||
| bool traverse_nodes_first() { return traverse_nodes_first_; } | |||
| struct { | |||
| int64_t counter; | |||
| std::string name; | |||
| @@ -239,6 +242,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| bool is_watch_renormalize_; | |||
| bool is_enable_; | |||
| bool is_untyped_generated_; | |||
| bool traverse_nodes_first_; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -308,7 +308,8 @@ bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBa | |||
| return false; | |||
| } | |||
| opt::irpass::ResolveIRPassLib irpass; | |||
| opt::OptimizerPtr opt_resolve = opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass)); | |||
| opt::OptimizerPtr opt_resolve = | |||
| opt::Optimizer::MakeOptimizer("opt_resolve", res, GetOptResolvePasses(irpass), false, false, false); | |||
| (void)parse::python_adapter::set_python_scoped(); | |||
| @@ -246,7 +246,7 @@ class CNode : public AnfNode, public EffectInfoHolder { | |||
| bool IsApply(const PrimitivePtr &) const; | |||
| const size_t size() const { return inputs_.size(); } | |||
| const AnfNodePtr input(size_t i) const { return inputs_[i]; } | |||
| const AnfNodePtr &input(size_t i) const { return inputs_.at(i); } | |||
| const std::vector<AnfNodePtr> &inputs() const { return inputs_; } | |||
| void add_input(const AnfNodePtr &input); | |||
| void set_input(size_t i, const AnfNodePtr &input); | |||