Merge pull request !766 from gongchen/nest_looptags/v0.3.0-alpha
| @@ -27,14 +27,13 @@ | |||
| #include <utility> | |||
| #include <initializer_list> | |||
| #ifdef DEBUG | |||
| #include "debug/draw.h" | |||
| #include "debug/anf_ir_dump.h" | |||
| #endif | |||
| #include "debug/trace.h" | |||
| #include "optimizer/opt.h" | |||
| #include "pipeline/resource.h" | |||
| #include "pipeline/action.h" | |||
| #include "debug/trace.h" | |||
| #include "utils/context/ms_context.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | |||
| // Optimizer step counter; | |||
| int counter = 1; | |||
| int counter = -1; | |||
| bool changes = true; | |||
| while (changes) { | |||
| @@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||
| } | |||
| }; | |||
| use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | |||
| #ifdef DEBUG | |||
| MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||
| auto fg_name = name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| DumpIR(fg_name + ".ir", func_graph); | |||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | |||
| #endif | |||
| if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { | |||
| MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||
| auto fg_name = | |||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| DumpIR(fg_name + ".ir", func_graph); | |||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | |||
| } | |||
| } | |||
| }; | |||
| use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc(); | |||
| @@ -32,6 +32,7 @@ | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include "pipeline/static_analysis/program_specialize.h" | |||
| #include "pipeline/resource.h" | |||
| #include "utils/context/ms_context.h" | |||
| #include "pipeline/remove_value_node_dup.h" | |||
| #include "optimizer/optimizer.h" | |||
| #include "vm/transform.h" | |||
| @@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||
| } | |||
| bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) { | |||
| size_t counter = 0; | |||
| for (auto &pass : passes) { | |||
| WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() { | |||
| WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { | |||
| MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; | |||
| auto result = pass.second(res); | |||
| if (!result) { | |||
| MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; | |||
| } | |||
| if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { | |||
| auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; | |||
| auto func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| func_graph->DumpFuncGraph(fg_name); | |||
| DumpIR(fg_name + ".ir", func_graph); | |||
| MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; | |||
| } | |||
| counter++; | |||
| MS_LOG(DEBUG) << "Pass " << pass.first << " end."; | |||
| }; | |||
| } | |||
| @@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList & | |||
| AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); | |||
| normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); | |||
| FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); | |||
| MS_EXCEPTION_IF_NULL(parent_context_); | |||
| AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); | |||
| @@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList | |||
| << ", broaded: " << mindspore::ToString(broaded_list); | |||
| return broaded_list; | |||
| } | |||
| return args_spec_list; | |||
| } | |||
| AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { | |||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||
| if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||
| return args_spec_list; | |||
| } | |||
| if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { | |||
| if (parent_context_) { | |||
| MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() | |||
| @@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList | |||
| return joined_args_spec_list; | |||
| } | |||
| } | |||
| if (trace_.size() != 0) { | |||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); | |||
| // Join the last eval arguments and current arguments to check if there are loop variant. | |||
| auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); | |||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||
| if (!(joined_args_spec_list == args_spec_list)) { | |||
| trace_.push_back(joined_args_spec_list); | |||
| func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||
| } | |||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | |||
| return joined_args_spec_list; | |||
| } else { | |||
| trace_.push_back(args_spec_list); | |||
| } | |||
| } | |||
| return args_spec_list; | |||
| } | |||
| @@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar | |||
| return conf->GetEvaluatedValue(); | |||
| }); | |||
| args_spec_list = NormalizeArgs(args_spec_list); | |||
| args_spec_list = BroadenUndeterminedArgs(args_spec_list); | |||
| trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf); | |||
| InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(cache_); | |||
| @@ -47,6 +47,10 @@ class Evaluator : public Base { | |||
| virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | |||
| virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { | |||
| return args_spec_list; | |||
| } | |||
| std::string ToString() const override { return identifier_; } | |||
| virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } | |||
| @@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator { | |||
| FuncGraphPtr func_graph() { return func_graph_; } | |||
| AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; | |||
| AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; | |||
| std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } | |||
| private: | |||
| FuncGraphPtr func_graph_; | |||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual> | |||
| func_graph_cache_; | |||
| std::vector<AbstractBasePtrList> trace_; | |||
| }; | |||
| using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "pipeline/static_analysis/static_analysis.h" | |||
| #include <algorithm> | |||
| #include <set> | |||
| #include "pipeline/static_analysis/utils.h" | |||
| #include "pipeline/static_analysis/prim.h" | |||
| @@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC | |||
| for (std::size_t i = 1; i < inputs.size(); i++) { | |||
| const AnfNodePtr &node = inputs[i]; | |||
| args_conf_list.push_back(MakeConfig(node, context)); | |||
| MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString(); | |||
| } | |||
| std::vector<EvaluatorPtr> infs; | |||
| @@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||
| const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list) { | |||
| AbstractBasePtrList out_specs; | |||
| if (!multi_poss_.count(evaluators[0])) { | |||
| multi_poss_[evaluators[0]] = evaluators[1]; | |||
| multi_poss_[evaluators[1]] = evaluators[0]; | |||
| } | |||
| AbstractBasePtrList args_spec_list; | |||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | |||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | |||
| @@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||
| for (auto eval : evaluators) { | |||
| auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | |||
| if (fg_eval) { | |||
| auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs(); | |||
| auto fg = fg_eval->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(fg); | |||
| auto undetermined_fgs = fg->recursive_graphs(); | |||
| if (undetermined_fgs) { | |||
| for (auto undetermined_fg : *undetermined_fgs) { | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString(); | |||
| // As the current evaluator has multiple possibles, all the func_graphs which | |||
| // are recursive with the current func_graph are undetermined in control flow. | |||
| undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true); | |||
| } | |||
| auto fg_parent = fg->parent(); | |||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||
| } | |||
| } | |||
| auto current_inf = std::make_pair(eval, args_spec_list); | |||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | |||
| // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring. | |||
| auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf); | |||
| if (it == eval_trace_.end()) { | |||
| auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); | |||
| if (it == eval_trace_.rend()) { | |||
| eval_trace_.push_back(current_inf); | |||
| MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); | |||
| MS_EXCEPTION_IF_NULL(eval); | |||
| auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_spec); | |||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); | |||
| out_specs.push_back(out_spec); | |||
| MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString(); | |||
| eval_trace_.pop_back(); | |||
| if (eval_trace_.empty()) { | |||
| multi_poss_.clear(); | |||
| } | |||
| } else if (it != eval_trace_.rbegin()) { | |||
| // Find latest entry function to handle nested recursion. | |||
| EvaluatorPtr latest_entry = eval; | |||
| auto latest_entry_iter = eval_trace_.rbegin(); | |||
| for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { | |||
| auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); | |||
| if (it_temp != evaluators.end()) { | |||
| latest_entry = *it_temp; | |||
| latest_entry_iter = r_it; | |||
| break; | |||
| } | |||
| latest_entry_iter = ++r_it; | |||
| } | |||
| if (latest_entry != eval) { | |||
| MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); | |||
| continue; | |||
| } | |||
| bool has_undetermined = false; | |||
| // Check whether sub loop has untraced undetermined evaluator. | |||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals; | |||
| for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { | |||
| undetermined_evals.insert(*r_it); | |||
| } | |||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||
| for (auto u_eval : undetermined_evals) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||
| has_undetermined = true; | |||
| break; | |||
| } | |||
| } | |||
| if (has_undetermined == false) { | |||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||
| continue; | |||
| } | |||
| // Try to travel the latest undetermined. | |||
| if (latest_entry != eval_trace_.rbegin()->first) { | |||
| MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); | |||
| auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||
| MS_EXCEPTION_IF_NULL(out_spec); | |||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); | |||
| return out_spec; | |||
| } | |||
| } | |||
| } | |||
| if (out_specs.size() == 0) { | |||
| @@ -25,6 +25,7 @@ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include <map> | |||
| #ifdef DEBUG | |||
| #include <stack> | |||
| @@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| AnfNodeConfigMap anfnode_config_map_; | |||
| // Use a list to trace multiple evaluators. | |||
| std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | |||
| std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; | |||
| AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | |||
| const ConfigPtrList &args_conf_list); | |||
| @@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||
| def setup_module(module): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| @ms_function | |||
| def while_upper_bound(upper): | |||
| rval = 2 | |||
| @@ -392,6 +391,58 @@ def test_grad_factorial(): | |||
| res = C.grad(factorial)(3) | |||
| assert res == 11 | |||
| @ms_function | |||
| def factorial2(n): | |||
| """ factorial """ | |||
| if n != 0: | |||
| return n * factorial2(n-1) | |||
| elif n == 1: | |||
| return 1 * factorial2(n-1) | |||
| else: | |||
| return 1 | |||
| def test_factorial2(): | |||
| res = factorial2(3) | |||
| assert res == 6 | |||
| @ms_function | |||
| def foo(n): | |||
| if n <= 1: | |||
| if n == 1: | |||
| return foo(n-1) | |||
| else: | |||
| return 1 | |||
| else: | |||
| return foo(n-1) | |||
| def test_foo(): | |||
| res = foo(5) | |||
| assert res == 1 | |||
| @ms_function | |||
| def double_nested_loop(x): | |||
| i = 0 | |||
| s = 0 | |||
| while(i < x): | |||
| j = 0 | |||
| i = i + 1 | |||
| while(j < 3): | |||
| j = j + 1 | |||
| s = s + j | |||
| return s | |||
| def test_nested_loop(): | |||
| res = double_nested_loop(3) | |||
| assert res == 18 | |||
| @ms_function | |||
| def double_nested_loop2(x): | |||
| s = 0 | |||
| for i in range(x): | |||
| for j in range(3): | |||
| s = s + j | |||
| return s | |||
| def test_nested_loop2(): | |||
| res = double_nested_loop(1) | |||
| assert res == 6 | |||
| def _for(x): | |||
| """ _for """ | |||
| ret = x * x | |||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||
| def setup_module(module): | |||
| context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") | |||
| context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend") | |||
| context.set_context(enable_task_sink = True, device_id = 0) | |||
| @@ -86,7 +86,17 @@ def while_by_while(x, y, z): | |||
| x = x + 1 | |||
| x = x + 1 | |||
| return x | |||
| @ms_function | |||
| def while_in_while(x, y, z): | |||
| out = c4 | |||
| while x < y: | |||
| z = c4 + c4 | |||
| while z < y: | |||
| z = z + 1 | |||
| out = out + z | |||
| x = x + 1 | |||
| out = out + x | |||
| return out | |||
| def test_simple_if(): | |||
| output = simple_if(c1, c2, c3) | |||
| @@ -117,3 +127,7 @@ def test_while_by_while(): | |||
| expect = Tensor([28], mstype.int32) | |||
| assert output == expect | |||
| def test_while_in_while(): | |||
| output = while_in_while(c1, c2, c3) | |||
| expect = Tensor([1274], mstype.int32) | |||
| assert output == expect | |||