Merge pull request !19289 from 张清华/cconv_opt0tags/v1.3.0
| @@ -101,12 +101,15 @@ void DumpInferStack(std::ostringstream &oss) { | |||
| infer_vec.clear(); | |||
| break; | |||
| } | |||
| auto graph_context = graph_infer->context(); | |||
| auto graph_context = graph_infer->parent_context(); | |||
| if (graph_context == nullptr) { | |||
| MS_LOG(INFO) << "Null context continue"; | |||
| continue; | |||
| } | |||
| auto graph = graph_context->func_graph(); | |||
| if (graph == nullptr) { | |||
| continue; | |||
| } | |||
| auto args_spec_list = graph_context->args_spec_list(); | |||
| oss << " #" << index++ << " " << GetGraphParamString(graph, args_spec_list); | |||
| } | |||
| @@ -264,7 +267,7 @@ std::vector<AnalysisContextPtr> AnalyzedFuncGraphExporter::ProcessFuncGraphCall( | |||
| } | |||
| auto base_fg_evaluator = dyn_cast<abstract::BaseFuncGraphEvaluator>(evaluator); | |||
| auto ctx = base_fg_evaluator->context(); | |||
| auto ctx = base_fg_evaluator->parent_context(); | |||
| if (ctx != nullptr && context_map_.insert({ctx, false}).second) { | |||
| MS_LOG(DEBUG) << "Add new context, ctx.addr = " << ctx.get() << "ctx = " << ctx->ToString(); | |||
| context_vec_.push_back(ctx); | |||
| @@ -506,7 +509,7 @@ void GetEvalStackInfo(std::ostringstream &oss) { | |||
| return; | |||
| } | |||
| static int fileNumber = 0; | |||
| string file_name = "analyze_fail" + std::to_string(fileNumber++) + ".dat"; | |||
| string file_name = "analyze_fail_" + std::to_string(fileNumber++) + ".dat"; | |||
| auto ms_om_path = common::GetEnv("MS_OM_PATH"); | |||
| if (!ms_om_path.empty()) { | |||
| auto path = ms_om_path + "/" + file_name; | |||
| @@ -37,6 +37,8 @@ namespace ad { | |||
| std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_; | |||
| std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_; | |||
| bool lift_fv_before_grad = true; | |||
| DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources) | |||
| : primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) { | |||
| { | |||
| @@ -73,6 +75,10 @@ void DFunctor::Clear() { | |||
| } | |||
| void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) { | |||
| if (lift_fv_before_grad) { | |||
| MS_LOG(EXCEPTION) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " | |||
| << fv->func_graph()->ToString() << " " << fv->ToString() << "."; | |||
| } | |||
| auto fv_adjoint = anfnode_to_adjoin_.find(fv); | |||
| if (fv_adjoint == anfnode_to_adjoin_.end()) { | |||
| MS_LOG(DEBUG) << "BackPropagateFv can not find adjoint in anfnode_to_adjoin_ fv " << fv->func_graph()->ToString() | |||
| @@ -437,6 +443,13 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { | |||
| AnfNodePtr new_grad_fv = grad_fv; | |||
| // Add grads wrt fv. | |||
| const auto &free_variables_nodes = primal_graph_->free_variables_nodes(); | |||
| if (!is_top_ && free_variables_nodes.size() != 0) { | |||
| if (lift_fv_before_grad) { | |||
| MS_LOG(EXCEPTION) << "direct fv size is: " << free_variables_nodes.size() << " in " << primal_graph_->ToString() | |||
| << "."; | |||
| } | |||
| } | |||
| for (auto &fv : free_variables_nodes) { | |||
| auto fv_adjoint = anfnode_to_adjoin_.find(fv); | |||
| if (fv_adjoint == anfnode_to_adjoin_.end()) { | |||
| @@ -460,6 +473,10 @@ AnfNodePtr DFunctor::AttachFvDoutToTape(const AnfNodePtr &grad_fv) { | |||
| } | |||
| AnfNodePtr DFunctor::AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv) { | |||
| if (lift_fv_before_grad) { | |||
| MS_LOG(EXCEPTION) << "Lift free variable case: AttachIndirectFvDoutToTape backprop indirect fv " | |||
| << grad_fv->ToString() << " " << primal_graph_->ToString() << "."; | |||
| } | |||
| AnfNodePtr new_grad_fv = grad_fv; | |||
| // Add indirect fv bprop. | |||
| for (auto &fv_adjoint : anfnode_to_adjoin_indirect_fv_) { | |||
| @@ -497,7 +514,12 @@ void DFunctor::MapMorphism() { | |||
| output_adjoint->second->AccumulateDout(dout_); | |||
| // Set output for tape closure. | |||
| auto grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); | |||
| AnfNodePtr grad_fv; | |||
| if (lift_fv_before_grad) { | |||
| grad_fv = AttachFvDoutToTape(NewValueNode(newenv)); | |||
| } else { | |||
| grad_fv = AttachIndirectFvDoutToTape(AttachFvDoutToTape(NewValueNode(newenv))); | |||
| } | |||
| std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple), grad_fv}; | |||
| // Add grads wrt inputs. | |||
| @@ -43,6 +43,9 @@ extern KPrim g_k_prims; | |||
| class DFunctor; | |||
| using DFunctorPtr = std::shared_ptr<DFunctor>; | |||
| // Flag to control if fv should be lifted before grad. If this lift_fv feature is mature, then this flag can be removed. | |||
| extern bool lift_fv_before_grad; | |||
| // D Functor's rules to map closure object and morphisms. | |||
| class DFunctor : public std::enable_shared_from_this<DFunctor> { | |||
| public: | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -16,12 +16,52 @@ | |||
| #include "frontend/optimizer/ad/grad.h" | |||
| #include "frontend/optimizer/ad/dfunctor.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "utils/ms_context.h" | |||
| #include "utils/symbolic.h" | |||
| namespace mindspore { | |||
| namespace ad { | |||
| namespace { | |||
| FuncGraphPtr PartialEliminateOptPass(const ResourcePtr &resource, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(resource); | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| opt::OptPassConfig partial_eliminate_opt_ = opt::OptPassConfig( | |||
| {irpass.partial_eliminate_, irpass.switch_partial_eliminater_, irpass.switch_layer_partial_eliminater_}); | |||
| opt::OptPassGroupMap map({{"partial_eliminate_", partial_eliminate_opt_}}); | |||
| auto after_lift_opt = opt::Optimizer::MakeOptimizer("partial_eliminate", resource, map); | |||
| FuncGraphPtr opt_fg = nullptr; | |||
| WITH(MsProfile::GetProfile()->Step("partial_eliminate_before_grad"))[&after_lift_opt, func_graph, &opt_fg]() { | |||
| opt_fg = after_lift_opt->step(func_graph, true); | |||
| }; | |||
| return opt_fg; | |||
| } | |||
| FuncGraphPtr LiftFv(const pipeline::ResourceBasePtr &resource, const FuncGraphPtr &func_graph) { | |||
| bool save_graphs_flag = MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG); | |||
| if (save_graphs_flag) { | |||
| DumpIR("before_lift_" + func_graph->ToString() + ".ir", func_graph); | |||
| } | |||
| FuncGraphPtr new_fg = LiftingClone(func_graph); | |||
| if (save_graphs_flag) { | |||
| DumpIR("after_lift_" + new_fg->ToString() + ".ir", new_fg); | |||
| } | |||
| auto new_res = std::dynamic_pointer_cast<pipeline::Resource>(resource); | |||
| if (new_res == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Parameter resources is not a pipeline::Resource"; | |||
| } | |||
| auto opt_fg = PartialEliminateOptPass(new_res, new_fg); | |||
| if (save_graphs_flag) { | |||
| DumpIR("after_opt_" + opt_fg->ToString() + ".ir", opt_fg); | |||
| } | |||
| return opt_fg; | |||
| } | |||
| } // namespace | |||
| FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &resources, bool is_top) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto gradkv = func_graph->transforms().find("grad"); | |||
| @@ -33,6 +73,11 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||
| MS_EXCEPTION_IF_NULL(manager_ptr); | |||
| manager_ptr->AddFuncGraph(func_graph); | |||
| FuncGraphPtr grad_fg = func_graph; | |||
| lift_fv_before_grad = (common::GetEnv("ENV_DONT_LIFT_FV_BEFORE_GRAD") != "1"); | |||
| if (func_graph->func_graphs_used().size() != 0 && lift_fv_before_grad) { | |||
| grad_fg = LiftFv(resources, func_graph); | |||
| } | |||
| auto multi_graph_sink = [&func_graph](const FuncGraphPtr &f) { | |||
| if (MsContext::GetInstance()->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) { | |||
| if (func_graph->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| @@ -41,8 +86,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||
| } | |||
| }; | |||
| auto f = std::make_shared<DFunctor>(func_graph, resources); | |||
| auto user_defined = f->KUserDefined(func_graph); | |||
| auto f = std::make_shared<DFunctor>(grad_fg, resources); | |||
| auto user_defined = f->KUserDefined(grad_fg); | |||
| if (user_defined != nullptr) { | |||
| multi_graph_sink(user_defined); | |||
| if (is_top) { | |||
| @@ -62,6 +107,9 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt | |||
| } | |||
| multi_graph_sink(res); | |||
| if (func_graph != grad_fg) { | |||
| (void)func_graph->transforms().insert(std::make_pair("grad", FuncGraphTransform(res))); | |||
| } | |||
| return res; | |||
| } | |||
| @@ -327,6 +327,11 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ | |||
| << " prim bprop function to J expanded func graph. NodeInfo: " | |||
| << trace::GetDebugInfo(bprop_fg->debug_info()); | |||
| } | |||
| if (lift_fv_before_grad && IsPrimitiveEquals(prim, prim::kPrimSwitch)) { | |||
| // Inline fprop_switch before renormalize; | |||
| expanded_fg->set_flag(FUNC_GRAPH_FLAG_FORCE_INLINE, true); | |||
| MS_LOG(DEBUG) << "set force_inline for fg: " << expanded_fg->ToString(); | |||
| } | |||
| return expanded_fg; | |||
| } | |||
| @@ -162,6 +162,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| exchange_switch_depend_value_ = | |||
| MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch); | |||
| switch_partial_eliminater_ = | |||
| MakeSubstitution(std::make_shared<SwitchPartialEliminater>(), "eliminate_switch_partial_", IsCNodeDup); | |||
| switch_layer_partial_eliminater_ = | |||
| MakeSubstitution(std::make_shared<SwitchLayerPartialEliminater>(), "eliminate_switch_layer_partial_", IsCNodeDup); | |||
| // Addn | |||
| merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN); | |||
| addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN); | |||
| @@ -86,6 +86,9 @@ class OptimizeIRPassLib { | |||
| SubstitutionPtr convert_switch_replacement_; | |||
| SubstitutionPtr exchange_switch_depend_value_; | |||
| SubstitutionPtr switch_partial_eliminater_; | |||
| SubstitutionPtr switch_layer_partial_eliminater_; | |||
| // AddN | |||
| SubstitutionPtr merge_addn_; | |||
| SubstitutionPtr addn_zero_filter_; | |||
| @@ -20,6 +20,7 @@ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include "frontend/optimizer/irpass.h" | |||
| @@ -38,29 +39,36 @@ class CallOutputTransform { | |||
| CallOutputTransform() : cache_() {} | |||
| ~CallOutputTransform() = default; | |||
| FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs) { | |||
| FuncGraphPtr operator()(const FuncGraphPtr &fg, size_t nargs, bool xs_first) { | |||
| if (cache_.find(fg) == cache_.end()) { | |||
| cache_[fg] = {}; | |||
| } | |||
| auto &cache = cache_[fg]; | |||
| if (cache.find(nargs) == cache.end()) { | |||
| auto key = std::make_pair(nargs, xs_first); | |||
| if (cache.find(key) == cache.end()) { | |||
| FuncGraphPtr new_fg = TransformableClone(fg, std::make_shared<TraceTransform>("call")); | |||
| std::vector<AnfNodePtr> new_items; | |||
| new_items.push_back(new_fg->output()); | |||
| for (size_t i = 0; i < nargs; i++) { | |||
| new_items.push_back(new_fg->add_parameter()); | |||
| if (xs_first) { | |||
| for (size_t i = 0; i < nargs; i++) { | |||
| new_items.push_back(new_fg->add_parameter()); | |||
| } | |||
| } else { | |||
| for (size_t i = 0; i < nargs; i++) { | |||
| new_items.push_back(new_fg->InsertFrontParameter()); | |||
| } | |||
| } | |||
| new_fg->set_output(new_fg->NewCNode(new_items)); | |||
| cache[nargs] = new_fg; | |||
| cache[key] = new_fg; | |||
| } | |||
| return cache[nargs]; | |||
| return cache[key]; | |||
| } | |||
| private: | |||
| std::unordered_map<FuncGraphPtr, std::unordered_map<size_t, FuncGraphPtr>> cache_; | |||
| std::unordered_map<FuncGraphPtr, std::unordered_map<std::pair<size_t, bool>, FuncGraphPtr, PairHasher>> cache_; | |||
| }; | |||
| } // namespace internal | |||
| @@ -88,20 +96,35 @@ class IncorporateCall : public AnfVisitor { | |||
| auto xs_size = Xs_.size(); | |||
| auto ys_size = inputs.size() - 1; | |||
| auto new_fg = call_output_transform_(fg_, ys_size); | |||
| bool xs_first = true; | |||
| if ((xs_size > 0) && (Xs_[xs_size - 1]->abstract() != nullptr) && | |||
| (Xs_[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) { | |||
| xs_first = false; | |||
| } | |||
| auto new_fg = call_output_transform_(fg_, ys_size, xs_first); | |||
| std::vector<AnfNodePtr> args; | |||
| args.push_back(NewValueNode(new_fg)); | |||
| if (xs_size > 0) { | |||
| (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); | |||
| } | |||
| if (ys_size > 0) { | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| if (xs_first) { | |||
| if (xs_size > 0) { | |||
| (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); | |||
| } | |||
| if (ys_size > 0) { | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| } | |||
| } else { | |||
| if (ys_size > 0) { | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| } | |||
| if (xs_size > 0) { | |||
| (void)args.insert(args.end(), Xs_.begin(), Xs_.end()); | |||
| } | |||
| } | |||
| return node->func_graph()->NewCNode(args); | |||
| auto new_node = node->func_graph()->NewCNode(args); | |||
| new_node->set_abstract(node->abstract()); | |||
| return new_node; | |||
| } | |||
| void Visit(const CNodePtr &cnode) override { | |||
| @@ -159,19 +182,35 @@ class IncorporateCallSwitch : public AnfVisitor { | |||
| auto fg = node->func_graph(); | |||
| auto xs_size = inputs_x.size() - 1; | |||
| auto ys_size = inputs.size() - 1; | |||
| auto new_g1 = call_output_transform_(g1_, ys_size); | |||
| auto new_g2 = call_output_transform_(g2_, ys_size); | |||
| bool xs_first = true; | |||
| if ((xs_size > 0) && (inputs_x[xs_size - 1]->abstract() != nullptr) && | |||
| (inputs_x[xs_size - 1]->abstract()->isa<abstract::AbstractMonad>())) { | |||
| xs_first = false; | |||
| } | |||
| auto new_g1 = call_output_transform_(g1_, ys_size, xs_first); | |||
| auto new_g2 = call_output_transform_(g2_, ys_size, xs_first); | |||
| auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); | |||
| std::vector<AnfNodePtr> args{sw_node}; | |||
| if (xs_size > 0) { | |||
| (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); | |||
| } | |||
| if (ys_size > 0) { | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| if (xs_first) { | |||
| if (xs_size > 0) { | |||
| (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); | |||
| } | |||
| if (ys_size > 0) { | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| } | |||
| } else { | |||
| if (ys_size > 0) { | |||
| (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); | |||
| } | |||
| if (xs_size > 0) { | |||
| (void)args.insert(args.end(), inputs_x.begin() + 1, inputs_x.end()); | |||
| } | |||
| } | |||
| return fg->NewCNode(args); | |||
| auto new_node = fg->NewCNode(args); | |||
| new_node->set_abstract(node->abstract()); | |||
| return new_node; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| @@ -164,6 +164,7 @@ class IncorporateGetitem : public AnfVisitor { | |||
| } | |||
| } | |||
| } | |||
| new_node->set_abstract(node->abstract()); | |||
| return new_node; | |||
| } | |||
| @@ -228,6 +229,7 @@ class IncorporateGetitemDepend : public AnfVisitor { | |||
| new_depend_cnode = | |||
| node->func_graph()->NewCNode({NewValueNode(prim::kPrimDepend), new_fg_cnode, depend_2nd_input_}); | |||
| } | |||
| new_depend_cnode->set_abstract(node->abstract()); | |||
| return new_depend_cnode; | |||
| } | |||
| @@ -294,8 +296,7 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| is_in_get_ = false; | |||
| auto fg = node->func_graph(); | |||
| if (idx_ == -1 || switch_ == nullptr || fg == nullptr || | |||
| (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) && !ExistEnvNode(fg))) { | |||
| if (idx_ == -1 || switch_ == nullptr || fg == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -308,10 +309,27 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| } | |||
| auto tuple_getitem = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| bool has_env_type = false; | |||
| if (tuple_getitem->input(1)->abstract() && tuple_getitem->input(1)->abstract()->isa<abstract::AbstractTuple>()) { | |||
| const auto &abs_tuple = *(tuple_getitem->input(1)->abstract()->cast<abstract::AbstractTuplePtr>()); | |||
| // eliminate (envinstance, value1, value2, ...) built by bprop func_graph() | |||
| if (abs_tuple.size() >= 1) { | |||
| // Value maybe kAnyValue, so check the type track; | |||
| if (abs_tuple[0]->isa<abstract::AbstractScalar>() && abs_tuple[0]->GetTypeTrack()->isa<EnvType>()) { | |||
| has_env_type = true; | |||
| } | |||
| } | |||
| // eliminate (value, bprop_func) built by fprop func_graph | |||
| if (abs_tuple.size() >= 2) { | |||
| if (abs_tuple[1]->isa<abstract::AbstractFunction>()) { | |||
| has_env_type = true; | |||
| } | |||
| } | |||
| } | |||
| // If exist env_getitem/env_setitem in this funcgraph or | |||
| // if g1_/g2_ is fprop func_graph and the corresponding bprop funcgraph has any env_getitem or env_setitem; | |||
| if (MultipleUseOfSwitch(tuple_getitem->input(1), fg) && !ExistEnvNode(fg) && !ExistEnvNodeInTupleItem(g1_) && | |||
| !ExistEnvNodeInTupleItem(g2_)) { | |||
| !ExistEnvNodeInTupleItem(g2_) && !has_env_type) { | |||
| return nullptr; | |||
| } | |||
| auto new_g1 = getitem_transform_(g1_, idx_); | |||
| @@ -319,7 +337,9 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| auto sw_node = fg->NewCNode({NewValueNode(prim::kPrimSwitch), x_, NewValueNode(new_g1), NewValueNode(new_g2)}); | |||
| (void)args_.insert(args_.begin(), sw_node); | |||
| return fg->NewCNode(args_); | |||
| auto new_node = fg->NewCNode(args_); | |||
| new_node->set_abstract(node->abstract()); | |||
| return new_node; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { | |||
| @@ -398,8 +418,8 @@ class IncorporateGetitemSwitch : public AnfVisitor { | |||
| } | |||
| const auto &cnode = output->cast<CNodePtr>(); | |||
| const auto &inputs = cnode->inputs(); | |||
| return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &node) { | |||
| auto sub_fg = GetValueNode<FuncGraphPtr>(node); | |||
| return std::any_of(inputs.cbegin() + 1, inputs.cend(), [](const auto &input) { | |||
| auto sub_fg = GetValueNode<FuncGraphPtr>(input); | |||
| if (sub_fg != nullptr && ExistEnvNode(sub_fg)) { | |||
| return true; | |||
| } | |||
| @@ -91,6 +91,9 @@ bool IsInside(InlinerBase *, const FuncGraphPtr &, const AnfNodePtr &node); | |||
| bool IsCore(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &); | |||
| bool IsDirectParentCall(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &node); | |||
| bool IsNotRecursive(InlinerBase *inliner, const FuncGraphPtr &fg, const AnfNodePtr &); | |||
| bool IsForceInline(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) { | |||
| return fg->has_flag(FUNC_GRAPH_FLAG_FORCE_INLINE); | |||
| } | |||
| // {G, Xs} | |||
| class InlinerBase : public AnfVisitor { | |||
| @@ -138,6 +141,14 @@ class InlinerBase : public AnfVisitor { | |||
| return nullptr; | |||
| } | |||
| if (IsForceInline(this, fg, node)) { | |||
| if (IsUniqueUse(nullptr, fg, nullptr)) { | |||
| return InlineMove(node, fg, args, inputs); | |||
| } else { | |||
| return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); | |||
| } | |||
| } | |||
| if (IsUniqueUse(nullptr, fg, nullptr)) { | |||
| // For the single used fg, including non-after and after not matched above, | |||
| // we move the whole fg nodes. | |||
| @@ -162,15 +173,20 @@ class InlinerBase : public AnfVisitor { | |||
| return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); | |||
| } | |||
| AnfNodePtr InlineMove(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| ReplaceParams(mng, args, fg); | |||
| auto out_node = fg->output(); | |||
| mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); | |||
| return out_node; | |||
| } | |||
| AnfNodePtr InlineForUniqueUse(const AnfNodePtr &node, const FuncGraphPtr &fg, const std::vector<AnfNodePtr> &args, | |||
| const std::vector<AnfNodePtr> &inputs) { | |||
| if (use_move_) { | |||
| auto mng = fg->manager(); | |||
| MS_EXCEPTION_IF_NULL(mng); | |||
| ReplaceParams(mng, args, fg); | |||
| auto out_node = fg->output(); | |||
| mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); | |||
| return out_node; | |||
| return InlineMove(node, fg, args, inputs); | |||
| } | |||
| // The other branch calling the last after block. | |||
| @@ -377,6 +393,7 @@ class DirectInliner : public InlinerBase { | |||
| : InlinerBase( | |||
| // Supports AND conditions in one criterion, Ex. {IsUniqueUse, IsNotRecursive}. | |||
| { | |||
| {IsForceInline}, | |||
| {IsDirectParentCall}, | |||
| }, | |||
| use_move) {} | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -17,9 +17,11 @@ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| @@ -72,6 +74,341 @@ class PartialEliminater : public AnfVisitor { | |||
| private: | |||
| std::vector<AnfNodePtr> Xs_{}; | |||
| }; | |||
| class ChoicePartialEliminater : public AnfVisitor { | |||
| public: | |||
| virtual ~ChoicePartialEliminater() = default; | |||
| protected: | |||
| AnfNodePtrList fg_list_{}; | |||
| std::vector<AnfNodePtrList> args_list_{}; | |||
| void Visit(const AnfNodePtr &node) override { | |||
| if (!IsPrimitiveCNode(node, prim::kPrimPartial)) { | |||
| if (IsValueNode<FuncGraph>(node)) { | |||
| fg_list_.push_back(node); | |||
| args_list_.push_back(AnfNodePtrList{}); | |||
| } | |||
| return; | |||
| } | |||
| auto &inputs = node->cast<CNodePtr>()->inputs(); | |||
| // {prim::kPrimPartial, G, Xs} | |||
| if (inputs.size() < 3) { | |||
| MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString(); | |||
| return; | |||
| } | |||
| if (IsValueNode<FuncGraph>(inputs[1])) { | |||
| fg_list_.push_back(inputs[1]); | |||
| AnfNodePtrList args; | |||
| (void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(args)); | |||
| args_list_.push_back(args); | |||
| } | |||
| return; | |||
| } | |||
| // return value: true -- continue replace; false -- return nullptr; | |||
| bool CheckFuncGraphAndArgs() { | |||
| // Either one should be {Partial, G, X} | |||
| auto has_partial_args = | |||
| std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; }); | |||
| if (!has_partial_args) { | |||
| return false; | |||
| } | |||
| // check funcgraph should be used once only. | |||
| for (auto &fg_node : fg_list_) { | |||
| auto fg = GetValueNode<FuncGraphPtr>(fg_node); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| if (fg->func_graph_cnodes_index().size() != 1) { | |||
| for (auto iter : fg->func_graph_cnodes_index()) { | |||
| MS_LOG(ERROR) << "fg user: " << iter.first->first->DebugString(1) << ", index: " << iter.first->second; | |||
| } | |||
| MS_LOG(EXCEPTION) << "fg is used multiple times: " << fg->ToString(); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| // f(x1, x2, x3, z1, z2) | |||
| // g(x4, x2, z1, z2) | |||
| // h(x5, x2, x7, x8, z1, z2) | |||
| // --> anchor_fg = h | |||
| // h(x5, x2, x7, x8, x1, x3, x4, z1, z2) | |||
| // f(x5, x2, x7, x8, x1, x3, x4, z1, z2) | |||
| // g(x5, x2, x7, x8, x1, x3, x4, z1, z2) | |||
| // as z1, z2 maybe U or IO monad. | |||
| AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list, | |||
| const std::vector<AnfNodePtrList> args_list) { | |||
| std::vector<size_t> inputs_index_list[args_list.size()]; | |||
| size_t extra_input_counter = 0; | |||
| AnfNodePtrList extra_inputs; | |||
| const auto &anchor_args = args_list[anchor_index]; | |||
| size_t anchor_args_size = anchor_args.size(); | |||
| auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]); | |||
| MS_EXCEPTION_IF_NULL(anchor_fg); | |||
| // Find the new location of the old_inputs except Zs; | |||
| for (size_t i = 0; i < args_list.size(); ++i) { | |||
| if (i == anchor_index) { | |||
| continue; | |||
| } | |||
| const auto &another_args = args_list[i]; | |||
| auto &curr_inputs_index = inputs_index_list[i]; | |||
| for (size_t j = 0; j < another_args.size(); ++j) { | |||
| size_t k; | |||
| for (k = 0; k < anchor_args_size; ++k) { | |||
| if (another_args[j] == anchor_args[k]) { | |||
| curr_inputs_index.push_back(k); | |||
| break; | |||
| } | |||
| } | |||
| if (k == anchor_args_size) { | |||
| // check if used by another func_graph; | |||
| for (k = 0; k < extra_input_counter; ++k) { | |||
| if (another_args[j] == extra_inputs[k]) { | |||
| curr_inputs_index.push_back(anchor_args_size + k); | |||
| break; | |||
| } | |||
| } | |||
| if (k == extra_input_counter) { | |||
| extra_inputs.push_back(another_args[j]); | |||
| curr_inputs_index.push_back(anchor_args_size + extra_input_counter); | |||
| extra_input_counter++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| auto manager = anchor_fg->manager(); | |||
| auto txn = manager->Transact(); | |||
| size_t anchor_params_size = anchor_fg->parameters().size(); | |||
| const auto &anchor_fg_params = anchor_fg->parameters(); | |||
| for (size_t i = 0; i < args_list.size(); ++i) { | |||
| if (i == anchor_index) { | |||
| continue; | |||
| } | |||
| AnfNodePtrList new_params; | |||
| new_params.resize(anchor_params_size + extra_input_counter); | |||
| const auto &curr_inputs_index = inputs_index_list[i]; | |||
| auto another_fg = GetValueNode<FuncGraphPtr>(fg_list[i]); | |||
| MS_EXCEPTION_IF_NULL(another_fg); | |||
| const auto &old_params = another_fg->parameters(); | |||
| const auto &old_args = args_list[i]; | |||
| for (size_t j = 0; j < old_args.size(); j++) { | |||
| new_params[curr_inputs_index[j]] = old_params[j]; | |||
| } | |||
| // Zs_ | |||
| for (size_t j = old_args.size(), k = 0; j < old_params.size(); ++j, ++k) { | |||
| new_params[anchor_args_size + extra_input_counter + k] = old_params[j]; | |||
| } | |||
| // unused inputs | |||
| for (size_t j = 0; j < anchor_args_size; ++j) { | |||
| if (new_params[j] == nullptr) { | |||
| TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info())); | |||
| ParameterPtr param = std::make_shared<Parameter>(another_fg); | |||
| new_params[j] = param; | |||
| } | |||
| } | |||
| // extra inputs used by another func_graph; | |||
| for (size_t j = 0; j < extra_inputs.size(); ++j) { | |||
| if (new_params[anchor_args_size + j] == nullptr) { | |||
| TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info())); | |||
| ParameterPtr param = std::make_shared<Parameter>(another_fg); | |||
| new_params[anchor_args_size + j] = param; | |||
| } | |||
| } | |||
| // set the parameter for another_fg and replace it's parameters; | |||
| txn.SetParameters(another_fg, new_params); | |||
| } | |||
| // Reorder Zs_ and add extra parameters for anchor_fg; | |||
| // add extra parameter for anchor_fg; | |||
| AnfNodePtrList new_params; | |||
| new_params.reserve(anchor_params_size + extra_input_counter); | |||
| // reuse parameters for anchor_args; | |||
| std::copy(anchor_fg_params.cbegin(), anchor_fg_params.cbegin() + anchor_args_size, std::back_inserter(new_params)); | |||
| // Extra parameters; | |||
| for (size_t i = 0; i < extra_inputs.size(); ++i) { | |||
| TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info())); | |||
| ParameterPtr param = std::make_shared<Parameter>(anchor_fg); | |||
| new_params.push_back(param); | |||
| } | |||
| // Reorder Zs_ to last; | |||
| for (size_t i = anchor_args_size; i < anchor_params_size; ++i) { | |||
| new_params.push_back(anchor_fg_params[i]); | |||
| } | |||
| txn.SetParameters(anchor_fg, new_params); | |||
| txn.Commit(); | |||
| return extra_inputs; | |||
| } | |||
| }; | |||
| // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} -> | |||
| // {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs} | |||
| // {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union | |||
| // Zs} | |||
| // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union | |||
| // Zs} | |||
| class SwitchPartialEliminater : public ChoicePartialEliminater { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) { | |||
| return nullptr; | |||
| } | |||
| auto input0_cnode = cnode->input(0)->cast<CNodePtr>(); | |||
| if (input0_cnode->size() != 4) { | |||
| return nullptr; | |||
| } | |||
| fg_list_.clear(); | |||
| args_list_.clear(); | |||
| auto &maybe_partial_1 = input0_cnode->input(2); | |||
| Visit(maybe_partial_1); | |||
| auto &maybe_partial_2 = input0_cnode->input(3); | |||
| Visit(maybe_partial_2); | |||
| // Either one should be {Partial, G, X} | |||
| if (fg_list_.size() != 2 && args_list_.size() != 2) { | |||
| return nullptr; | |||
| } | |||
| // Should not continue; | |||
| if (!CheckFuncGraphAndArgs()) { | |||
| return nullptr; | |||
| } | |||
| if (args_list_[0] == args_list_[1]) { | |||
| auto new_node = | |||
| BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[0], AnfNodePtrList{}); | |||
| return new_node; | |||
| } else { | |||
| // find partial funcgraph with the longest args as anchor; | |||
| size_t max_args_pos = 0; | |||
| if (args_list_[0].size() > args_list_[1].size()) { | |||
| max_args_pos = 0; | |||
| } else { | |||
| max_args_pos = 1; | |||
| } | |||
| auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_); | |||
| auto new_node = | |||
| BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[max_args_pos], extra_inputs); | |||
| return new_node; | |||
| } | |||
| } | |||
| private: | |||
| AnfNodePtr BuildNewSwitchNode(const CNodePtr &old_cnode, const CNodePtr input0_cnode, const AnfNodePtr &G1, | |||
| const AnfNodePtr &G2, const AnfNodePtrList &partial_args, | |||
| const AnfNodePtrList &extra_args) { | |||
| TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info())); | |||
| // {Switch, cond, G1, G2} | |||
| auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2}); | |||
| AnfNodePtrList args{switch_cnode}; | |||
| (void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args)); | |||
| (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args)); | |||
| // Zs | |||
| if (old_cnode->size() >= 2) { | |||
| (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args)); | |||
| } | |||
| TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info())); | |||
| auto new_node = old_cnode->func_graph()->NewCNode(args); | |||
| return new_node; | |||
| } | |||
| }; | |||
| // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} -> | |||
| // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs} | |||
| // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} -> | |||
| // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs} | |||
| // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} -> | |||
| // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs} | |||
| class SwitchLayerPartialEliminater : public ChoicePartialEliminater { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| if (!node->isa<CNode>() || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| // {SwitchLayer{}, Zs} | |||
| if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) { | |||
| return nullptr; | |||
| } | |||
| auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>(); | |||
| // {SwitchLayer, cond, MakeTuple{}} | |||
| if (switch_layer_cnode->size() != 3) { | |||
| return nullptr; | |||
| } | |||
| if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) { | |||
| return nullptr; | |||
| } | |||
| auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>(); | |||
| if (make_tuple_cnode->size() < 2) { | |||
| return nullptr; | |||
| } | |||
| fg_list_.clear(); | |||
| args_list_.clear(); | |||
| // Build funcgraph list and args list; | |||
| for (size_t i = 1; i < make_tuple_cnode->size(); ++i) { | |||
| Visit(make_tuple_cnode->input(i)); | |||
| } | |||
| if (!CheckFuncGraphAndArgs()) { | |||
| return nullptr; | |||
| } | |||
| // All have the same args; | |||
| auto args_equal = | |||
| std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; }); | |||
| if (args_equal) { | |||
| auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[0], AnfNodePtrList{}); | |||
| return new_node; | |||
| } else { | |||
| // find partial funcgraph with the longest args as anchor; | |||
| size_t max_args_pos = 0, max_args_len = 0; | |||
| for (size_t i = 0; i < args_list_.size(); ++i) { | |||
| if (max_args_len < args_list_[i].size()) { | |||
| max_args_len = args_list_[i].size(); | |||
| max_args_pos = i; | |||
| } | |||
| } | |||
| auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_); | |||
| auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[max_args_pos], extra_inputs); | |||
| return new_node; | |||
| } | |||
| } | |||
| private: | |||
| AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode, | |||
| const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) { | |||
| auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>(); | |||
| AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)}; | |||
| make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end()); | |||
| TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info())); | |||
| // {MakeTuple, G1, G2, ...} | |||
| auto new_make_tuple_cnode = old_cnode->func_graph()->NewCNode(make_tuple_args); | |||
| TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer_cnode->debug_info())); | |||
| // {SwitchLayer, cond, MakeTuple{}} | |||
| auto new_switch_layer_cnode = old_cnode->func_graph()->NewCNode( | |||
| {switch_layer_cnode->input(0), switch_layer_cnode->input(1), new_make_tuple_cnode}); | |||
| AnfNodePtrList args{new_switch_layer_cnode}; | |||
| (void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args)); | |||
| (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args)); | |||
| // Zs | |||
| if (old_cnode->size() >= 2) { | |||
| (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args)); | |||
| } | |||
| TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info())); | |||
| auto new_node = old_cnode->func_graph()->NewCNode(args); | |||
| return new_node; | |||
| } | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -109,11 +109,12 @@ void BaseFuncGraphEvaluator::LeaveStackFrame(const AnalysisEnginePtr &engine, | |||
| } | |||
| // Start running stack frames in a Evaluator. | |||
| AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) { | |||
| AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, | |||
| const AnalysisContextPtr &context) { | |||
| EvalResultPtr eval_result = nullptr; | |||
| AbstractBasePtr res_base = nullptr; | |||
| std::stack<StackFramePtr> stack_frames; | |||
| auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context_, parent_context_); | |||
| auto current_stack_frame = std::make_shared<StackFrame>(shared_from_base<Evaluator>(), fg, context, parent_context_); | |||
| MS_LOG(DEBUG) << "[" << this << "/StackFrame] Start at func graph, " << current_stack_frame; | |||
| stack_frames.push(current_stack_frame); | |||
| while (true) { | |||
| @@ -155,7 +156,8 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchStackFrame(const AnalysisEnginePtr | |||
| return res_base; | |||
| } | |||
| AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg) { | |||
| AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, | |||
| const AnalysisContextPtr &context) { | |||
| const AnfNodePtr &func_node = fg->get_return(); | |||
| const auto &all_nodes = TopoSort(func_node, SuccIncoming, [](const AnfNodePtr &node) -> IncludeType { | |||
| if (node->isa<ValueNode>() || node->isa<Parameter>()) { | |||
| @@ -165,7 +167,7 @@ AbstractBasePtr BaseFuncGraphEvaluator::LaunchRecursiveEval(const AnalysisEngine | |||
| }); | |||
| AbstractBasePtr res_base = nullptr; | |||
| for (const auto &node : all_nodes) { | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context_); | |||
| AnfNodeConfigPtr node_conf = engine->MakeConfig(node, context); | |||
| MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg << "/" << fg->ToString() | |||
| << ", node_conf: " << node_conf->ToString(); | |||
| auto node_eval_result = engine->ObtainEvalResultWithCache(node_conf); | |||
| @@ -214,24 +216,30 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| << parent_context_->func_graph()->ToString() << "()->" << AnalysisResultCacheMgr::GetThreadid() << ":" | |||
| << fg->ToString() << "();"; | |||
| } | |||
| context_ = parent_context_->NewFuncGraphContext(fg, args_abs_list); | |||
| auto context = parent_context_->NewFuncGraphContext(fg, args_abs_list); | |||
| auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>()); | |||
| if (func_graph_evaluator != nullptr) { | |||
| if (engine->root_func_graph() == func_graph_evaluator->func_graph()) { | |||
| engine->set_root_context(context); | |||
| } | |||
| } | |||
| const auto ¶meters = fg->parameters(); | |||
| for (size_t i = 0; i < nargs; i++) { | |||
| const auto &arg = args_abs_list[i]; | |||
| const auto &node = parameters[i]; | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, context_); | |||
| AnfNodeConfigPtr conf = engine->MakeConfig(node, context); | |||
| engine->SaveEvalResultInCache(conf, std::make_shared<EvalResult>(arg, nullptr)); | |||
| MS_LOG(DEBUG) << GetInferThread() << "Set Param: " << conf->ToString() << " = " << arg->ToString(); | |||
| } | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg << "/" << fg->ToString() | |||
| << ", context: " << context_->ToString() << ", return node: " << fg->get_return()->DebugString() | |||
| << ", context: " << context->ToString() << ", return node: " << fg->get_return()->DebugString() | |||
| << ", parent: " << (parent_context_->func_graph() ? parent_context_->func_graph()->ToString() : "NULL") | |||
| << ", current function call depth: " << engine->function_call_depth(); | |||
| AbstractBasePtr res_base = nullptr; | |||
| if (engine->enable_recursive_eval()) { | |||
| res_base = LaunchRecursiveEval(engine, fg); | |||
| res_base = LaunchRecursiveEval(engine, fg, context); | |||
| } else { | |||
| res_base = LaunchStackFrame(engine, fg); | |||
| res_base = LaunchStackFrame(engine, fg, context); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(res_base); | |||
| @@ -250,28 +258,27 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| return res; | |||
| } | |||
| void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) { | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| // Only broaden scalar that data type is number, such as float16,int32 and so on. | |||
| auto type = arg->BuildType()->type_id(); | |||
| if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) { | |||
| auto config = abstract::AbstractBase::kBroadenScalarParameterOnly; | |||
| return arg->Broaden(config); | |||
| } else if (arg->GetValueTrack() != kAnyValue) { | |||
| return arg->Broaden(); | |||
| } | |||
| return arg; | |||
| }); | |||
| } | |||
| AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||
| if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| AbstractBasePtrList broaded_list; | |||
| (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_list), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { | |||
| MS_EXCEPTION_IF_NULL(arg); | |||
| // Only broaden scalar that data type is number, such as float16,int32 and so on. | |||
| auto type = arg->BuildType()->type_id(); | |||
| if (arg->isa<AbstractScalar>() && type > kNumberTypeBegin && type < kNumberTypeEnd) { | |||
| auto config = abstract::AbstractBase::kBroadenScalarParameterOnly; | |||
| return arg->Broaden(config); | |||
| } else if (arg->GetValueTrack() != kAnyValue) { | |||
| return arg->Broaden(); | |||
| } | |||
| return arg; | |||
| }); | |||
| if (func_graph_->joined_shapes_.size() == broaded_list.size()) { | |||
| for (size_t i = 0; i < broaded_list.size(); ++i) { | |||
| broaded_list[i]->set_shape(func_graph_->joined_shapes_[i]); | |||
| } | |||
| } | |||
| BroadenArgs(args_spec_list, &broaded_list); | |||
| MS_LOG(DEBUG) << func_graph_->ToString() << " original: " << mindspore::ToString(args_spec_list) | |||
| << ", broaded: " << mindspore::ToString(broaded_list); | |||
| @@ -287,55 +294,11 @@ AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBa | |||
| } | |||
| if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { | |||
| if (parent_context_) { | |||
| MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() | |||
| << ", context: " << parent_context_->ToString(); | |||
| auto last_context = parent_context_->FindParentContext(func_graph_); | |||
| if (last_context && last_context->func_graph() == func_graph_) { | |||
| MS_LOG(DEBUG) << "Find last eval context: " << last_context->ToString(); | |||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(last_context->args_spec_list()); | |||
| // Join the last eval arguments and current arguments to check if there are loop variant. | |||
| auto joined_args_spec_list_1 = AbstractJoin(args_spec_list, last_context->args_spec_list()); | |||
| MS_LOG(DEBUG) << "Joined args: " << ::mindspore::ToString(joined_args_spec_list_1); | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list_1 == args_spec_list)) { | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->joined_shapes_.clear(); | |||
| std::transform(joined_args_spec_list_1.begin(), joined_args_spec_list_1.end(), | |||
| std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { | |||
| MS_EXCEPTION_IF_NULL(arg_spec); | |||
| return arg_spec->GetShapeTrack(); | |||
| }); | |||
| joined_args_spec_list_1 = NormalizeArgs(joined_args_spec_list_1); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| } | |||
| return joined_args_spec_list_1; | |||
| } | |||
| } | |||
| if (!trace_.empty()) { | |||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); | |||
| // Join the last eval arguments and current arguments to check if there are loop variant. | |||
| auto joined_args_spec_list_2 = AbstractJoin(args_spec_list, trace_.back()); | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list_2 == args_spec_list)) { | |||
| trace_.push_back(joined_args_spec_list_2); | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| func_graph_->joined_shapes_.clear(); | |||
| std::transform(joined_args_spec_list_2.begin(), joined_args_spec_list_2.end(), | |||
| std::back_inserter(func_graph_->joined_shapes_), [](const AbstractBasePtr &arg_spec) { | |||
| MS_EXCEPTION_IF_NULL(arg_spec); | |||
| return arg_spec->GetShapeTrack(); | |||
| }); | |||
| joined_args_spec_list_2 = NormalizeArgs(joined_args_spec_list_2); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| } | |||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list_2); | |||
| return joined_args_spec_list_2; | |||
| } else { | |||
| trace_.push_back(args_spec_list); | |||
| } | |||
| func_graph_->set_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| auto normalized_args_spec_list = NormalizeArgs(args_spec_list); | |||
| MS_LOG(DEBUG) << "Set " << func_graph_->ToString() << " with IGNORE_VALUES flag."; | |||
| MS_LOG(DEBUG) << "Normalized args " << mindspore::ToString(normalized_args_spec_list); | |||
| return normalized_args_spec_list; | |||
| } | |||
| return args_spec_list; | |||
| } | |||
| @@ -388,16 +351,6 @@ FuncGraphPtr MetaFuncGraphEvaluator::GetFuncGraph(AnalysisEnginePtr engine, cons | |||
| EvalResultPtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, | |||
| const AnfNodeConfigPtr &out_conf) { | |||
| // The evaluator can't reenter at sametime. | |||
| std::unique_lock<std::recursive_timed_mutex> eval_lock(eval_lock_, std::try_to_lock); | |||
| if (!eval_lock.owns_lock()) { | |||
| // Release GIL | |||
| pybind11::gil_scoped_release infer_gil_release; | |||
| // Check if enter endless loop | |||
| HealthPointScopedDrop health_point_check; | |||
| eval_lock.lock(); | |||
| } | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| @@ -210,8 +210,6 @@ class BaseFuncGraphEvaluator : public Evaluator { | |||
| AnalysisContextPtr MakeContext(const AnalysisEnginePtr &engine, const AbstractBasePtrList &args_spec_list); | |||
| AnalysisContextPtr context() const { return context_; } | |||
| void set_context(const AnalysisContextPtr &context) { context_ = context; } | |||
| AnalysisContextPtr parent_context() const { return parent_context_; } | |||
| void set_parent_context(const AnalysisContextPtr &parent_context) { parent_context_ = parent_context; } | |||
| @@ -219,14 +217,14 @@ class BaseFuncGraphEvaluator : public Evaluator { | |||
| AnalysisContextPtr parent_context_; | |||
| private: | |||
| AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg); | |||
| AbstractBasePtr LaunchRecursiveEval(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, | |||
| const AnalysisContextPtr &context); | |||
| // Add functions for stack frame routine. | |||
| AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg); | |||
| AbstractBasePtr LaunchStackFrame(const AnalysisEnginePtr &engine, const FuncGraphPtr &fg, | |||
| const AnalysisContextPtr &context); | |||
| static void EnterStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame, | |||
| const StackFramePtr &new_stack_frame); | |||
| static void LeaveStackFrame(const AnalysisEnginePtr &engine, const StackFramePtr ¤t_stack_frame); | |||
| AnalysisContextPtr context_; | |||
| }; | |||
| class FuncGraphEvaluator : public BaseFuncGraphEvaluator { | |||
| @@ -353,6 +351,8 @@ class JEvaluator : public Evaluator { | |||
| EvaluatorPtr evaluator_; | |||
| AbstractFunctionPtr orig_func_; | |||
| }; | |||
| void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args); | |||
| } // namespace abstract | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PIPELINE_JIT_STATIC_ANALYSIS_EVALUATOR_H_ | |||
| @@ -23,6 +23,7 @@ | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/operator/composite/do_signature.h" | |||
| #include "abstract/abstract_function.h" | |||
| #include "abstract/utils.h" | |||
| #include "ir/graph_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "debug/trace.h" | |||
| @@ -545,32 +546,37 @@ std::pair<AbstractBasePtrList, AbstractBasePtr> FuncGraphSpecializer::BuildFromB | |||
| std::unordered_set<AbstractBasePtrList, AbstractBasePtrListHasher, AbstractBasePtrListEqual> choices; | |||
| EvalResultPtr ret = nullptr; | |||
| AbstractBasePtrList broaded_argvals; | |||
| EvalResultCache &cache = evalcaches_[eval]->GetCache(); | |||
| for (auto &argvals_map : cache) { | |||
| std::vector<AbstractBasePtrList> args_vector; | |||
| auto &origin_eval_cache = evalcaches_[eval]->GetCache(); | |||
| for (auto &argvals_map : origin_eval_cache) { | |||
| auto argvals = argvals_map.first; | |||
| args_vector.push_back(argvals); | |||
| broaded_argvals.clear(); | |||
| (void)std::transform(argvals.begin(), argvals.end(), std::back_inserter(broaded_argvals), | |||
| [](const AbstractBasePtr &arg) -> AbstractBasePtr { return arg->Broaden(); }); | |||
| BroadenArgs(argvals, &broaded_argvals); | |||
| (void)choices.insert(broaded_argvals); | |||
| MS_LOG(DEBUG) << "Broaded_argvals: " << broaded_argvals.size() << ", " << ::mindspore::ToString(broaded_argvals); | |||
| } | |||
| if (choices.size() == 1) { | |||
| ConfigPtrList args_conf_list; | |||
| (void)std::transform(broaded_argvals.begin(), broaded_argvals.end(), std::back_inserter(args_conf_list), | |||
| [](AbstractBasePtr v) -> ConfigPtr { return std::make_shared<VirtualConfig>(v); }); | |||
| // If broaden return null | |||
| ret = eval->SingleRun(engine_, args_conf_list, nullptr); | |||
| if (args_vector.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Should have 2 more choices, but: " << args_vector.size(); | |||
| } | |||
| AbstractBasePtrList joined_argvals = args_vector[0]; | |||
| for (size_t i = 1; i < args_vector.size(); ++i) { | |||
| joined_argvals = abstract::AbstractJoin(joined_argvals, args_vector[i]); | |||
| } | |||
| MS_LOG(DEBUG) << "Joined argvals: " << joined_argvals.size() << ", " << ::mindspore::ToString(joined_argvals); | |||
| EvaluatorCacheMgrPtr real = std::make_shared<EvaluatorCacheMgr>(); | |||
| real->SetValue(broaded_argvals, ret); | |||
| evalcaches_[eval] = real; | |||
| return std::make_pair(broaded_argvals, ret->abstract()); | |||
| } else { | |||
| MS_LOG(DEBUG) << "Choices.size: " << choices.size(); | |||
| return std::make_pair(AbstractBasePtrList(), nullptr); | |||
| auto joined_eval_result = origin_eval_cache.get(joined_argvals); | |||
| if (joined_eval_result != nullptr) { | |||
| MS_LOG(DEBUG) << "Find unique Choices in original eval cache, so use it: " << joined_eval_result->ToString(); | |||
| real->SetValue(joined_argvals, joined_eval_result); | |||
| evalcaches_[eval] = real; | |||
| return std::make_pair(joined_argvals, joined_eval_result->abstract()); | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Choices.size: " << choices.size(); | |||
| return std::make_pair(AbstractBasePtrList(), nullptr); | |||
| } | |||
| void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { | |||
| @@ -96,7 +96,6 @@ StackFramePtr StackFrame::DoJump(const AnalysisEnginePtr &engine, const CNodePtr | |||
| } | |||
| // Create a new stack frame and set arguments for it. | |||
| fg_evaluator->set_context(new_context); | |||
| auto new_stack_frame = std::make_shared<StackFrame>(fg_evaluator, fg, new_context, parent_context); | |||
| new_stack_frame->set_args_abs_list(std::move(args_abs_list)); | |||
| return new_stack_frame; | |||
| @@ -80,6 +80,8 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac | |||
| MS_EXCEPTION_IF_NULL(func_graph_manager_); | |||
| func_graph_manager_->AddFuncGraph(func_graph); | |||
| root_func_graph_ = func_graph; | |||
| AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); | |||
| // Running the analyzer. | |||
| @@ -105,7 +107,7 @@ AnalysisContextPtr AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Ana | |||
| const ConfigPtrList &args_conf_list) { | |||
| std::shared_ptr<FuncGraphEvaluator> eval = std::make_shared<FuncGraphEvaluator>(func_graph, context); | |||
| (void)eval->Run(shared_from_this(), args_conf_list, nullptr); | |||
| return eval->context(); | |||
| return root_context_; | |||
| } | |||
| void AnalysisEngine::SaveEvalResultInCache(const AnfNodeConfigPtr &conf, const EvalResultPtr &result) { | |||
| @@ -213,11 +215,12 @@ void AnalysisEngine::CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf) { | |||
| return; | |||
| } | |||
| auto top_evaluator = infer_stack.top().first; | |||
| if (!top_evaluator->isa<BaseFuncGraphEvaluator>()) { | |||
| // Top or root func_graph must be FuncGraph other than MetaFuncGraph; | |||
| if (!top_evaluator->isa<FuncGraphEvaluator>()) { | |||
| MS_LOG(EXCEPTION) << "Top evaluator is " << top_evaluator->ToString(); | |||
| } | |||
| auto top_fg_evaluator = dyn_cast<BaseFuncGraphEvaluator>(top_evaluator); | |||
| auto top_context_fg = top_fg_evaluator->context()->func_graph(); | |||
| auto top_fg_evaluator = dyn_cast<FuncGraphEvaluator>(top_evaluator); | |||
| auto top_context_fg = top_fg_evaluator->func_graph(); | |||
| if (current_cnode_fg != top_context_fg) { // Ignore FV call. | |||
| return; | |||
| } | |||
| @@ -339,6 +342,8 @@ void AnalysisEngine::Clear() { | |||
| evaluators_.clear(); | |||
| constructors_app_.clear(); | |||
| continued_evals_.clear(); | |||
| root_func_graph_ = nullptr; | |||
| root_context_ = nullptr; | |||
| } | |||
| namespace { | |||
| @@ -578,7 +583,7 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr> | |||
| #endif | |||
| } | |||
| bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||
| bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) { | |||
| static std::mutex fg_lock; | |||
| std::lock_guard<std::mutex> infer_lock(fg_lock); | |||
| auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>(); | |||
| @@ -591,10 +596,17 @@ bool AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { | |||
| auto undetermined_fgs = fg->recursive(); | |||
| if (undetermined_fgs) { | |||
| auto fg_parent = fg->parent(); | |||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||
| fg_parent->set_flag(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||
| return true; | |||
| if (fg_parent != nullptr) { | |||
| fg_parent->set_flag(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString() << " for fg: " << fg->ToString(); | |||
| return true; | |||
| } else if (possible_parent_fg != nullptr) { | |||
| possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString() << " for fg: " << fg->ToString(); | |||
| return true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "cannot find parent for fg: " << fg->ToString(); | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| @@ -810,8 +822,11 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve | |||
| AsyncAbstractResultPtr asyncResult1 = std::make_shared<AsyncAbstractResult>(); | |||
| AsyncAbstractResultPtr asyncFirstRunResult = std::make_shared<AsyncAbstractResult>(); | |||
| bool firstRun = !SetUndeterminedFlag(evaluators[0]); | |||
| (void)SetUndeterminedFlag(evaluators[1]); | |||
| MS_EXCEPTION_IF_NULL(out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | |||
| auto possible_parent_fg = out_conf->node()->func_graph(); | |||
| bool firstRun = !SetUndeterminedFlag(evaluators[0], possible_parent_fg); | |||
| (void)SetUndeterminedFlag(evaluators[1], possible_parent_fg); | |||
| std::string threadId = AnalysisResultCacheMgr::GetThreadid(); | |||
| MS_LOG(DEBUG) << GetInferThread() << "async : " << evaluators[0]->ToString(); | |||
| @@ -891,8 +906,12 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua | |||
| MS_EXCEPTION_IF_NULL(conf); | |||
| return conf->ObtainEvalResult()->abstract(); | |||
| }); | |||
| for (const auto &eval : evaluators) { | |||
| (void)SetUndeterminedFlag(eval); | |||
| MS_EXCEPTION_IF_NULL(out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_conf->node()); | |||
| auto possible_parent_fg = out_conf->node()->func_graph(); | |||
| for (auto eval : evaluators) { | |||
| (void)SetUndeterminedFlag(eval, possible_parent_fg); | |||
| const auto current_inf = EvaluatorArgs(eval, args_spec_list); | |||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | |||
| // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. | |||
| @@ -248,6 +248,10 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf); | |||
| const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } | |||
| FuncGraphPtr root_func_graph() const { return root_func_graph_; } | |||
| AnalysisContextPtr root_context() const { return root_context_; } | |||
| void set_root_context(const AnalysisContextPtr &context) { root_context_ = context; } | |||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | |||
| void ResetFunctionCallDepth() { | |||
| @@ -292,7 +296,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node); | |||
| private: | |||
| bool SetUndeterminedFlag(const EvaluatorPtr &evaluator); | |||
| bool SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg); | |||
| EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, | |||
| const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, | |||
| bool *continue_flag); | |||
| @@ -308,6 +312,9 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| std::list<EvaluatorArgs> eval_trace_; | |||
| std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; | |||
| std::unordered_set<EvaluatorArgs, EvaluatorArgsHasher, EvaluatorArgsEqual> continued_evals_; | |||
| // root or top func_graph for static analysis; | |||
| FuncGraphPtr root_func_graph_{nullptr}; | |||
| AnalysisContextPtr root_context_{nullptr}; | |||
| AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | |||
| const ConfigPtrList &args_conf_list); | |||
| @@ -98,6 +98,21 @@ void FuncGraph::add_parameter(const ParameterPtr &p) { | |||
| } | |||
| } | |||
| ParameterPtr FuncGraph::InsertFrontParameter() { | |||
| FuncGraphPtr this_func_graph = shared_from_base<FuncGraph>(); | |||
| ParameterPtr p = std::make_shared<Parameter>(this_func_graph); | |||
| InsertFrontParameter(p); | |||
| return p; | |||
| } | |||
| void FuncGraph::InsertFrontParameter(const ParameterPtr &p) { | |||
| if (manager_.lock()) { | |||
| manager_.lock()->InsertFrontParameter(shared_from_base<FuncGraph>(), p); | |||
| } else { | |||
| PrependParameter(p); | |||
| } | |||
| } | |||
| ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { | |||
| FuncGraphPtr this_graph = shared_from_base<FuncGraph>(); | |||
| ParameterPtr p = std::make_shared<Parameter>(this_graph); | |||
| @@ -83,6 +83,7 @@ const char FUNC_GRAPH_FLAG_CORE[] = "core"; | |||
| const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; | |||
| const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; | |||
| const char FUNC_GRAPH_OUTPUT_NO_RECOMPUTE[] = "output_no_recompute"; | |||
| const char FUNC_GRAPH_FLAG_FORCE_INLINE[] = "force_inline"; | |||
| const char kFuncGraphFlagUndetermined[] = "Undeterminate"; | |||
| const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry"; | |||
| @@ -169,9 +170,14 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| void set_output(const AnfNodePtr &value, bool force_new_ret = false); | |||
| const std::vector<AnfNodePtr> ¶meters() const { return parameters_; } | |||
| // Append | |||
| virtual ParameterPtr add_parameter(); | |||
| void add_parameter(const ParameterPtr &p); | |||
| void append_parameter(const ParameterPtr &p) { parameters_.push_back(p); } | |||
| // Prepend | |||
| virtual ParameterPtr InsertFrontParameter(); | |||
| void InsertFrontParameter(const ParameterPtr &p); | |||
| void PrependParameter(const ParameterPtr &p) { parameters_.insert(parameters_.begin(), p); } | |||
| void set_parameters(const std::vector<AnfNodePtr> ¶ms) { parameters_ = params; } | |||
| // Add a weight parameter with specific name. | |||
| ParameterPtr AddWeightParameter(const std::string &name); | |||
| @@ -354,7 +360,6 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { | |||
| void add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| std::vector<BaseShapePtr> joined_shapes_; | |||
| std::unordered_map<std::string, FuncGraphTransform> transforms_; | |||
| // Parameter default value. | |||
| std::map<std::string, AnfNodePtr> parameter_default_value_; | |||
| @@ -224,7 +224,6 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons | |||
| TraceGuard trace_guard(func_graph->debug_info(), target_relation_); | |||
| *target_func_graph = std::make_shared<FuncGraph>(); | |||
| (*target_func_graph)->set_attrs(func_graph->attrs()); | |||
| (*target_func_graph)->joined_shapes_ = func_graph->joined_shapes_; | |||
| (*target_func_graph)->set_transforms(func_graph->transforms()); | |||
| (*target_func_graph)->set_has_vararg(func_graph->has_vararg()); | |||
| (*target_func_graph)->set_has_kwarg(func_graph->has_kwarg()); | |||
| @@ -254,9 +253,24 @@ void Cloner::GenParameters(const FuncGraphPtr &func_graph) { | |||
| return; | |||
| } | |||
| CloneInfo item = todo_.back(); | |||
| auto lift_top_func_graph = item.origin; | |||
| for (auto &fv_map : iter->second) { | |||
| auto &free_var = fv_map.first; | |||
| if (utils::isa<AnfNodePtr>(free_var)) { | |||
| auto free_var_node = utils::cast<AnfNodePtr>(free_var); | |||
| // Don't lift weight parameter to top func_graph. | |||
| if (func_graph == lift_top_func_graph) { | |||
| if (free_var_node->isa<Parameter>()) { | |||
| auto free_var_param = free_var_node->cast<ParameterPtr>(); | |||
| if (free_var_param->has_default()) { | |||
| MS_LOG(DEBUG) << "Bypass weight param: " << free_var_param->ToString() | |||
| << " for top_func_graph: " << lift_top_func_graph->ToString(); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| MS_LOG(DEBUG) << "Gen param: " << free_var_node->ToString() << " for func_graph: " << func_graph->ToString(); | |||
| repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast<AnfNodePtr>(free_var))); | |||
| } | |||
| } | |||
| @@ -301,6 +315,8 @@ void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList | |||
| } | |||
| } | |||
| AnfNodePtr new_param = nullptr; | |||
| CloneInfo item = todo_.back(); | |||
| auto lift_top_func_graph = item.origin; | |||
| for (auto ¶m : params) { | |||
| auto old_param = repl_node_[param]; | |||
| if (old_param->isa<CNode>() && old_param->func_graph() == func_graph) { | |||
| @@ -314,6 +330,14 @@ void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList | |||
| input_params->push_back(new_param); | |||
| continue; | |||
| } | |||
| if (lift_top_func_graph == func_graph) { | |||
| // Don't lift parameter from used_graphs to my parameter if I am the top; | |||
| repl_node_[old_param] = old_param; | |||
| input_params->push_back(old_param); | |||
| MS_LOG(DEBUG) << "Bypass param: " << old_param->ToString() | |||
| << " for top_func_graph: " << lift_top_func_graph->ToString(); | |||
| continue; | |||
| } | |||
| new_param = AddParameter(func_graph, old_param, false); | |||
| parameters.push_back(new_param); | |||
| lift_params->push_back(new_param); | |||
| @@ -445,6 +445,12 @@ void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &pa | |||
| tr.Commit(); | |||
| } | |||
| void FuncGraphManager::InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter) { | |||
| auto tr = Transact(); | |||
| tr.InsertFrontParameter(fg, parameter); | |||
| tr.Commit(); | |||
| } | |||
| bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| auto func_graph = old_node->func_graph(); | |||
| auto tr = Transact(); | |||
| @@ -599,6 +605,13 @@ void FuncGraphManager::ParseChanges(const std::vector<Change> &changes, EdgeTupl | |||
| auto param_node = param.param->cast<ParameterPtr>(); | |||
| param.func_graph->append_parameter(param_node); | |||
| } break; | |||
| case Change::kTxInsertFrontParam: { | |||
| auto param = args.cast<ArgsOfInsertFrontParam>(); | |||
| MS_EXCEPTION_IF_NULL(param.func_graph); | |||
| (*adds)[param.param] += 1; | |||
| auto param_node = param.param->cast<ParameterPtr>(); | |||
| param.func_graph->PrependParameter(param_node); | |||
| } break; | |||
| default: | |||
| break; | |||
| } | |||
| @@ -665,6 +678,10 @@ void FuncGraphTransaction::AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m | |||
| changes_.emplace_back(Change::kTxAddParam, ArgsOfAddParam{fg, param}); | |||
| } | |||
| void FuncGraphTransaction::InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m) { | |||
| changes_.emplace_back(Change::kTxInsertFrontParam, ArgsOfInsertFrontParam{fg, param}); | |||
| } | |||
| bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { | |||
| MS_EXCEPTION_IF_NULL(old_node); | |||
| MS_EXCEPTION_IF_NULL(new_node); | |||
| @@ -312,6 +312,7 @@ class FuncGraphManager : public std::enable_shared_from_this<FuncGraphManager> { | |||
| void RemoveRoots(); | |||
| void SetParameters(const FuncGraphPtr &fg, const std::vector<AnfNodePtr> ¶meters); | |||
| void AddParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter); | |||
| void InsertFrontParameter(const FuncGraphPtr &fg, const AnfNodePtr ¶meter); | |||
| void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); | |||
| @@ -406,6 +407,7 @@ class FuncGraphTransaction { | |||
| // set parameters of a func graph | |||
| void SetParameters(FuncGraphPtr fg, const std::vector<AnfNodePtr> ¶ms); | |||
| void AddParameter(FuncGraphPtr fg, const AnfNodePtr ¶m); | |||
| void InsertFrontParameter(FuncGraphPtr fg, const AnfNodePtr ¶m); | |||
| // replace old_node with new_node | |||
| bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); | |||
| @@ -447,6 +449,18 @@ struct ArgsOfAddParam { | |||
| } | |||
| }; | |||
| // args for InsertFront param | |||
| struct ArgsOfInsertFrontParam { | |||
| FuncGraphPtr func_graph; | |||
| AnfNodePtr param; | |||
| bool operator==(const ArgsOfInsertFrontParam &other) const { return &other == this; } | |||
| friend std::ostream &operator<<(std::ostream &os, const ArgsOfInsertFrontParam &) { | |||
| os << "[ArgsOfInsertFrontParam]"; | |||
| return os; | |||
| } | |||
| }; | |||
| // args for set edge | |||
| struct ArgsOfSetEdge { | |||
| CNodePtr root_node; | |||
| @@ -473,7 +487,7 @@ struct ArgsOfAddEdge { | |||
| }; | |||
| struct Change { | |||
| enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxAddEdge }; | |||
| enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxInsertFrontParam, kTxAddEdge }; | |||
| OpName op; | |||
| Any args; | |||
| Change(OpName name, const Any ¶) : op(name), args(para) {} | |||
| @@ -40,6 +40,16 @@ abstract::ShapePtr AddNInferShape(const PrimitivePtr &primitive, const std::vect | |||
| auto element0_shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(shape_0); | |||
| for (size_t i = 0; i < elements.size(); ++i) { | |||
| auto shape = elements[i]->BuildShape(); | |||
| if (shape->isa<abstract::Shape>() && shape_0->isa<abstract::Shape>()) { | |||
| const auto &shape_vec = shape->cast<abstract::ShapePtr>()->shape(); | |||
| const auto &shape_0_vec = shape_0->cast<abstract::ShapePtr>()->shape(); | |||
| if ((shape_vec == ShapeVector({1}) && shape_0_vec == ShapeVector()) || | |||
| (shape_vec == ShapeVector() && shape_0_vec == ShapeVector({1}))) { | |||
| MS_LOG(DEBUG) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString() | |||
| << " are consistent with the shape of input[0]" << shape_0->ToString(); | |||
| continue; | |||
| } | |||
| } | |||
| if (*shape != *shape_0) { | |||
| MS_EXCEPTION(ValueError) << primitive->name() << "Shape of input[" << i << "]: " << shape->ToString() | |||
| << " are not consistent with the shape of input[0]" << shape_0->ToString(); | |||
| @@ -676,7 +676,7 @@ class SideEffectControlFlowAssignDependWhileNet(Cell): | |||
| return grad_out | |||
| @pytest.mark.level0 | |||
| @pytest.mark.level1 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -37,11 +37,13 @@ class TestAD : public UT::Common { | |||
| public: | |||
| UT::PyFuncGraphFetcher getPyFun; | |||
| pipeline::ResourceBasePtr resourcePtr = std::make_shared<pipeline::ResourceBase>(); | |||
| pipeline::ResourcePtr resourcePtr = std::make_shared<pipeline::Resource>(); | |||
| protected: | |||
| void AssertExpect(const std::string& testCase) { | |||
| FuncGraphPtr g = getPyFun(testCase); | |||
| resourcePtr->manager()->RemoveRoots(); | |||
| resourcePtr->manager()->AddFuncGraph(g, true); | |||
| FuncGraphPtr dg = Grad(g, resourcePtr); | |||
| AssertExpect(testCase, dg); | |||
| } | |||
| @@ -20,9 +20,10 @@ from mindspore import context | |||
| from mindspore import Tensor | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||
| grad_all = C.GradOperation(get_all=True) | |||
| grad_by_list = C.GradOperation(get_by_list=True) | |||
| class CropAndResizeNet(nn.Cell): | |||
| def __init__(self, crop_size): | |||
| @@ -138,3 +139,38 @@ def test_ad_fv_cnode_order(): | |||
| net.add_flags_recursive(defer_inline=True) | |||
| grad_net = grad_all(net) | |||
| grad_net(input_x, input_y) | |||
| # True and False branch of switch have different number of parameters. | |||
| def test_if_branch_with_different_params(): | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.weight1 = Parameter(Tensor(np.array([1.0], dtype=np.float32)), name="weight1") | |||
| self.weight2 = Parameter(Tensor(np.array([2.0], dtype=np.float32)), name="weight2") | |||
| def construct(self, idx, end, x): | |||
| out = x | |||
| if idx < end: | |||
| out = out + self.weight1 * self.weight2 | |||
| else: | |||
| out = out + self.weight1 | |||
| return out | |||
| class GradNet(nn.Cell): | |||
| def __init__(self, net): | |||
| super(GradNet, self).__init__() | |||
| self.net = net | |||
| self.weights = ParameterTuple(net.trainable_params()) | |||
| def construct(self, idx, end, x): | |||
| return grad_by_list(self.net, self.weights)(idx, end, x) | |||
| idx = Tensor(np.array((0), dtype=np.int32)) | |||
| end = Tensor(np.array((3), dtype=np.int32)) | |||
| x = Tensor(np.array([2.0], dtype=np.float32)) | |||
| net = Net() | |||
| grad_net = GradNet(net) | |||
| grad_net(idx, end, x) | |||