/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/jit/parse/function_block.h" #include #include #include #include "pybind11/pybind11.h" #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/parse/parse.h" #include "pipeline/jit/parse/data_converter.h" #include "frontend/operator/ops.h" #include "utils/info.h" #include "debug/trace.h" #include "utils/utils.h" namespace mindspore { namespace py = pybind11; namespace parse { FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { func_graph_ = std::make_shared(); matured_ = false; } void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } static bool CanBeIsolatedNode(const std::string &var_name, const AnfNodePtr &node) { auto cnode = dyn_cast(node); if (cnode == nullptr || cnode->inputs().empty()) { // Not a valid cnode, can not be isolate node. return false; } auto prim = GetValueNode(cnode->inputs().at(0)); if (prim == nullptr) { // Not a primitive cnode, it may have side effects or not, // We add it as an isolate node if its name is not '_' or empty. // this means that code like: // _ = func_call() // will be ignored even if func_call() has side effects. return !var_name.empty() && var_name != "_"; } // Primitive cnode with side effects can be isolate nodes. auto effect_info = GetPrimEffectInfo(prim); bool has_effects = (effect_info.memory || effect_info.io); if (has_effects) { return true; } // Primitive cnode with 'no_eliminate' flag can be isolate nodes. return GetPrimitiveFlag(prim, ATTR_NO_ELIMINATE); } // Write variable records the variable name to corresponding node void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " write var `" << var_name << "` with node " << node->DebugString(); // The fallback feature is enabled in default. // Not support change the flag during the process is alive. static const auto use_fallback = (parser_.support_fallback() != "0"); auto [iter, is_new_name] = assigned_vars_.emplace(var_name, std::make_pair(node, false)); if (!is_new_name) { // If a cnode variable with same name already existed but not used, // add it as an isolate node. for example: // a = print(x) // a = print(y) // When we write variable 'a = print(y)', // the cnode 'print(x)' should added as an isolate node. auto is_used = iter->second.second; auto hidden_node = iter->second.first; auto is_isolated = CanBeIsolatedNode(var_name, hidden_node); if (!is_used && is_isolated) { MS_EXCEPTION_IF_NULL(hidden_node); MS_LOG(INFO) << "Isolated node found(Hidden), hidden_node: " << hidden_node->DebugString(2) << " is hidden by " << node->DebugString(2) << " with the same name, var_name: " << var_name << ", block: " << this << "/" << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << ", Line: " << trace::GetDebugInfo(hidden_node->debug_info(), "", kSourceLineTipDiscard); AddIsolatedNode(hidden_node); } iter->second = std::make_pair(node, false); if (use_fallback) { UpdateLocalPyParam(var_name, node); } } else { if (use_fallback) { AddLocalPyParam(var_name, node); } } } // Read variable from predecessors AnfNodePtr FunctionBlock::ReadVariable(const std::string &var_name) { MS_LOG(DEBUG) << "Read begin, var: " << var_name << ", block: " << ToString(); // Get var node if it is found auto found = assigned_vars_.find(var_name); if (found != assigned_vars_.end()) { auto &node = found->second.first; MS_EXCEPTION_IF_NULL(node); // Mark the variable as used. found->second.second = true; auto iter = resolve_to_removable_phis_.find(node); if (iter != resolve_to_removable_phis_.end()) { return iter->second; } return node; } // Get var from predecessor block, if can't get then make a resolve node to it if (matured_) { // If only one predecessor block, read the definition of var from it. if (prev_blocks_.size() == 1) { auto block = prev_blocks_[0]; MS_EXCEPTION_IF_NULL(block); auto res = block->ReadVariable(var_name); // The fallback feature is enabled in default. // Not support change the flag during the process is alive. static const auto use_fallback = (parser_.support_fallback() != "0"); if (use_fallback) { MS_LOG(DEBUG) << "Update global params of block: " << ToString() << ", with previous block: " << block->ToString() << ",\nCurrent: " << py::str(global_py_params()) << "\nInsert: " << py::str(block->global_py_params()); UpdateGlobalPyParam(block->global_py_params()); } return res; } else if (prev_blocks_.empty()) { // Get namespace and make Resolve auto it = var_to_resolve_.find(var_name); if (it != var_to_resolve_.end()) { return it->second; } MS_LOG(DEBUG) << "var: " << var_name; auto tmp_node = MakeResolveSymbol(var_name); var_to_resolve_[var_name] = tmp_node; return tmp_node; } } // If have more than one predecessor blocks then build a phi node. auto debug_info = std::make_shared(); debug_info->set_name(var_name); TraceGuard guard(std::make_shared(debug_info)); ParameterPtr phi_param = std::make_shared(func_graph()); MS_LOG(DEBUG) << (func_graph_ ? func_graph_->ToString() : "FG(Null)") << " generate phi node " << phi_param->ToString() << " for " << var_name; func_graph()->add_parameter(phi_param); phi_nodes_[phi_param] = var_name; WriteVariable(var_name, phi_param); if (matured_) { SetPhiArgument(phi_param); } return phi_param; } // Resolve Ast operator node AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); TraceGuard trace_guard(parser_.GetLocation(op)); py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_AST_NAMESPACE_SYMBOL, op); if (namespace_var.size() != 2) { MS_LOG(EXCEPTION) << "Resolve ast op failed, get namespace tuple size=" << namespace_var.size(); } NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_AST, namespace_var[0]); SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString(); return MakeResolve(name_space, symbol); } // Resolve class member, two possible: method, member variable AnfNodePtr FunctionBlock::MakeResolveClassMember(const std::string &attr) { auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); py::object namespace_var = ast->CallParseModFunction(PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, ast->obj()); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); SymbolPtr symbol = std::make_shared(attr); MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString(); return MakeResolve(name_space, symbol); } AnfNodePtr FunctionBlock::GetResolveNode(const py::tuple &info) { constexpr size_t namespace_index = 0; constexpr size_t symbol_index = 1; NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_SYMBOL_STR, info[namespace_index]); SymbolPtr symbol = std::make_shared(info[symbol_index].cast()); return MakeResolve(name_space, symbol); } AnfNodePtr FunctionBlock::HandleNamespaceInfo(const py::tuple &info) { constexpr size_t namespace_index = 0; constexpr size_t symbol_index = 1; constexpr size_t namespace_info_size = 2; if (info.size() != namespace_info_size) { MS_EXCEPTION(NameError) << "namespace info size should be 2, but got " << info.size(); } // If namespace is None, the symbol is an undefined name. if (info[namespace_index].is_none()) { MS_EXCEPTION(NameError) << info[symbol_index].cast(); } return GetResolveNode(info); } AnfNodePtr FunctionBlock::HandleBuiltinNamespaceInfo(const py::tuple &info) { constexpr size_t closure_info_size = 2; constexpr size_t unsupported_info_size = 3; constexpr size_t supported_info_size = 4; constexpr size_t namespace_index = 0; constexpr size_t symbol_index = 1; constexpr size_t value_index = 2; if (info.size() < closure_info_size || info.size() > supported_info_size) { MS_EXCEPTION(NameError) << "namespace info size should be 2, 3 or 4, but got " << info.size(); } // Handle closure namespace info. if (info.size() == closure_info_size) { // If namespace is None, the symbol is an undefined name. if (info[namespace_index].is_none()) { MS_EXCEPTION(NameError) << info[symbol_index].cast(); } return GetResolveNode(info); } // Handle global namespace info. auto resolved_node = GetResolveNode(info); if (info.size() == unsupported_info_size) { resolved_node->set_interpret(true); } SymbolPtr symbol = std::make_shared(info[symbol_index].cast()); py::object py_obj = info[value_index]; AddGlobalPyParam(symbol->name(), py_obj); MS_LOG(INFO) << "[" << func_graph()->ToString() << "] Added global python symbol: {" << symbol->name() << " : " << py::str(py_obj) << "}"; return resolved_node; } // Make a resolve node for symbol string AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { MS_LOG(DEBUG) << "value: " << value; if (value.compare(0, strlen("self"), "self") == 0) { auto start = value.find_first_of('.') + 1; if (start >= value.size()) { MS_LOG(ERROR) << "Find invalid resolve symbol str: " << value; return nullptr; } auto bits_str = value.substr(start); return MakeResolveClassMember(bits_str); } auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); // The fallback feature is enabled in default. // Not support change the flag during the process is alive. static const auto use_fallback = (parser_.support_fallback() != "0"); if (!use_fallback) { py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_NAMESPACE_SYMBOL, value); return HandleNamespaceInfo(namespace_info); } else { py::tuple namespace_info = ast->CallParserObjMethod(PYTHON_PARSE_GET_BUILTIN_NAMESPACE_SYMBOL, value); return HandleBuiltinNamespaceInfo(namespace_info); } } AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); py::tuple namespace_var = ast->CallParseModFunction(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); const size_t namespace_var_size = 2; if (namespace_var.size() < namespace_var_size) { MS_EXCEPTION(NameError) << "namespace_var is less than 2"; } NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString(); return MakeResolve(name_space, symbol); } AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { MS_LOG(DEBUG) << "MakeResolve for " << (name_space ? (std::string)py::str(name_space->obj()) : "null namespace") << " , " << (resolve_symbol ? (std::string)resolve_symbol->symbol() : "null resolve symbol."); ValueNodePtr module_node = NewValueNode(name_space); ValueNodePtr symbol_node = NewValueNode(resolve_symbol); auto node = func_graph_->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); return node; } AnfNodePtr FunctionBlock::MakeInterpret(const std::string &script_text, const AnfNodePtr &global_dict_node, const AnfNodePtr &local_dict_node, const AnfNodePtr &orig_node) { MS_LOG(DEBUG) << "MakeInterpret for " << script_text; ScriptPtr script = std::make_shared