Merge pull request !4898 from fary86/fix_call_depth_too_largetags/v0.7.0-beta
| @@ -113,6 +113,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||||
| .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") | .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") | ||||
| .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") | .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") | ||||
| .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") | .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") | ||||
| .def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.") | |||||
| .def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.") | |||||
| .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") | .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") | ||||
| .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") | .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") | ||||
| .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, | .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, | ||||
| @@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||||
| const AnfNodePtr &func_node = fg->get_return(); | const AnfNodePtr &func_node = fg->get_return(); | ||||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() | MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() | ||||
| << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); | |||||
| << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString() | |||||
| << ", current function call depth: " << engine->function_call_depth(); | |||||
| AbstractBasePtr ret_base = nullptr; | AbstractBasePtr ret_base = nullptr; | ||||
| engine->IncreaseFunctionCallDepth(); | |||||
| if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) { | |||||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << "."; | |||||
| } | |||||
| std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); | std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); | ||||
| for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { | for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { | ||||
| const auto &node = *it; | const auto &node = *it; | ||||
| @@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() | MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() | ||||
| << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); | << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); | ||||
| } | } | ||||
| engine->DecreaseFunctionCallDepth(); | |||||
| MS_EXCEPTION_IF_NULL(ret_base); | MS_EXCEPTION_IF_NULL(ret_base); | ||||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() | MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() | ||||
| @@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac | |||||
| AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); | AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); | ||||
| // Running the analyzer. | // Running the analyzer. | ||||
| ResetFunctionCallDepth(); | |||||
| AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); | AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); | ||||
| MS_EXCEPTION_IF_NULL(root_context); | MS_EXCEPTION_IF_NULL(root_context); | ||||
| MS_EXCEPTION_IF_NULL(root_context->func_graph()); | MS_EXCEPTION_IF_NULL(root_context->func_graph()); | ||||
| @@ -185,7 +185,9 @@ struct PartialAppHasher { | |||||
| class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | ||||
| public: | public: | ||||
| AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | ||||
| : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} | |||||
| : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) { | |||||
| function_call_depth_ = 0; | |||||
| } | |||||
| ~AnalysisEngine() = default; | ~AnalysisEngine() = default; | ||||
| // func_graph: The func_graph to analyze. | // func_graph: The func_graph to analyze. | ||||
| @@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| AnalysisCache cache_; | AnalysisCache cache_; | ||||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | ||||
| void ResetFunctionCallDepth() { function_call_depth_ = 0; } | |||||
| void IncreaseFunctionCallDepth() { function_call_depth_++; } | |||||
| void DecreaseFunctionCallDepth() { | |||||
| if (function_call_depth_ == 0) { | |||||
| MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it."; | |||||
| } | |||||
| function_call_depth_--; | |||||
| } | |||||
| unsigned int function_call_depth() { return function_call_depth_; } | |||||
| private: | private: | ||||
| void SetUndeterminedFlag(const EvaluatorPtr &evaluator); | void SetUndeterminedFlag(const EvaluatorPtr &evaluator); | ||||
| EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, | EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, | ||||
| @@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||||
| const ConfigPtrList &args_conf_list); | const ConfigPtrList &args_conf_list); | ||||
| EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | ||||
| const ConfigPtrList &args_conf_list); | const ConfigPtrList &args_conf_list); | ||||
| // record current depth of function call statck | |||||
| unsigned int function_call_depth_; | |||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| std::vector<AnfNodePtr> compute_conf_stack_; | std::vector<AnfNodePtr> compute_conf_stack_; | ||||
| @@ -234,6 +234,17 @@ class _Context: | |||||
| if not success: | if not success: | ||||
| raise RuntimeError("Device id set failed!!!") | raise RuntimeError("Device id set failed!!!") | ||||
| @property | |||||
| def max_call_depth(self): | |||||
| return self._context_handle.get_max_call_depth() | |||||
| @max_call_depth.setter | |||||
| def max_call_depth(self, max_call_depth): | |||||
| if max_call_depth <= 0: | |||||
| raise ValueError( | |||||
| "Max call depth must be greater than 0, but got {}".format(max_call_depth)) | |||||
| self._context_handle.set_max_call_depth(max_call_depth) | |||||
| @property | @property | ||||
| def enable_auto_mixed_precision(self): | def enable_auto_mixed_precision(self): | ||||
| return self._context_handle.get_auto_mixed_precision_flag() | return self._context_handle.get_auto_mixed_precision_flag() | ||||
| @@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs): | |||||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | full_batch (bool): Whether to load the whole batch on each device. Default: False. | ||||
| enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in | enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in | ||||
| data parallel training in the benefit of time and memory saving. | data parallel training in the benefit of time and memory saving. | ||||
| max_call_depth(int): Specify the function call depth limit. Default: 1000. | |||||
| Raises: | Raises: | ||||
| @@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs): | |||||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | >>> context.set_auto_parallel_context(parameter_broadcast=False) | ||||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | ||||
| >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | ||||
| >>> context.set_auto_parallel_context(max_call_depth=80) | |||||
| """ | """ | ||||
| _set_auto_parallel_context(**kwargs) | _set_auto_parallel_context(**kwargs) | ||||
| @@ -532,7 +545,7 @@ def reset_auto_parallel_context(): | |||||
| save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | ||||
| enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | ||||
| enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, | enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, | ||||
| enable_sparse=bool) | |||||
| enable_sparse=bool, max_call_depth=int) | |||||
| def set_context(**kwargs): | def set_context(**kwargs): | ||||
| """ | """ | ||||
| Sets context for running environment. | Sets context for running environment. | ||||
| @@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||||
| } else { | } else { | ||||
| device_id_ = 0; | device_id_ = 0; | ||||
| } | } | ||||
| max_call_depth_ = MAX_CALL_DEPTH_DEFAULT; | |||||
| backend_policy_ = policy_map_[policy]; | backend_policy_ = policy_map_[policy]; | ||||
| device_target_ = target; | device_target_ = target; | ||||
| execution_mode_ = kPynativeMode; | execution_mode_ = kPynativeMode; | ||||
| @@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend"; | |||||
| const char kDavinciInferenceDevice[] = "AscendInference"; | const char kDavinciInferenceDevice[] = "AscendInference"; | ||||
| const char kDavinciDevice[] = "Davinci"; | const char kDavinciDevice[] = "Davinci"; | ||||
| const char KNpuLog[] = "_npu_log"; | const char KNpuLog[] = "_npu_log"; | ||||
| const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; | |||||
| const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; | const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; | ||||
| // The default max available device memory is 1024GB. | // The default max available device memory is 1024GB. | ||||
| const float kDefaultMaxDeviceMemory = 1024; | const float kDefaultMaxDeviceMemory = 1024; | ||||
| @@ -80,6 +82,13 @@ class MsContext { | |||||
| uint32_t device_id() const { return device_id_; } | uint32_t device_id() const { return device_id_; } | ||||
| bool set_device_id(uint32_t device_id); | bool set_device_id(uint32_t device_id); | ||||
| // uint32_t max_call_depth_ | |||||
| uint32_t max_call_depth() const { return max_call_depth_; } | |||||
| inline bool set_max_call_depth(uint32_t max_call_depth) { | |||||
| max_call_depth_ = max_call_depth; | |||||
| return true; | |||||
| } | |||||
| bool save_graphs_flag() const { return save_graphs_flag_; } | bool save_graphs_flag() const { return save_graphs_flag_; } | ||||
| void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } | void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } | ||||
| @@ -171,6 +180,7 @@ class MsContext { | |||||
| MsBackendPolicy backend_policy_; | MsBackendPolicy backend_policy_; | ||||
| std::string device_target_; | std::string device_target_; | ||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| uint32_t max_call_depth_; | |||||
| int execution_mode_; | int execution_mode_; | ||||
| bool enable_pynative_infer_; | bool enable_pynative_infer_; | ||||
| bool enable_pynative_hook_; | bool enable_pynative_hook_; | ||||
| @@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break(): | |||||
| x = self.flatten(x + elem1) | x = self.flatten(x + elem1) | ||||
| return x | return x | ||||
| old_max_call_depth = context.get_context('max_call_depth') | |||||
| context.set_context(max_call_depth=2000) | |||||
| t = Tensor(np.ones([2, 3], dtype=np.float32)) | t = Tensor(np.ones([2, 3], dtype=np.float32)) | ||||
| net = Net() | net = Net() | ||||
| net(t) | net(t) | ||||
| context.set_context(max_call_depth=old_max_call_depth) | |||||
| def test_mixed_precision_cast(): | def test_mixed_precision_cast(): | ||||
| @@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive(): | |||||
| with pytest.raises(ValueError): | with pytest.raises(ValueError): | ||||
| net(i, input1) | net(i, input1) | ||||
| def test_recursive_call(): | |||||
| class Net(nn.Cell): | |||||
| """ Net definition """ | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.fc = nn.Dense(10, 10) # padding=0 | |||||
| #self.net2 = Net2() | |||||
| def construct(self, x): | |||||
| net2 = Net2() | |||||
| x = net2(x) | |||||
| out = self.fc(x) | |||||
| return out | |||||
| class Net2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net2, self).__init__() | |||||
| self.net = Net() | |||||
| self.fc = nn.Dense(10, 10) | |||||
| def construct(self, x): | |||||
| x = self.net(x) | |||||
| out = self.fc(x) | |||||
| return out | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||||
| old_max_call_depth = context.get_context('max_call_depth') | |||||
| context.set_context(max_call_depth=80) | |||||
| input_data = Tensor(np.identity(10).astype(np.float32)) | |||||
| net = Net2() | |||||
| with pytest.raises(RuntimeError): | |||||
| net(input_data) | |||||
| context.set_context(max_call_depth=old_max_call_depth) | |||||