Merge pull request !766 from gongchen/nest_looptags/v0.3.0-alpha
| @@ -27,14 +27,13 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <initializer_list> | #include <initializer_list> | ||||
| #ifdef DEBUG | |||||
| #include "debug/draw.h" | #include "debug/draw.h" | ||||
| #include "debug/anf_ir_dump.h" | #include "debug/anf_ir_dump.h" | ||||
| #endif | |||||
| #include "debug/trace.h" | |||||
| #include "optimizer/opt.h" | #include "optimizer/opt.h" | ||||
| #include "pipeline/resource.h" | #include "pipeline/resource.h" | ||||
| #include "pipeline/action.h" | #include "pipeline/action.h" | ||||
| #include "debug/trace.h" | |||||
| #include "utils/context/ms_context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| @@ -133,7 +132,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | FuncGraphPtr step(FuncGraphPtr func_graph, bool use_profile = true) { | ||||
| // Optimizer step counter; | // Optimizer step counter; | ||||
| int counter = 1; | |||||
| int counter = -1; | |||||
| bool changes = true; | bool changes = true; | ||||
| while (changes) { | while (changes) { | ||||
| @@ -170,13 +169,14 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> { | |||||
| } | } | ||||
| }; | }; | ||||
| use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | use_profile ? (WITH(MsProfile::GetProfile()->Step(pass_names_[i])) opt_func) : opt_func(); | ||||
| #ifdef DEBUG | |||||
| MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||||
| auto fg_name = name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||||
| func_graph->DumpFuncGraph(fg_name); | |||||
| DumpIR(fg_name + ".ir", func_graph); | |||||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | |||||
| #endif | |||||
| if (IS_OUTPUT_ON(mindspore::DEBUG) && MsContext::GetInstance()->save_graphs_flag()) { | |||||
| MS_LOG(DEBUG) << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end."; | |||||
| auto fg_name = | |||||
| "opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i]; | |||||
| func_graph->DumpFuncGraph(fg_name); | |||||
| DumpIR(fg_name + ".ir", func_graph); | |||||
| MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph."; | |||||
| } | |||||
| } | } | ||||
| }; | }; | ||||
| use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc(); | use_profile ? (WITH(MsProfile::GetProfile()->Lap(counter++)) run_runc) : run_runc(); | ||||
| @@ -32,6 +32,7 @@ | |||||
| #include "pipeline/static_analysis/static_analysis.h" | #include "pipeline/static_analysis/static_analysis.h" | ||||
| #include "pipeline/static_analysis/program_specialize.h" | #include "pipeline/static_analysis/program_specialize.h" | ||||
| #include "pipeline/resource.h" | #include "pipeline/resource.h" | ||||
| #include "utils/context/ms_context.h" | |||||
| #include "pipeline/remove_value_node_dup.h" | #include "pipeline/remove_value_node_dup.h" | ||||
| #include "optimizer/optimizer.h" | #include "optimizer/optimizer.h" | ||||
| #include "vm/transform.h" | #include "vm/transform.h" | ||||
| @@ -240,13 +241,23 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { | |||||
| } | } | ||||
| bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) { | bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) { | ||||
| size_t counter = 0; | |||||
| for (auto &pass : passes) { | for (auto &pass : passes) { | ||||
| WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() { | |||||
| WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() { | |||||
| MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; | MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; | ||||
| auto result = pass.second(res); | auto result = pass.second(res); | ||||
| if (!result) { | if (!result) { | ||||
| MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; | MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first; | ||||
| } | } | ||||
| if (MsContext::GetInstance()->save_graphs_flag() && res->func_graph() != nullptr) { | |||||
| auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first; | |||||
| auto func_graph = res->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| func_graph->DumpFuncGraph(fg_name); | |||||
| DumpIR(fg_name + ".ir", func_graph); | |||||
| MS_LOG(DEBUG) << "Dump " << fg_name << " func graph."; | |||||
| } | |||||
| counter++; | |||||
| MS_LOG(DEBUG) << "Pass " << pass.first << " end."; | MS_LOG(DEBUG) << "Pass " << pass.first << " end."; | ||||
| }; | }; | ||||
| } | } | ||||
| @@ -55,6 +55,7 @@ void InferFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList & | |||||
| AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, | AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr &engine, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); | AbstractBasePtrList normalized_args_spec_list = NormalizeArgs(args_spec_list); | ||||
| normalized_args_spec_list = BroadenUndeterminedArgs(normalized_args_spec_list); | |||||
| FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); | FuncGraphPtr fg = GetFuncGraph(engine, normalized_args_spec_list); | ||||
| MS_EXCEPTION_IF_NULL(parent_context_); | MS_EXCEPTION_IF_NULL(parent_context_); | ||||
| AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); | AnalysisContextPtr context = parent_context_->NewFuncGraphContext(fg, normalized_args_spec_list); | ||||
| @@ -140,7 +141,14 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList | |||||
| << ", broaded: " << mindspore::ToString(broaded_list); | << ", broaded: " << mindspore::ToString(broaded_list); | ||||
| return broaded_list; | return broaded_list; | ||||
| } | } | ||||
| return args_spec_list; | |||||
| } | |||||
| AbstractBasePtrList FuncGraphEvaluator::BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||||
| if (func_graph_->has_flag(FUNC_GRAPH_FLAG_IGNORE_VALUES)) { | |||||
| return args_spec_list; | |||||
| } | |||||
| if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { | if (func_graph_->has_flag(kFuncGraphFlagUndetermined)) { | ||||
| if (parent_context_) { | if (parent_context_) { | ||||
| MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() | MS_LOG(DEBUG) << "Undeterminate FuncGraphEvaluator " << ToString() | ||||
| @@ -160,6 +168,21 @@ AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList | |||||
| return joined_args_spec_list; | return joined_args_spec_list; | ||||
| } | } | ||||
| } | } | ||||
| if (trace_.size() != 0) { | |||||
| MS_LOG(DEBUG) << "Current eval args: " << ::mindspore::ToString(args_spec_list); | |||||
| MS_LOG(DEBUG) << "Last eval args: " << ::mindspore::ToString(trace_.back()); | |||||
| // Join the last eval arguments and current arguments to check if there are loop variant. | |||||
| auto joined_args_spec_list = AbstractJoin(args_spec_list, trace_.back()); | |||||
| // If there is loop variant, all arguments need to be broaden to avoid wrong constant propagation. | |||||
| if (!(joined_args_spec_list == args_spec_list)) { | |||||
| trace_.push_back(joined_args_spec_list); | |||||
| func_graph_->set_flags(FUNC_GRAPH_FLAG_IGNORE_VALUES, true); | |||||
| } | |||||
| MS_LOG(DEBUG) << "Joined eval args: " << ::mindspore::ToString(joined_args_spec_list); | |||||
| return joined_args_spec_list; | |||||
| } else { | |||||
| trace_.push_back(args_spec_list); | |||||
| } | |||||
| } | } | ||||
| return args_spec_list; | return args_spec_list; | ||||
| } | } | ||||
| @@ -224,6 +247,7 @@ AbstractBasePtr Evaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &ar | |||||
| return conf->GetEvaluatedValue(); | return conf->GetEvaluatedValue(); | ||||
| }); | }); | ||||
| args_spec_list = NormalizeArgs(args_spec_list); | args_spec_list = NormalizeArgs(args_spec_list); | ||||
| args_spec_list = BroadenUndeterminedArgs(args_spec_list); | |||||
| trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf); | trace::TraceGraphInferEnter(shared_from_base<Evaluator>(), out_conf); | ||||
| InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | InferEntryLogging(shared_from_base<Evaluator>(), args_spec_list, out_conf); | ||||
| MS_EXCEPTION_IF_NULL(cache_); | MS_EXCEPTION_IF_NULL(cache_); | ||||
| @@ -47,6 +47,10 @@ class Evaluator : public Base { | |||||
| virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | virtual AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { return args_spec_list; } | ||||
| virtual AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) { | |||||
| return args_spec_list; | |||||
| } | |||||
| std::string ToString() const override { return identifier_; } | std::string ToString() const override { return identifier_; } | ||||
| virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } | virtual AnfNodePtr bound_node() const { return bound_node_.lock(); } | ||||
| @@ -181,12 +185,14 @@ class FuncGraphEvaluator : public BaseFuncGraphEvaluator { | |||||
| FuncGraphPtr func_graph() { return func_graph_; } | FuncGraphPtr func_graph() { return func_graph_; } | ||||
| AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; | AbstractBasePtrList NormalizeArgs(const AbstractBasePtrList &args_spec_list) const override; | ||||
| AbstractBasePtrList BroadenUndeterminedArgs(const AbstractBasePtrList &args_spec_list) override; | |||||
| std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } | std::string ToString() const override { return identifier_ + "_" + func_graph_->ToString(); } | ||||
| private: | private: | ||||
| FuncGraphPtr func_graph_; | FuncGraphPtr func_graph_; | ||||
| std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual> | std::unordered_map<AbstractBasePtrList, FuncGraphPtr, AbstractBasePtrListHasher, AbstractBasePtrListEqual> | ||||
| func_graph_cache_; | func_graph_cache_; | ||||
| std::vector<AbstractBasePtrList> trace_; | |||||
| }; | }; | ||||
| using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>; | using FuncGraphEvaluatorPtr = std::shared_ptr<FuncGraphEvaluator>; | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include "pipeline/static_analysis/static_analysis.h" | #include "pipeline/static_analysis/static_analysis.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <set> | |||||
| #include "pipeline/static_analysis/utils.h" | #include "pipeline/static_analysis/utils.h" | ||||
| #include "pipeline/static_analysis/prim.h" | #include "pipeline/static_analysis/prim.h" | ||||
| @@ -239,7 +240,6 @@ AbstractBasePtr AnalysisEngine::InferCNode(const CNodePtr &cnode, const AnfNodeC | |||||
| for (std::size_t i = 1; i < inputs.size(); i++) { | for (std::size_t i = 1; i < inputs.size(); i++) { | ||||
| const AnfNodePtr &node = inputs[i]; | const AnfNodePtr &node = inputs[i]; | ||||
| args_conf_list.push_back(MakeConfig(node, context)); | args_conf_list.push_back(MakeConfig(node, context)); | ||||
| MS_LOG(DEBUG) << "Current CNode args_conf_list[" << i << "] node: " << node->DebugString(); | |||||
| } | } | ||||
| std::vector<EvaluatorPtr> infs; | std::vector<EvaluatorPtr> infs; | ||||
| @@ -469,6 +469,10 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| const AnfNodeConfigPtr &out_conf, | const AnfNodeConfigPtr &out_conf, | ||||
| const ConfigPtrList &args_conf_list) { | const ConfigPtrList &args_conf_list) { | ||||
| AbstractBasePtrList out_specs; | AbstractBasePtrList out_specs; | ||||
| if (!multi_poss_.count(evaluators[0])) { | |||||
| multi_poss_[evaluators[0]] = evaluators[1]; | |||||
| multi_poss_[evaluators[1]] = evaluators[0]; | |||||
| } | |||||
| AbstractBasePtrList args_spec_list; | AbstractBasePtrList args_spec_list; | ||||
| (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | (void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list), | ||||
| [](const ConfigPtr &conf) -> AbstractBasePtr { | [](const ConfigPtr &conf) -> AbstractBasePtr { | ||||
| @@ -478,28 +482,81 @@ AbstractBasePtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Eval | |||||
| for (auto eval : evaluators) { | for (auto eval : evaluators) { | ||||
| auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | auto fg_eval = eval->cast<FuncGraphEvaluatorPtr>(); | ||||
| if (fg_eval) { | if (fg_eval) { | ||||
| auto undetermined_fgs = fg_eval->func_graph()->recursive_graphs(); | |||||
| auto fg = fg_eval->func_graph(); | |||||
| MS_EXCEPTION_IF_NULL(fg); | |||||
| auto undetermined_fgs = fg->recursive_graphs(); | |||||
| if (undetermined_fgs) { | if (undetermined_fgs) { | ||||
| for (auto undetermined_fg : *undetermined_fgs) { | |||||
| MS_LOG(DEBUG) << "Set graph undetermined: " << undetermined_fg->ToString(); | |||||
| // As the current evaluator has multiple possibles, all the func_graphs which | |||||
| // are recursive with the current func_graph are undetermined in control flow. | |||||
| undetermined_fg->set_flags(kFuncGraphFlagUndetermined, true); | |||||
| } | |||||
| auto fg_parent = fg->parent(); | |||||
| MS_EXCEPTION_IF_NULL(fg_parent); | |||||
| fg_parent->set_flags(kFuncGraphFlagUndetermined, true); | |||||
| MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); | |||||
| } | } | ||||
| } | } | ||||
| auto current_inf = std::make_pair(eval, args_spec_list); | auto current_inf = std::make_pair(eval, args_spec_list); | ||||
| MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); | |||||
| // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring. | // If current evaluator is under tracing, then skip current evaluator to avoid recursively inferring. | ||||
| auto it = std::find(eval_trace_.begin(), eval_trace_.end(), current_inf); | |||||
| if (it == eval_trace_.end()) { | |||||
| auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); | |||||
| if (it == eval_trace_.rend()) { | |||||
| eval_trace_.push_back(current_inf); | eval_trace_.push_back(current_inf); | ||||
| MS_LOG(DEBUG) << "Trace Evaluator " << eval->ToString() << " ptr: " << eval.get(); | |||||
| MS_EXCEPTION_IF_NULL(eval); | MS_EXCEPTION_IF_NULL(eval); | ||||
| auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); | auto out_spec = eval->Run(shared_from_this(), args_conf_list, out_conf); | ||||
| MS_EXCEPTION_IF_NULL(out_spec); | MS_EXCEPTION_IF_NULL(out_spec); | ||||
| MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); | MS_LOG(DEBUG) << "Evaluator " << eval->ToString() << " return out_spec: " << out_spec->ToString(); | ||||
| out_specs.push_back(out_spec); | out_specs.push_back(out_spec); | ||||
| MS_LOG(DEBUG) << "Pop Evaluator " << eval->ToString(); | |||||
| eval_trace_.pop_back(); | eval_trace_.pop_back(); | ||||
| if (eval_trace_.empty()) { | |||||
| multi_poss_.clear(); | |||||
| } | |||||
| } else if (it != eval_trace_.rbegin()) { | |||||
| // Find latest entry function to handle nested recursion. | |||||
| EvaluatorPtr latest_entry = eval; | |||||
| auto latest_entry_iter = eval_trace_.rbegin(); | |||||
| for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { | |||||
| auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); | |||||
| if (it_temp != evaluators.end()) { | |||||
| latest_entry = *it_temp; | |||||
| latest_entry_iter = r_it; | |||||
| break; | |||||
| } | |||||
| latest_entry_iter = ++r_it; | |||||
| } | |||||
| if (latest_entry != eval) { | |||||
| MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); | |||||
| continue; | |||||
| } | |||||
| bool has_undetermined = false; | |||||
| // Check whether sub loop has untraced undetermined evaluator. | |||||
| std::set<std::pair<EvaluatorPtr, AbstractBasePtrList>> undetermined_evals; | |||||
| for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { | |||||
| undetermined_evals.insert(*r_it); | |||||
| } | |||||
| MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); | |||||
| for (auto u_eval : undetermined_evals) { | |||||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; | |||||
| if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { | |||||
| MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; | |||||
| has_undetermined = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (has_undetermined == false) { | |||||
| MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; | |||||
| continue; | |||||
| } | |||||
| // Try to travel the latest undetermined. | |||||
| if (latest_entry != eval_trace_.rbegin()->first) { | |||||
| MS_LOG(DEBUG) << "Direct Run Evaluator " << eval->ToString(); | |||||
| auto out_spec = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); | |||||
| MS_EXCEPTION_IF_NULL(out_spec); | |||||
| MS_LOG(DEBUG) << "Evaluator " << latest_entry->ToString() << " return out_spec: " << out_spec->ToString(); | |||||
| return out_spec; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| if (out_specs.size() == 0) { | if (out_specs.size() == 0) { | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include <utility> | #include <utility> | ||||
| #include <map> | |||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| #include <stack> | #include <stack> | ||||
| @@ -206,6 +207,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| AnfNodeConfigMap anfnode_config_map_; | AnfNodeConfigMap anfnode_config_map_; | ||||
| // Use a list to trace multiple evaluators. | // Use a list to trace multiple evaluators. | ||||
| std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | std::list<std::pair<EvaluatorPtr, AbstractBasePtrList>> eval_trace_; | ||||
| std::map<EvaluatorPtr, EvaluatorPtr> multi_poss_; | |||||
| AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, | ||||
| const ConfigPtrList &args_conf_list); | const ConfigPtrList &args_conf_list); | ||||
| @@ -39,7 +39,6 @@ from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode=context.PYNATIVE_MODE) | context.set_context(mode=context.PYNATIVE_MODE) | ||||
| @ms_function | @ms_function | ||||
| def while_upper_bound(upper): | def while_upper_bound(upper): | ||||
| rval = 2 | rval = 2 | ||||
| @@ -392,6 +391,58 @@ def test_grad_factorial(): | |||||
| res = C.grad(factorial)(3) | res = C.grad(factorial)(3) | ||||
| assert res == 11 | assert res == 11 | ||||
| @ms_function | |||||
| def factorial2(n): | |||||
| """ factorial """ | |||||
| if n != 0: | |||||
| return n * factorial2(n-1) | |||||
| elif n == 1: | |||||
| return 1 * factorial2(n-1) | |||||
| else: | |||||
| return 1 | |||||
| def test_factorial2(): | |||||
| res = factorial2(3) | |||||
| assert res == 6 | |||||
| @ms_function | |||||
| def foo(n): | |||||
| if n <= 1: | |||||
| if n == 1: | |||||
| return foo(n-1) | |||||
| else: | |||||
| return 1 | |||||
| else: | |||||
| return foo(n-1) | |||||
| def test_foo(): | |||||
| res = foo(5) | |||||
| assert res == 1 | |||||
| @ms_function | |||||
| def double_nested_loop(x): | |||||
| i = 0 | |||||
| s = 0 | |||||
| while(i < x): | |||||
| j = 0 | |||||
| i = i + 1 | |||||
| while(j < 3): | |||||
| j = j + 1 | |||||
| s = s + j | |||||
| return s | |||||
| def test_nested_loop(): | |||||
| res = double_nested_loop(3) | |||||
| assert res == 18 | |||||
| @ms_function | |||||
| def double_nested_loop2(x): | |||||
| s = 0 | |||||
| for i in range(x): | |||||
| for j in range(3): | |||||
| s = s + j | |||||
| return s | |||||
| def test_nested_loop2(): | |||||
| res = double_nested_loop(1) | |||||
| assert res == 6 | |||||
| def _for(x): | def _for(x): | ||||
| """ _for """ | """ _for """ | ||||
| ret = x * x | ret = x * x | ||||
| @@ -24,7 +24,7 @@ from mindspore.ops import operations as P | |||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode = context.PYNATIVE_MODE, save_graphs = True, device_target = "Ascend") | |||||
| context.set_context(mode = context.PYNATIVE_MODE, save_graphs = False, device_target = "Ascend") | |||||
| context.set_context(enable_task_sink = True, device_id = 0) | context.set_context(enable_task_sink = True, device_id = 0) | ||||
| @@ -86,7 +86,17 @@ def while_by_while(x, y, z): | |||||
| x = x + 1 | x = x + 1 | ||||
| x = x + 1 | x = x + 1 | ||||
| return x | return x | ||||
| @ms_function | |||||
| def while_in_while(x, y, z): | |||||
| out = c4 | |||||
| while x < y: | |||||
| z = c4 + c4 | |||||
| while z < y: | |||||
| z = z + 1 | |||||
| out = out + z | |||||
| x = x + 1 | |||||
| out = out + x | |||||
| return out | |||||
| def test_simple_if(): | def test_simple_if(): | ||||
| output = simple_if(c1, c2, c3) | output = simple_if(c1, c2, c3) | ||||
| @@ -117,3 +127,7 @@ def test_while_by_while(): | |||||
| expect = Tensor([28], mstype.int32) | expect = Tensor([28], mstype.int32) | ||||
| assert output == expect | assert output == expect | ||||
| def test_while_in_while(): | |||||
| output = while_in_while(c1, c2, c3) | |||||
| expect = Tensor([1274], mstype.int32) | |||||
| assert output == expect | |||||