Merge pull request !32465 from Margaret_wangrui/fallback_control_flow_2pull/1/head
| @@ -128,6 +128,9 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) { | |||
| if (node != nullptr) { | |||
| return node; | |||
| } | |||
| // The fallback feature is enabled in default. | |||
| // Not support change the flag during the process is alive. | |||
| static const auto use_fallback = (parser_.support_fallback() != "0"); | |||
| // Get var from predecessor block, if can't get then make a resolve node to it | |||
| if (matured_) { | |||
| // If only one predecessor block, read the definition of var from it. | |||
| @@ -135,10 +138,6 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) { | |||
| auto block = prev_blocks_[0]; | |||
| MS_EXCEPTION_IF_NULL(block); | |||
| auto res = block->ReadVariable(var_name); | |||
| // The fallback feature is enabled in default. | |||
| // Not support change the flag during the process is alive. | |||
| static const auto use_fallback = (parser_.support_fallback() != "0"); | |||
| if (use_fallback) { | |||
| MS_LOG(DEBUG) << "Update global params of block: " << ToString() | |||
| << ", with previous block: " << block->ToString() | |||
| @@ -166,6 +165,11 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) { | |||
| ParameterPtr phi_param = std::make_shared<Parameter>(func_graph()); | |||
| MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node " | |||
| << phi_param->ToString() << " for " << var_name; | |||
| // If information transform by phi, need remove the var in interpret dict in fallback feature. | |||
| if (use_fallback) { | |||
| EraseLocalPyParam(var_name); | |||
| } | |||
| func_graph()->add_parameter(phi_param); | |||
| phi_nodes_[phi_param] = var_name; | |||
| WriteVariable(var_name, phi_param); | |||
| @@ -100,49 +100,52 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||
| } | |||
| } | |||
| std::tuple<std::vector<AnfNodePtr>, std::vector<AnfNodePtr>> local_py_params() { | |||
| std::tuple<std::map<std::string, AnfNodePtr>, std::map<std::string, AnfNodePtr>> local_py_params() { | |||
| return {local_py_params_keys_, local_py_params_values_}; | |||
| } | |||
| void AddLocalPyParam(const std::string &name, const AnfNodePtr &node) { | |||
| MS_LOG(DEBUG) << "Add '" << name << "', " << node->DebugString(); | |||
| local_py_params_keys_.emplace_back(NewValueNode(name)); | |||
| local_py_params_values_.emplace_back(node); | |||
| (void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(name, NewValueNode(name))); | |||
| (void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(name, node)); | |||
| } | |||
| // Call this methon only if you need update a variable. Usually variable override. | |||
| void UpdateLocalPyParam(const std::string &name, const AnfNodePtr &node) { | |||
| auto iter = std::find_if(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(), | |||
| [&name](const AnfNodePtr node) -> bool { | |||
| const auto value_node = dyn_cast<ValueNode>(node); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value()); | |||
| MS_EXCEPTION_IF_NULL(str_imm); | |||
| return name == str_imm->value(); | |||
| }); | |||
| if (iter == local_py_params_keys_.cend()) { | |||
| MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if 'name' not exist."; | |||
| auto key_iter = local_py_params_keys_.find(name); | |||
| if (key_iter == local_py_params_keys_.end()) { | |||
| MS_LOG(EXCEPTION) << "Only for updating. Should not call this method if '" << name << "' not exist."; | |||
| } | |||
| // Find the same position in 'values', and update the node. | |||
| auto distance = std::distance(local_py_params_keys_.cbegin(), iter); | |||
| auto values_pos_iter = local_py_params_values_.begin() + distance; | |||
| MS_LOG(DEBUG) << "Update '" << name << "', " << (*values_pos_iter)->DebugString() << " -> " << node->DebugString(); | |||
| *values_pos_iter = node; | |||
| MS_LOG(DEBUG) << "Update '" << name << "', " << local_py_params_values_[name]->DebugString() << " -> " | |||
| << node->DebugString(); | |||
| local_py_params_values_[name] = node; | |||
| } | |||
| void EraseLocalPyParam(const std::string &name) { | |||
| auto key_iter = local_py_params_keys_.find(name); | |||
| auto value_iter = local_py_params_values_.find(name); | |||
| if (key_iter != local_py_params_keys_.end() && value_iter != local_py_params_values_.end()) { | |||
| MS_LOG(DEBUG) << "Erase '" << name | |||
| << "' from local_py_params, the key node:" << local_py_params_keys_[name]->DebugString() | |||
| << ", the value node:" << local_py_params_values_[name]->DebugString(); | |||
| local_py_params_keys_.erase(key_iter); | |||
| local_py_params_values_.erase(value_iter); | |||
| } | |||
| } | |||
| void UpdateLocalPyParam(const std::vector<AnfNodePtr> &keys, const std::vector<AnfNodePtr> &values) { | |||
| void UpdateLocalPyParam(const std::map<std::string, AnfNodePtr> &keys, std::map<std::string, AnfNodePtr> values) { | |||
| if (keys.size() != values.size()) { | |||
| MS_LOG(EXCEPTION) << "keys size should be equal to values size."; | |||
| } | |||
| for (size_t index = 0; index < keys.size(); ++index) { | |||
| auto iter = std::find(local_py_params_keys_.cbegin(), local_py_params_keys_.cend(), keys[index]); | |||
| if (iter == local_py_params_keys_.cend()) { | |||
| local_py_params_keys_.emplace_back(keys[index]); | |||
| local_py_params_values_.emplace_back(values[index]); | |||
| MS_LOG(DEBUG) << "Add '" << keys[index]->DebugString() << "', " << values[index]->DebugString(); | |||
| for (auto iter = keys.begin(); iter != keys.end(); ++iter) { | |||
| const std::string &cur_key_name = iter->first; | |||
| if (local_py_params_keys_.find(cur_key_name) == local_py_params_keys_.end()) { | |||
| (void)local_py_params_keys_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, iter->second)); | |||
| (void)local_py_params_values_.insert(std::pair<std::string, AnfNodePtr>(cur_key_name, values[cur_key_name])); | |||
| MS_LOG(DEBUG) << "Add '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString(); | |||
| } else { | |||
| auto distance = std::distance(local_py_params_keys_.cbegin(), iter); | |||
| auto values_pos_iter = local_py_params_values_.begin() + distance; | |||
| MS_LOG(DEBUG) << "Update '" << keys[index]->DebugString() << "', " << values[index]->DebugString(); | |||
| *values_pos_iter = values[index]; | |||
| MS_LOG(DEBUG) << "Update '" << iter->second->DebugString() << "', " << values[cur_key_name]->DebugString(); | |||
| local_py_params_values_[cur_key_name] = values[cur_key_name]; | |||
| } | |||
| } | |||
| if (local_py_params_keys_.size() != local_py_params_values_.size()) { | |||
| @@ -186,8 +189,8 @@ class FunctionBlock : public std::enable_shared_from_this<FunctionBlock> { | |||
| // Collect all python symbols in the block. | |||
| // We treat both global symbols and local symbols declared previously as global symbols. | |||
| py::dict global_py_params_; | |||
| std::vector<AnfNodePtr> local_py_params_keys_; | |||
| std::vector<AnfNodePtr> local_py_params_values_; | |||
| std::map<std::string, AnfNodePtr> local_py_params_keys_; | |||
| std::map<std::string, AnfNodePtr> local_py_params_values_; | |||
| // Isolated nodes. | |||
| OrderedSet<AnfNodePtr> isolated_nodes_; | |||
| @@ -1613,7 +1613,7 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object | |||
| MS_EXCEPTION_IF_NULL(after_block->func_graph()); | |||
| after_block->func_graph()->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true); | |||
| } | |||
| static const auto use_fallback = (support_fallback() != "0"); | |||
| // Process the if-true branch | |||
| std::pair<FunctionBlockPtr, FunctionBlockPtr> true_branch_graphs; | |||
| py::object bodyNode = python_adapter::GetPyObjAttr(node, "body"); | |||
| @@ -1630,6 +1630,10 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object | |||
| } | |||
| MS_LOG(DEBUG) << "The true_end block jump to after, true_block: " << true_block->ToString() | |||
| << ", true_end: " << true_end->ToString(); | |||
| if (use_fallback) { | |||
| UpdateBlockPyParams(after_block, true_end); | |||
| } | |||
| } | |||
| // Process the orelse branch | |||
| @@ -1648,6 +1652,9 @@ FunctionBlockPtr Parser::ParseIf(const FunctionBlockPtr &block, const py::object | |||
| } | |||
| MS_LOG(DEBUG) << "The false_end block jump to after, false_block: " << false_block->ToString() | |||
| << ", false_end: " << false_end->ToString(); | |||
| if (use_fallback) { | |||
| UpdateBlockPyParams(after_block, false_end); | |||
| } | |||
| } | |||
| auto switch_app = block->ConditionalJump(bool_node, true_block, false_block); | |||
| @@ -2364,7 +2371,7 @@ void Parser::UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std:: | |||
| } | |||
| bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &global_dict, | |||
| const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) { | |||
| const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // Check global parameters. | |||
| if (global_dict.contains(script_text)) { | |||
| @@ -2373,14 +2380,7 @@ bool Parser::IsScriptInParams(const std::string &script_text, const py::dict &gl | |||
| } | |||
| // Check local parameters. | |||
| auto in_local_params = std::any_of(local_keys.begin(), local_keys.end(), [&script_text](const AnfNodePtr &node) { | |||
| const auto value_node = dyn_cast<ValueNode>(node); | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| const StringImmPtr &str_imm = dyn_cast<StringImm>(value_node->value()); | |||
| MS_EXCEPTION_IF_NULL(str_imm); | |||
| return script_text == str_imm->value(); | |||
| }); | |||
| if (in_local_params) { | |||
| if (local_keys.find(script_text) != local_keys.end()) { | |||
| MS_LOG(DEBUG) << "[" << func_graph->ToString() << "] Found `" << script_text << "` in local params."; | |||
| return true; | |||
| } | |||
| @@ -2414,7 +2414,7 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod | |||
| MS_EXCEPTION_IF_NULL(value_node); | |||
| // Check if script_text is in global/local params. | |||
| py::dict global_dict = block->global_py_params(); | |||
| const auto &[keys, values] = block->local_py_params(); | |||
| auto [keys, values] = block->local_py_params(); | |||
| if (IsTensorType(value_node, script_text)) { | |||
| return value_node; | |||
| } | |||
| @@ -2434,13 +2434,14 @@ AnfNodePtr Parser::MakeInterpretNode(const FunctionBlockPtr &block, const AnfNod | |||
| auto current_fg = value_node->func_graph(); | |||
| std::vector<AnfNodePtr> filter_keys; | |||
| std::vector<AnfNodePtr> filter_values; | |||
| for (size_t index = 0; index < values.size(); ++index) { | |||
| auto value = values[index]; | |||
| for (auto iter = values.begin(); iter != values.end(); ++iter) { | |||
| auto value = iter->second; | |||
| auto fg = GetValueNode<FuncGraphPtr>(value); | |||
| if (fg == current_fg) { | |||
| continue; | |||
| } | |||
| (void)filter_keys.emplace_back(keys[index]); | |||
| const std::string &name = iter->first; | |||
| (void)filter_keys.emplace_back(keys[name]); | |||
| (void)filter_values.emplace_back(value); | |||
| } | |||
| auto local_dict_node = ParseDictByKeysAndValues(block, filter_keys, filter_values); | |||
| @@ -219,7 +219,7 @@ class Parser { | |||
| // Check if script_text is in global/local params. | |||
| bool IsScriptInParams(const std::string &script_text, const py::dict &global_dict, | |||
| const std::vector<AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph); | |||
| const std::map<std::string, AnfNodePtr> &local_keys, const FuncGraphPtr &func_graph); | |||
| // Set the interpret flag for the node calling the interpret node. | |||
| void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const AnfNodePtr &node); | |||
| void UpdateInterpretForUserNode(const AnfNodePtr &user_node, const std::vector<AnfNodePtr> &nodes); | |||
| @@ -0,0 +1,65 @@ | |||
| # Copyright 2022 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ test graph fallback control flow.""" | |||
| import numpy as np | |||
| from mindspore import context | |||
| from mindspore.nn import Cell | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def test_single_if_no_else_type(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test fallback with control flow. | |||
| Expectation: No exception. | |||
| """ | |||
| class FalseNet(Cell): | |||
| def __init__(self): | |||
| super(FalseNet, self).__init__() | |||
| self.cond = False | |||
| def construct(self): | |||
| x = np.array(1) | |||
| if self.cond: | |||
| return type(2).mro() | |||
| return type(x).mro() | |||
| test_net = FalseNet() | |||
| res = test_net() | |||
| assert str(res) == "(<class 'numpy.ndarray'>, <class 'object'>)" | |||
| def test_single_if_no_else_type_2(): | |||
| """ | |||
| Feature: JIT Fallback | |||
| Description: Test fallback with control flow. | |||
| Expectation: No exception. | |||
| """ | |||
| class TrueNet(Cell): | |||
| def __init__(self): | |||
| super(TrueNet, self).__init__() | |||
| self.cond = True | |||
| def construct(self): | |||
| x = np.array(2) | |||
| y = 2 | |||
| if self.cond: | |||
| return type(y).mro() | |||
| return type(x).mro() | |||
| test_net = TrueNet() | |||
| res = test_net() | |||
| assert str(res) == "(<class 'int'>, <class 'object'>)" | |||