Merge pull request !4898 from fary86/fix_call_depth_too_largetags/v0.7.0-beta
| @@ -113,6 +113,8 @@ PYBIND11_MODULE(_c_expression, m) { | |||
| .def("set_device_target", &mindspore::MsContext::set_device_target, "Set device target.") | |||
| .def("get_device_id", &mindspore::MsContext::device_id, "Get device id.") | |||
| .def("set_device_id", &mindspore::MsContext::set_device_id, "Set device id.") | |||
| .def("get_max_call_depth", &mindspore::MsContext::max_call_depth, "Get max call depth.") | |||
| .def("set_max_call_depth", &mindspore::MsContext::set_max_call_depth, "Set max call depth.") | |||
| .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") | |||
| .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") | |||
| .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, | |||
| @@ -114,8 +114,13 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| const AnfNodePtr &func_node = fg->get_return(); | |||
| MS_LOG(DEBUG) << "Analysis FuncGraph begin, func graph: " << fg.get() << fg->ToString() | |||
| << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString(); | |||
| << ", context: " << graph_context_->ToString() << ", return node: " << func_node->DebugString() | |||
| << ", current function call depth: " << engine->function_call_depth(); | |||
| AbstractBasePtr ret_base = nullptr; | |||
| engine->IncreaseFunctionCallDepth(); | |||
| if (engine->function_call_depth() > MsContext::GetInstance()->max_call_depth()) { | |||
| MS_LOG(EXCEPTION) << "Exceed function call depth limit " << MsContext::GetInstance()->max_call_depth() << "."; | |||
| } | |||
| std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); | |||
| for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { | |||
| const auto &node = *it; | |||
| @@ -126,6 +131,7 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr | |||
| MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() | |||
| << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); | |||
| } | |||
| engine->DecreaseFunctionCallDepth(); | |||
| MS_EXCEPTION_IF_NULL(ret_base); | |||
| MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() | |||
| @@ -119,6 +119,7 @@ AnalysisResult AnalysisEngine::Run(const FuncGraphPtr &func_graph, const Abstrac | |||
| AnalysisContextPtr empty_context = AnalysisContext::DummyContext(); | |||
| // Running the analyzer. | |||
| ResetFunctionCallDepth(); | |||
| AnalysisContextPtr root_context = Run(func_graph, empty_context, args_conf_list); | |||
| MS_EXCEPTION_IF_NULL(root_context); | |||
| MS_EXCEPTION_IF_NULL(root_context->func_graph()); | |||
| @@ -185,7 +185,9 @@ struct PartialAppHasher { | |||
| class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| public: | |||
| AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) | |||
| : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) {} | |||
| : cache_(AnalysisCache()), prim_constructors_(prim_evaluator_map), func_graph_manager_(func_graph_manager) { | |||
| function_call_depth_ = 0; | |||
| } | |||
| ~AnalysisEngine() = default; | |||
| // func_graph: The func_graph to analyze. | |||
| @@ -231,6 +233,19 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| AnalysisCache cache_; | |||
| std::unordered_map<PrimitivePyPtr, EvaluatorPtr> prim_py_evaluators_; | |||
| void ResetFunctionCallDepth() { function_call_depth_ = 0; } | |||
| void IncreaseFunctionCallDepth() { function_call_depth_++; } | |||
| void DecreaseFunctionCallDepth() { | |||
| if (function_call_depth_ == 0) { | |||
| MS_LOG(EXCEPTION) << "Current function call depth is already 0, can not decrease it."; | |||
| } | |||
| function_call_depth_--; | |||
| } | |||
| unsigned int function_call_depth() { return function_call_depth_; } | |||
| private: | |||
| void SetUndeterminedFlag(const EvaluatorPtr &evaluator); | |||
| EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval, | |||
| @@ -257,6 +272,8 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> { | |||
| const ConfigPtrList &args_conf_list); | |||
| EvalResultPtr ExecuteMultipleEvaluators(const std::vector<EvaluatorPtr> &evaluators, const AnfNodeConfigPtr &out_conf, | |||
| const ConfigPtrList &args_conf_list); | |||
| // record current depth of function call statck | |||
| unsigned int function_call_depth_; | |||
| #ifdef DEBUG | |||
| std::vector<AnfNodePtr> compute_conf_stack_; | |||
| @@ -234,6 +234,17 @@ class _Context: | |||
| if not success: | |||
| raise RuntimeError("Device id set failed!!!") | |||
| @property | |||
| def max_call_depth(self): | |||
| return self._context_handle.get_max_call_depth() | |||
| @max_call_depth.setter | |||
| def max_call_depth(self, max_call_depth): | |||
| if max_call_depth <= 0: | |||
| raise ValueError( | |||
| "Max call depth must be greater than 0, but got {}".format(max_call_depth)) | |||
| self._context_handle.set_max_call_depth(max_call_depth) | |||
| @property | |||
| def enable_auto_mixed_precision(self): | |||
| return self._context_handle.get_auto_mixed_precision_flag() | |||
| @@ -475,6 +486,7 @@ def set_auto_parallel_context(**kwargs): | |||
| full_batch (bool): Whether to load the whole batch on each device. Default: False. | |||
| enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in | |||
| data parallel training in the benefit of time and memory saving. | |||
| max_call_depth(int): Specify the function call depth limit. Default: 1000. | |||
| Raises: | |||
| @@ -490,6 +502,7 @@ def set_auto_parallel_context(**kwargs): | |||
| >>> context.set_auto_parallel_context(parameter_broadcast=False) | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt") | |||
| >>> context.set_auto_parallel_context(max_call_depth=80) | |||
| """ | |||
| _set_auto_parallel_context(**kwargs) | |||
| @@ -532,7 +545,7 @@ def reset_auto_parallel_context(): | |||
| save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, | |||
| enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool, | |||
| enable_graph_kernel=bool, check_bprop=bool, max_device_memory=str, print_file_path=str, | |||
| enable_sparse=bool) | |||
| enable_sparse=bool, max_call_depth=int) | |||
| def set_context(**kwargs): | |||
| """ | |||
| Sets context for running environment. | |||
| @@ -47,6 +47,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||
| } else { | |||
| device_id_ = 0; | |||
| } | |||
| max_call_depth_ = MAX_CALL_DEPTH_DEFAULT; | |||
| backend_policy_ = policy_map_[policy]; | |||
| device_target_ = target; | |||
| execution_mode_ = kPynativeMode; | |||
| @@ -43,6 +43,8 @@ const char kAscendDevice[] = "Ascend"; | |||
| const char kDavinciInferenceDevice[] = "AscendInference"; | |||
| const char kDavinciDevice[] = "Davinci"; | |||
| const char KNpuLog[] = "_npu_log"; | |||
| const unsigned int MAX_CALL_DEPTH_DEFAULT = 1000; | |||
| const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, kDavinciDevice}; | |||
| // The default max available device memory is 1024GB. | |||
| const float kDefaultMaxDeviceMemory = 1024; | |||
| @@ -80,6 +82,13 @@ class MsContext { | |||
| uint32_t device_id() const { return device_id_; } | |||
| bool set_device_id(uint32_t device_id); | |||
| // uint32_t max_call_depth_ | |||
| uint32_t max_call_depth() const { return max_call_depth_; } | |||
| inline bool set_max_call_depth(uint32_t max_call_depth) { | |||
| max_call_depth_ = max_call_depth; | |||
| return true; | |||
| } | |||
| bool save_graphs_flag() const { return save_graphs_flag_; } | |||
| void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } | |||
| @@ -171,6 +180,7 @@ class MsContext { | |||
| MsBackendPolicy backend_policy_; | |||
| std::string device_target_; | |||
| uint32_t device_id_; | |||
| uint32_t max_call_depth_; | |||
| int execution_mode_; | |||
| bool enable_pynative_infer_; | |||
| bool enable_pynative_hook_; | |||
| @@ -795,9 +795,12 @@ def test_large_for_loop_with_continue_break(): | |||
| x = self.flatten(x + elem1) | |||
| return x | |||
| old_max_call_depth = context.get_context('max_call_depth') | |||
| context.set_context(max_call_depth=2000) | |||
| t = Tensor(np.ones([2, 3], dtype=np.float32)) | |||
| net = Net() | |||
| net(t) | |||
| context.set_context(max_call_depth=old_max_call_depth) | |||
| def test_mixed_precision_cast(): | |||
| @@ -873,3 +876,38 @@ def test_parser_switch_layer_func_primitive(): | |||
| with pytest.raises(ValueError): | |||
| net(i, input1) | |||
| def test_recursive_call(): | |||
| class Net(nn.Cell): | |||
| """ Net definition """ | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.fc = nn.Dense(10, 10) # padding=0 | |||
| #self.net2 = Net2() | |||
| def construct(self, x): | |||
| net2 = Net2() | |||
| x = net2(x) | |||
| out = self.fc(x) | |||
| return out | |||
| class Net2(nn.Cell): | |||
| def __init__(self): | |||
| super(Net2, self).__init__() | |||
| self.net = Net() | |||
| self.fc = nn.Dense(10, 10) | |||
| def construct(self, x): | |||
| x = self.net(x) | |||
| out = self.fc(x) | |||
| return out | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=False) | |||
| old_max_call_depth = context.get_context('max_call_depth') | |||
| context.set_context(max_call_depth=80) | |||
| input_data = Tensor(np.identity(10).astype(np.float32)) | |||
| net = Net2() | |||
| with pytest.raises(RuntimeError): | |||
| net(input_data) | |||
| context.set_context(max_call_depth=old_max_call_depth) | |||