| @@ -194,9 +194,12 @@ def get_object_key(obj): | |||||
| obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args) | obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args) | ||||
| obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj)) | obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj)) | ||||
| else: | else: | ||||
| # `<class 'xxxxxxx'>` | |||||
| # -> `xxxxxxx` | |||||
| tag = str(obj.__class__)[8:-2] | |||||
| if hasattr(obj, "cell_init_args"): | if hasattr(obj, "cell_init_args"): | ||||
| obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args) | |||||
| obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj)) | |||||
| obj_key = "%s_ID" % (tag + obj.cell_init_args) | |||||
| obj_id = "%s_ID%d" % (tag, id(obj)) | |||||
| logger.debug("obj_key %s obj_id = %s", obj_key, obj_id) | logger.debug("obj_key %s obj_id = %s", obj_key, obj_id) | ||||
| # method has same id of different instance | # method has same id of different instance | ||||
| @@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor { | |||||
| } | } | ||||
| } | } | ||||
| // (void)mng->Replace(new_fg_parameters[param_i], new_param); | |||||
| new_parameters.push_back(new_param); | new_parameters.push_back(new_param); | ||||
| curr_input_idx++; | curr_input_idx++; | ||||
| } | } | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/param_info.h" | #include "ir/param_info.h" | ||||
| #include "ir/cell.h" | |||||
| #include "frontend/parallel/costmodel_context.h" | #include "frontend/parallel/costmodel_context.h" | ||||
| #include "frontend/parallel/context.h" | #include "frontend/parallel/context.h" | ||||
| #include "pipeline/jit/pass.h" | #include "pipeline/jit/pass.h" | ||||
| @@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) { | |||||
| parse::python_adapter::set_python_env_flag(true); | parse::python_adapter::set_python_env_flag(true); | ||||
| parse::python_adapter::SetPythonPath(dir); | parse::python_adapter::SetPythonPath(dir); | ||||
| FuncGraphPtr fg = parse::ConvertToFuncGraph(input); | |||||
| if (fg == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Parse error."; | |||||
| ValuePtr converted_ret = nullptr; | |||||
| bool converted = parse::ConvertData(input, &converted_ret, true); | |||||
| if (!converted) { | |||||
| MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input)); | |||||
| } | } | ||||
| res->set_func_graph(fg); | |||||
| FuncGraphPtr top_graph = nullptr; | |||||
| if (py::isinstance<Cell>(input)) { | |||||
| top_graph = parse::MakeTopGraph(input, converted_ret); | |||||
| } else if (converted_ret->isa<FuncGraph>()) { | |||||
| top_graph = converted_ret->cast<FuncGraphPtr>(); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell."; | |||||
| } | |||||
| parse::Parser::UpdateTopFuncGraph(top_graph); | |||||
| res->set_func_graph(top_graph); | |||||
| FuncGraphManagerPtr manager = res->manager(); | FuncGraphManagerPtr manager = res->manager(); | ||||
| if (manager == nullptr) { | if (manager == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "Manager is nullptr."; | MS_LOG(EXCEPTION) << "Manager is nullptr."; | ||||
| } | } | ||||
| manager->AddFuncGraph(fg); | |||||
| manager->AddFuncGraph(top_graph); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "frontend/operator/composite/composite.h" | #include "frontend/operator/composite/composite.h" | ||||
| #include "ir/func_graph_cloner.h" | #include "ir/func_graph_cloner.h" | ||||
| #include "ir/cell.h" | |||||
| #include "utils/symbolic.h" | #include "utils/symbolic.h" | ||||
| #include "utils/ms_context.h" | #include "utils/ms_context.h" | ||||
| @@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||||
| bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { | |||||
| auto obj = py::cast(cell); | |||||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | ||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "Parse resolve function error."; | MS_LOG(ERROR) << "Parse resolve function error."; | ||||
| @@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||||
| if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { | if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { | ||||
| // Create the namespace for common class instance | // Create the namespace for common class instance | ||||
| // When the obj is Cell, default parse the 'construct' | // When the obj is Cell, default parse the 'construct' | ||||
| if (data_converter::IsCellInstance(obj)) { | |||||
| return ConvertCellObjToFuncGraph(obj, data); | |||||
| } | |||||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | ||||
| py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | ||||
| *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | ||||
| @@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||||
| ret = ConvertTuple(obj, &converted, use_signature); | ret = ConvertTuple(obj, &converted, use_signature); | ||||
| } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | ||||
| ret = ConvertCellList(obj, &converted, use_signature); | ret = ConvertCellList(obj, &converted, use_signature); | ||||
| } else if (py::isinstance<Cell>(obj)) { | |||||
| return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data); | |||||
| } else if (py::isinstance<py::list>(obj)) { | } else if (py::isinstance<py::list>(obj)) { | ||||
| ret = ConvertList(obj, &converted, use_signature); | ret = ConvertList(obj, &converted, use_signature); | ||||
| } else if (py::isinstance<py::module>(obj)) { | } else if (py::isinstance<py::module>(obj)) { | ||||
| @@ -140,34 +140,80 @@ void Parser::CleanParserResource() { | |||||
| ScopeManager::GetInstance().ClearScope(); | ScopeManager::GetInstance().ClearScope(); | ||||
| } | } | ||||
| FuncGraphPtr Parser::ParseFuncGraph() { | |||||
| // get ast FunctionDef node | |||||
| py::object node = ast_->GetAstNode(); | |||||
| FunctionBlockPtr pFnBlock = ParseFunction(node); | |||||
| if (errcode() != PARSE_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Parse function error, code is " << errcode(); | |||||
| return nullptr; | |||||
| AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto value = py::cast<tensor::MetaTensorPtr>(obj); | |||||
| // parameter object should not be none | |||||
| if (value == nullptr || !value->is_parameter()) { | |||||
| MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object."; | |||||
| } | |||||
| // get the parameter name from parameter object | |||||
| auto param_name = value->param_info()->name(); | |||||
| auto top_graph = func_graph; | |||||
| // if the parameter node has been created , return it | |||||
| AnfNodePtr para_node = nullptr; | |||||
| for (auto param : top_graph->parameters()) { | |||||
| auto param_node = dyn_cast<Parameter>(param); | |||||
| if (param_node != nullptr && param_node->name() == param_name) { | |||||
| para_node = param; | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| if (para_node == nullptr) { | |||||
| auto node = top_graph->AddWeightParameter(param_name); | |||||
| RemoveUnnecessaryPhis(); | |||||
| node->set_default_param(value); | |||||
| // set_abstract for parameter | |||||
| auto abs = value->ToAbstract(); | |||||
| // boarden value | |||||
| abs = abs->Broaden(); | |||||
| node->set_abstract(abs); | |||||
| para_node = node; | |||||
| } | |||||
| return para_node; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(pFnBlock); | |||||
| void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) { | |||||
| auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>(); | |||||
| for (size_t i = 0; i < params.size(); i++) { | |||||
| (void)AppendParameterObj(top_graph, params[i]); | |||||
| } | |||||
| } | |||||
| void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) { | |||||
| // check whether the functions refered by this function and itself are missing 'return' statement | // check whether the functions refered by this function and itself are missing 'return' statement | ||||
| auto mng = Manage(pFnBlock->func_graph(), false); | |||||
| auto mng = Manage(fn, false); | |||||
| for (auto func_graph : mng->func_graphs()) { | for (auto func_graph : mng->func_graphs()) { | ||||
| if (func_graph->get_return() != nullptr) { | if (func_graph->get_return() != nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||||
| py::object node = ast->GetAstNode(); | |||||
| py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||||
| py::str desc = | py::str desc = | ||||
| python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]); | |||||
| python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); | |||||
| MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; | MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; | ||||
| } | } | ||||
| // clear manager info after checking missing return | // clear manager info after checking missing return | ||||
| for (auto fg : mng->func_graphs()) { | for (auto fg : mng->func_graphs()) { | ||||
| fg->ClearAllManagerInfo(); | fg->ClearAllManagerInfo(); | ||||
| } | } | ||||
| } | |||||
| FuncGraphPtr Parser::ParseFuncGraph() { | |||||
| // get ast FunctionDef node | |||||
| py::object node = ast_->GetAstNode(); | |||||
| FunctionBlockPtr pFnBlock = ParseFunction(node); | |||||
| if (errcode() != PARSE_SUCCESS) { | |||||
| MS_LOG(ERROR) << "Parse function error, code is " << errcode(); | |||||
| return nullptr; | |||||
| } | |||||
| RemoveUnnecessaryPhis(); | |||||
| MS_EXCEPTION_IF_NULL(pFnBlock); | |||||
| CheckFuncReturn(pFnBlock->func_graph(), ast_); | |||||
| return pFnBlock->func_graph(); | return pFnBlock->func_graph(); | ||||
| } | } | ||||
| @@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no | |||||
| return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); | return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); | ||||
| } | } | ||||
| CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node, | |||||
| const std::vector<AnfNodePtr> &packed_arguments) { | |||||
| std::vector<AnfNodePtr> unpack_call_nodes; | |||||
| auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL)); | |||||
| unpack_call_nodes.push_back(unpack_call_op); | |||||
| unpack_call_nodes.push_back(call_function_anf_node); | |||||
| (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), | |||||
| [](AnfNodePtr node) -> AnfNodePtr { return node; }); | |||||
| CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes); | |||||
| return unpack_call; | |||||
| } | |||||
| AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, | AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, | ||||
| const std::vector<AnfNodePtr> &packed_arguments, | const std::vector<AnfNodePtr> &packed_arguments, | ||||
| const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const { | const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const { | ||||
| // if there is keyword arguments or starred, using an unpack_call op to unpack the argument | // if there is keyword arguments or starred, using an unpack_call op to unpack the argument | ||||
| if (need_unpack) { | if (need_unpack) { | ||||
| std::vector<AnfNodePtr> unpack_call_nodes; | |||||
| auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL)); | |||||
| unpack_call_nodes.push_back(unpack_call_op); | |||||
| unpack_call_nodes.push_back(call_function_anf_node); | |||||
| (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), | |||||
| [](AnfNodePtr node) -> AnfNodePtr { return node; }); | |||||
| CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); | |||||
| return unpack_call; | |||||
| return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments); | |||||
| } | } | ||||
| // else there is no keyword arguments and starred, parsed as normal arguments without unpack | // else there is no keyword arguments and starred, parsed as normal arguments without unpack | ||||
| std::vector<AnfNodePtr> func_call_nodes; | std::vector<AnfNodePtr> func_call_nodes; | ||||
| @@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { | |||||
| auto func_graph = std::make_shared<FuncGraph>(); | |||||
| func_graph->debug_info()->set_name("top"); | |||||
| // def top(*arg, *kwargs): | |||||
| auto param_vargs = func_graph->add_parameter(); | |||||
| auto args_name = "args"; | |||||
| param_vargs->set_name(args_name); | |||||
| param_vargs->debug_info()->set_name(args_name); | |||||
| auto param_vkwargs = func_graph->add_parameter(); | |||||
| args_name = "kwargs"; | |||||
| param_vkwargs->set_name(args_name); | |||||
| param_vkwargs->debug_info()->set_name(args_name); | |||||
| func_graph->set_has_vararg(true); | |||||
| func_graph->set_has_kwarg(true); | |||||
| func_graph->set_kwonlyargs_count(0); | |||||
| // cell_obj | |||||
| parse::UpdateFuncGraphFlags(cell, func_graph); | |||||
| // top graph's construct flag | |||||
| if (py::hasattr(cell, "construct")) { | |||||
| parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); | |||||
| } | |||||
| UpdataParam(func_graph, cell); | |||||
| // ret = cell_obj(*arg, *kwargs) | |||||
| auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs}); | |||||
| // return ret | |||||
| func_graph->set_output(call_fn); | |||||
| MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); | |||||
| return func_graph; | |||||
| } | |||||
| } // namespace parse | } // namespace parse | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, | |||||
| // Parse the python object to graph | // Parse the python object to graph | ||||
| FuncGraphPtr ParsePythonCode(const py::object &obj, | FuncGraphPtr ParsePythonCode(const py::object &obj, | ||||
| const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); | const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); | ||||
| // add wrap for cell top graph. | |||||
| FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr); | |||||
| } // namespace parse | } // namespace parse | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -0,0 +1,50 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "pybind_api/ir/cell_py.h" | |||||
| #include <string> | |||||
| #include "pybind_api/api_register.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "pipeline/jit/parse/python_adapter.h" | |||||
| namespace mindspore { | |||||
| void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) { | |||||
| std::string attr_name = name; | |||||
| ValuePtr converted_ret = nullptr; | |||||
| if (py::isinstance<py::module>(obj)) { | |||||
| MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module"; | |||||
| } | |||||
| bool converted = parse::ConvertData(obj, &converted_ret, true); | |||||
| if (!converted) { | |||||
| MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||||
| } else { | |||||
| MS_LOG(DEBUG) << cell->ToString() << " add attr " << attr_name << converted_ret->ToString(); | |||||
| cell->AddAttr(attr_name, converted_ret); | |||||
| } | |||||
| } | |||||
| // Define python 'Cell' class. | |||||
| REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) { | |||||
| (void)py::class_<Cell, std::shared_ptr<Cell>>(*m, "Cell_") | |||||
| .def(py::init<std::string &>()) | |||||
| .def("__str__", &Cell::ToString) | |||||
| .def("_add_attr", &CellPy::AddAttr, "Add Cell attr.") | |||||
| .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.") | |||||
| .def( | |||||
| "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; }, | |||||
| "construct"); | |||||
| })); | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_ | |||||
| #define MINDSPORE_CCSRC_UTILS_CELL_PY_H_ | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "pybind11/pybind11.h" | |||||
| #include "pybind11/numpy.h" | |||||
| #include "ir/cell.h" | |||||
| namespace py = pybind11; | |||||
| // brief mindspore namespace. | |||||
| // | |||||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | |||||
| namespace mindspore { | |||||
| // Cell python wrapper and adapter class. | |||||
| class CellPy { | |||||
| public: | |||||
| static void AddAttr(CellPtr cell, const std::string &name, const py::object &obj); | |||||
| }; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_ | |||||
| @@ -0,0 +1,94 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ir/cell.h" | |||||
| #include <utility> | |||||
| #include <map> | |||||
| #include <algorithm> | |||||
| #include "abstract/abstract_value.h" | |||||
| namespace mindspore { | |||||
| using mindspore::abstract::AbstractFunction; | |||||
| abstract::AbstractBasePtr Cell::ToAbstract() { | |||||
| /* | |||||
| std::vector<abstract::AbstractAttribute> abs_attrs; | |||||
| std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs), | |||||
| [](std::pair<std::string, ValuePtr> attr) -> abstract::AbstractAttribute { | |||||
| return std::make_pair(attr.first, attr.second->ToAbstract()); | |||||
| }); | |||||
| auto abs = std::make_shared<abstract::AbstractCell>(shared_from_base<Named>(), abs_attrs); | |||||
| abs->set_value(shared_from_base<Value>()); | |||||
| return abs; | |||||
| */ | |||||
| return nullptr; | |||||
| } | |||||
| bool Cell::operator==(const Value &other) const { | |||||
| if (other.isa<Cell>()) { | |||||
| auto other_prim = static_cast<const Cell &>(other); | |||||
| return *this == other_prim; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| bool Cell::operator==(const Cell &other) const { | |||||
| if (name() != other.name()) { | |||||
| return false; | |||||
| } | |||||
| if (attrs_.size() != other.attrs_.size()) { | |||||
| return false; | |||||
| } | |||||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||||
| if (item.second == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto iter = other.attrs_.find(item.first); | |||||
| if (iter == other.attrs_.end()) { | |||||
| return false; | |||||
| } | |||||
| return *item.second == *iter->second; | |||||
| }); | |||||
| return all; | |||||
| } | |||||
| std::string Cell::GetAttrString() const { | |||||
| std::ostringstream buffer; | |||||
| bool begin = true; | |||||
| buffer << "{" << std::endl; | |||||
| for (auto &attr : attrs_) { | |||||
| if (!begin) { | |||||
| buffer << ", " << std::endl; | |||||
| } else { | |||||
| begin = false; | |||||
| } | |||||
| buffer << attr.first << ":" << attr.second->ToString(); | |||||
| } | |||||
| buffer << "}"; | |||||
| return buffer.str(); | |||||
| } | |||||
| std::string Cell::ToString() const { | |||||
| std::ostringstream buffer; | |||||
| buffer << "Cell " << name(); | |||||
| return buffer.str(); | |||||
| } | |||||
| void Cell::DelAttr(const std::string &name) { attrs_.erase(name); } | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,69 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_IR_CELL_H_ | |||||
| #define MINDSPORE_CCSRC_IR_CELL_H_ | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <tuple> | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "utils/misc.h" | |||||
| namespace mindspore { | |||||
| using abstract::AbstractBasePtr; | |||||
| using abstract::AbstractBasePtrList; | |||||
| // value for Cell | |||||
| class Cell : public Named { | |||||
| public: | |||||
| explicit Cell(const std::string &name) : Named(name) {} | |||||
| MS_DECLARE_PARENT(Cell, Named); | |||||
| abstract::AbstractBasePtr ToAbstract() override; | |||||
| std::string ToString() const override; | |||||
| std::string GetAttrString() const; | |||||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||||
| void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs_input) { attrs_ = attrs_input; } | |||||
| void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; } | |||||
| void DelAttr(const std::string &name); | |||||
| ValuePtr GetAttr(const std::string &attr_name) const { | |||||
| auto iter = attrs_.find(attr_name); | |||||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||||
| } | |||||
| bool HasAttr(const std::string &attr_name) const { | |||||
| auto iter = attrs_.find(attr_name); | |||||
| return !(iter == attrs_.cend()); | |||||
| } | |||||
| bool operator==(const Value &other) const override; | |||||
| bool operator==(const Cell &other) const; | |||||
| ~Cell() override = default; | |||||
| const bool parse_info_ = true; | |||||
| private: | |||||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||||
| }; | |||||
| using CellPtr = std::shared_ptr<Cell>; | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_IR_CELL_H_ | |||||
| @@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, | |||||
| MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count | MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count | ||||
| << " were given."; | << " were given."; | ||||
| } | } | ||||
| auto varg_name = specialized_graph->GetVariableArgName(); | |||||
| // for python variable argument input , there is no upper limit | // for python variable argument input , there is no upper limit | ||||
| for (int i = 0; i < variable_args_count; ++i) { | for (int i = 0; i < variable_args_count; ++i) { | ||||
| ParameterPtr p = std::make_shared<Parameter>(specialized_graph); | ParameterPtr p = std::make_shared<Parameter>(specialized_graph); | ||||
| std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); | |||||
| std::string param_name = varg_name + std::to_string(i); | |||||
| p->set_name(param_name); | p->set_name(param_name); | ||||
| MS_EXCEPTION_IF_NULL(p->debug_info()); | MS_EXCEPTION_IF_NULL(p->debug_info()); | ||||
| p->debug_info()->set_name(param_name); | p->debug_info()->set_name(param_name); | ||||
| @@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe | |||||
| while (temp_info != nullptr) { | while (temp_info != nullptr) { | ||||
| if (temp_info->trace_info() != nullptr) { | if (temp_info->trace_info() != nullptr) { | ||||
| if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() || | if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() || | ||||
| temp_info->trace_info()->isa<TraceGenMetaFuncGraph>()) { | |||||
| temp_info->trace_info()->isa<TraceGenMetaFuncGraph>() || | |||||
| temp_info->trace_info()->isa<TraceGenerateVarArg>() || temp_info->trace_info()->isa<TraceGenerateKwArg>()) { | |||||
| break; | break; | ||||
| } | } | ||||
| trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); | trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); | ||||
| @@ -24,14 +24,14 @@ from ..common import dtype as mstype | |||||
| from ..common.api import _executor, _pynative_exec | from ..common.api import _executor, _pynative_exec | ||||
| from .._checkparam import _check_str_by_regular | from .._checkparam import _check_str_by_regular | ||||
| from ..common.parameter import Parameter, ParameterTuple | from ..common.parameter import Parameter, ParameterTuple | ||||
| from .._c_expression import init_backend | |||||
| from .._c_expression import init_backend, Cell_ | |||||
| from ..ops.primitive import Primitive | from ..ops.primitive import Primitive | ||||
| from ..ops.operations import HookBackward | from ..ops.operations import HookBackward | ||||
| from ..ops.functional import cast | from ..ops.functional import cast | ||||
| from ..parallel._tensor import _load_tensor_by_layout | from ..parallel._tensor import _load_tensor_by_layout | ||||
| from ..common.tensor import Tensor | from ..common.tensor import Tensor | ||||
| class Cell: | |||||
| class Cell(Cell_): | |||||
| """ | """ | ||||
| Base class for all neural networks. | Base class for all neural networks. | ||||
| @@ -58,14 +58,21 @@ class Cell: | |||||
| >>> def construct(self, x): | >>> def construct(self, x): | ||||
| >>> return self.relu(x) | >>> return self.relu(x) | ||||
| """ | """ | ||||
| IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', | |||||
| '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', | |||||
| '_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode', | |||||
| '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced', | |||||
| 'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type'] | |||||
| def __init__(self, auto_prefix=True, flags=None): | def __init__(self, auto_prefix=True, flags=None): | ||||
| Cell_.__init__(self, self._cell_tag) | |||||
| self._params = OrderedDict() | self._params = OrderedDict() | ||||
| self._cells = OrderedDict() | self._cells = OrderedDict() | ||||
| self._params_list = OrderedDict() | self._params_list = OrderedDict() | ||||
| self.training = False | self.training = False | ||||
| self.requires_grad = False | self.requires_grad = False | ||||
| self.pynative = False | self.pynative = False | ||||
| self._attr_synced = False | |||||
| self._param_prefix = '' | self._param_prefix = '' | ||||
| self._auto_prefix = auto_prefix | self._auto_prefix = auto_prefix | ||||
| self._scope = None | self._scope = None | ||||
| @@ -92,6 +99,12 @@ class Cell: | |||||
| def already_run(self): | def already_run(self): | ||||
| return self._already_run | return self._already_run | ||||
| @property | |||||
| def _cell_tag(self): | |||||
| # `<class 'xxxxxxx'>` | |||||
| # -> `xxxxxxx` | |||||
| return str(self.__class__)[8:-2] | |||||
| @already_run.setter | @already_run.setter | ||||
| def already_run(self, value): | def already_run(self, value): | ||||
| self._already_run = value | self._already_run = value | ||||
| @@ -222,6 +235,7 @@ class Cell: | |||||
| del self._cells[name] | del self._cells[name] | ||||
| else: | else: | ||||
| object.__delattr__(self, name) | object.__delattr__(self, name) | ||||
| self._attr_synced = False | |||||
| def cast_inputs(self, inputs, dst_type): | def cast_inputs(self, inputs, dst_type): | ||||
| res = list() | res = list() | ||||
| @@ -277,6 +291,34 @@ class Cell: | |||||
| self._already_run = True | self._already_run = True | ||||
| return output | return output | ||||
| def _add_attr(self, name, value): | |||||
| if name and name[:2] != '__' and name not in Cell.IGNORE_LIST: | |||||
| super(Cell, self)._add_attr(name, value) | |||||
| def _sync_attr_for_compile(self): | |||||
| """Sync the attr to c++ object.""" | |||||
| if self._attr_synced: | |||||
| return | |||||
| cells = self.__dict__.get('_cells') | |||||
| for key in cells: | |||||
| cell = cells[key] | |||||
| cell._sync_attr_for_compile() | |||||
| self._add_attr(key, cell) | |||||
| params = self.__dict__.get('_params') | |||||
| for key in params: | |||||
| if '.' in key: | |||||
| continue | |||||
| param = params[key] | |||||
| self._add_attr(key, param) | |||||
| params_list = self.__dict__.get('_params_list') | |||||
| for key in params_list: | |||||
| params_list_item = params_list[key] | |||||
| self._add_attr(key, params_list_item) | |||||
| for key in self.__dict__: | |||||
| value = self.__dict__[key] | |||||
| self._add_attr(key, value) | |||||
| self._attr_synced = True | |||||
| def __setattr__(self, name, value): | def __setattr__(self, name, value): | ||||
| cells = self.__dict__.get('_cells') | cells = self.__dict__.get('_cells') | ||||
| params = self.__dict__.get('_params') | params = self.__dict__.get('_params') | ||||
| @@ -329,6 +371,8 @@ class Cell: | |||||
| if isinstance(value, Primitive): | if isinstance(value, Primitive): | ||||
| value.set_prim_instance_name(name) | value.set_prim_instance_name(name) | ||||
| object.__setattr__(self, name, value) | object.__setattr__(self, name, value) | ||||
| if name not in Cell.IGNORE_LIST: | |||||
| self._attr_synced = False | |||||
| def extend_repr(self): | def extend_repr(self): | ||||
| """ | """ | ||||
| @@ -451,7 +495,7 @@ class Cell: | |||||
| Object, the result of executing. | Object, the result of executing. | ||||
| """ | """ | ||||
| self._auto_parallel_compile_and_run = True | self._auto_parallel_compile_and_run = True | ||||
| _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) | |||||
| self.compile(*inputs) | |||||
| if self._auto_parallel_mode: | if self._auto_parallel_mode: | ||||
| if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: | if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: | ||||
| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """container""" | """container""" | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| from abc import abstractmethod, ABCMeta | |||||
| from abc import abstractmethod | |||||
| from ..cell import Cell | from ..cell import Cell | ||||
| __all__ = ['SequentialCell', 'CellList'] | __all__ = ['SequentialCell', 'CellList'] | ||||
| @@ -34,7 +34,7 @@ def _valid_cell(cell): | |||||
| raise TypeError('Cell {} is not subclass of Cell'.format(cell)) | raise TypeError('Cell {} is not subclass of Cell'.format(cell)) | ||||
| class _CellListBase(metaclass=ABCMeta): | |||||
| class _CellListBase(): | |||||
| """ | """ | ||||
| An interface for base the cell as list. | An interface for base the cell as list. | ||||
| @@ -51,7 +51,7 @@ def test_get_parameter_layout(): | |||||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | exe.compile(net, x, phase='train', auto_parallel_mode=True) | ||||
| x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] | x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] | ||||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] | weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] | ||||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | |||||
| expect_dict = {'args0': x_layout, 'w1': weight_layout} | |||||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | ||||
| assert net.parameter_layout_dict == expect_dict | assert net.parameter_layout_dict == expect_dict | ||||
| @@ -125,7 +125,7 @@ def test_grad_sens_parameter_type(): | |||||
| y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] | y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] | ||||
| b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] | b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] | ||||
| sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] | sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] | ||||
| expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} | |||||
| expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout} | |||||
| assert net.parameter_layout_dict == expect_dict | assert net.parameter_layout_dict == expect_dict | ||||