| @@ -194,9 +194,12 @@ def get_object_key(obj): | |||
| obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args) | |||
| obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj)) | |||
| else: | |||
| # `<class 'xxxxxxx'>` | |||
| # -> `xxxxxxx` | |||
| tag = str(obj.__class__)[8:-2] | |||
| if hasattr(obj, "cell_init_args"): | |||
| obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args) | |||
| obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj)) | |||
| obj_key = "%s_ID" % (tag + obj.cell_init_args) | |||
| obj_id = "%s_ID%d" % (tag, id(obj)) | |||
| logger.debug("obj_key %s obj_id = %s", obj_key, obj_id) | |||
| # method has same id of different instance | |||
| @@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor { | |||
| } | |||
| } | |||
| // (void)mng->Replace(new_fg_parameters[param_i], new_param); | |||
| new_parameters.push_back(new_param); | |||
| curr_input_idx++; | |||
| } | |||
| @@ -25,6 +25,7 @@ | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/param_info.h" | |||
| #include "ir/cell.h" | |||
| #include "frontend/parallel/costmodel_context.h" | |||
| #include "frontend/parallel/context.h" | |||
| #include "pipeline/jit/pass.h" | |||
| @@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) { | |||
| parse::python_adapter::set_python_env_flag(true); | |||
| parse::python_adapter::SetPythonPath(dir); | |||
| FuncGraphPtr fg = parse::ConvertToFuncGraph(input); | |||
| if (fg == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Parse error."; | |||
| ValuePtr converted_ret = nullptr; | |||
| bool converted = parse::ConvertData(input, &converted_ret, true); | |||
| if (!converted) { | |||
| MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input)); | |||
| } | |||
| res->set_func_graph(fg); | |||
| FuncGraphPtr top_graph = nullptr; | |||
| if (py::isinstance<Cell>(input)) { | |||
| top_graph = parse::MakeTopGraph(input, converted_ret); | |||
| } else if (converted_ret->isa<FuncGraph>()) { | |||
| top_graph = converted_ret->cast<FuncGraphPtr>(); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell."; | |||
| } | |||
| parse::Parser::UpdateTopFuncGraph(top_graph); | |||
| res->set_func_graph(top_graph); | |||
| FuncGraphManagerPtr manager = res->manager(); | |||
| if (manager == nullptr) { | |||
| MS_LOG(EXCEPTION) << "Manager is nullptr."; | |||
| } | |||
| manager->AddFuncGraph(fg); | |||
| manager->AddFuncGraph(top_graph); | |||
| return true; | |||
| } | |||
| @@ -27,6 +27,7 @@ | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/operator/composite/composite.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "ir/cell.h" | |||
| #include "utils/symbolic.h" | |||
| #include "utils/ms_context.h" | |||
| @@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { | |||
| return true; | |||
| } | |||
| bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { | |||
| bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) { | |||
| auto obj = py::cast(cell); | |||
| FuncGraphPtr func_graph = ConvertToFuncGraph(obj); | |||
| if (func_graph == nullptr) { | |||
| MS_LOG(ERROR) << "Parse resolve function error."; | |||
| @@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { | |||
| if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { | |||
| // Create the namespace for common class instance | |||
| // When the obj is Cell, default parse the 'construct' | |||
| if (data_converter::IsCellInstance(obj)) { | |||
| return ConvertCellObjToFuncGraph(obj, data); | |||
| } | |||
| py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); | |||
| py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); | |||
| *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); | |||
| @@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature | |||
| ret = ConvertTuple(obj, &converted, use_signature); | |||
| } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { | |||
| ret = ConvertCellList(obj, &converted, use_signature); | |||
| } else if (py::isinstance<Cell>(obj)) { | |||
| return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data); | |||
| } else if (py::isinstance<py::list>(obj)) { | |||
| ret = ConvertList(obj, &converted, use_signature); | |||
| } else if (py::isinstance<py::module>(obj)) { | |||
| @@ -140,34 +140,80 @@ void Parser::CleanParserResource() { | |||
| ScopeManager::GetInstance().ClearScope(); | |||
| } | |||
| FuncGraphPtr Parser::ParseFuncGraph() { | |||
| // get ast FunctionDef node | |||
| py::object node = ast_->GetAstNode(); | |||
| FunctionBlockPtr pFnBlock = ParseFunction(node); | |||
| if (errcode() != PARSE_SUCCESS) { | |||
| MS_LOG(ERROR) << "Parse function error, code is " << errcode(); | |||
| return nullptr; | |||
| AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto value = py::cast<tensor::MetaTensorPtr>(obj); | |||
| // parameter object should not be none | |||
| if (value == nullptr || !value->is_parameter()) { | |||
| MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object."; | |||
| } | |||
| // get the parameter name from parameter object | |||
| auto param_name = value->param_info()->name(); | |||
| auto top_graph = func_graph; | |||
| // if the parameter node has been created , return it | |||
| AnfNodePtr para_node = nullptr; | |||
| for (auto param : top_graph->parameters()) { | |||
| auto param_node = dyn_cast<Parameter>(param); | |||
| if (param_node != nullptr && param_node->name() == param_name) { | |||
| para_node = param; | |||
| break; | |||
| } | |||
| } | |||
| if (para_node == nullptr) { | |||
| auto node = top_graph->AddWeightParameter(param_name); | |||
| RemoveUnnecessaryPhis(); | |||
| node->set_default_param(value); | |||
| // set_abstract for parameter | |||
| auto abs = value->ToAbstract(); | |||
| // boarden value | |||
| abs = abs->Broaden(); | |||
| node->set_abstract(abs); | |||
| para_node = node; | |||
| } | |||
| return para_node; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(pFnBlock); | |||
| void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) { | |||
| auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>(); | |||
| for (size_t i = 0; i < params.size(); i++) { | |||
| (void)AppendParameterObj(top_graph, params[i]); | |||
| } | |||
| } | |||
| void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) { | |||
| // check whether the functions refered by this function and itself are missing 'return' statement | |||
| auto mng = Manage(pFnBlock->func_graph(), false); | |||
| auto mng = Manage(fn, false); | |||
| for (auto func_graph : mng->func_graphs()) { | |||
| if (func_graph->get_return() != nullptr) { | |||
| continue; | |||
| } | |||
| py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||
| py::object node = ast->GetAstNode(); | |||
| py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); | |||
| py::str desc = | |||
| python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]); | |||
| python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]); | |||
| MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; | |||
| } | |||
| // clear manager info after checking missing return | |||
| for (auto fg : mng->func_graphs()) { | |||
| fg->ClearAllManagerInfo(); | |||
| } | |||
| } | |||
| FuncGraphPtr Parser::ParseFuncGraph() { | |||
| // get ast FunctionDef node | |||
| py::object node = ast_->GetAstNode(); | |||
| FunctionBlockPtr pFnBlock = ParseFunction(node); | |||
| if (errcode() != PARSE_SUCCESS) { | |||
| MS_LOG(ERROR) << "Parse function error, code is " << errcode(); | |||
| return nullptr; | |||
| } | |||
| RemoveUnnecessaryPhis(); | |||
| MS_EXCEPTION_IF_NULL(pFnBlock); | |||
| CheckFuncReturn(pFnBlock->func_graph(), ast_); | |||
| return pFnBlock->func_graph(); | |||
| } | |||
| @@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no | |||
| return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); | |||
| } | |||
| CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node, | |||
| const std::vector<AnfNodePtr> &packed_arguments) { | |||
| std::vector<AnfNodePtr> unpack_call_nodes; | |||
| auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL)); | |||
| unpack_call_nodes.push_back(unpack_call_op); | |||
| unpack_call_nodes.push_back(call_function_anf_node); | |||
| (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), | |||
| [](AnfNodePtr node) -> AnfNodePtr { return node; }); | |||
| CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes); | |||
| return unpack_call; | |||
| } | |||
| AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, | |||
| const std::vector<AnfNodePtr> &packed_arguments, | |||
| const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const { | |||
| // if there is keyword arguments or starred, using an unpack_call op to unpack the argument | |||
| if (need_unpack) { | |||
| std::vector<AnfNodePtr> unpack_call_nodes; | |||
| auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL)); | |||
| unpack_call_nodes.push_back(unpack_call_op); | |||
| unpack_call_nodes.push_back(call_function_anf_node); | |||
| (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), | |||
| [](AnfNodePtr node) -> AnfNodePtr { return node; }); | |||
| CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); | |||
| return unpack_call; | |||
| return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments); | |||
| } | |||
| // else there is no keyword arguments and starred, parsed as normal arguments without unpack | |||
| std::vector<AnfNodePtr> func_call_nodes; | |||
| @@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { | |||
| return true; | |||
| } | |||
| FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { | |||
| auto func_graph = std::make_shared<FuncGraph>(); | |||
| func_graph->debug_info()->set_name("top"); | |||
| // def top(*arg, *kwargs): | |||
| auto param_vargs = func_graph->add_parameter(); | |||
| auto args_name = "args"; | |||
| param_vargs->set_name(args_name); | |||
| param_vargs->debug_info()->set_name(args_name); | |||
| auto param_vkwargs = func_graph->add_parameter(); | |||
| args_name = "kwargs"; | |||
| param_vkwargs->set_name(args_name); | |||
| param_vkwargs->debug_info()->set_name(args_name); | |||
| func_graph->set_has_vararg(true); | |||
| func_graph->set_has_kwarg(true); | |||
| func_graph->set_kwonlyargs_count(0); | |||
| // cell_obj | |||
| parse::UpdateFuncGraphFlags(cell, func_graph); | |||
| // top graph's construct flag | |||
| if (py::hasattr(cell, "construct")) { | |||
| parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph); | |||
| } | |||
| UpdataParam(func_graph, cell); | |||
| // ret = cell_obj(*arg, *kwargs) | |||
| auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs}); | |||
| // return ret | |||
| func_graph->set_output(call_fn); | |||
| MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell)); | |||
| return func_graph; | |||
| } | |||
| } // namespace parse | |||
| } // namespace mindspore | |||
| @@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, | |||
| // Parse the python object to graph | |||
| FuncGraphPtr ParsePythonCode(const py::object &obj, | |||
| const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); | |||
| // add wrap for cell top graph. | |||
| FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr); | |||
| } // namespace parse | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pybind_api/ir/cell_py.h" | |||
| #include <string> | |||
| #include "pybind_api/api_register.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "pipeline/jit/parse/python_adapter.h" | |||
| namespace mindspore { | |||
| void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) { | |||
| std::string attr_name = name; | |||
| ValuePtr converted_ret = nullptr; | |||
| if (py::isinstance<py::module>(obj)) { | |||
| MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module"; | |||
| } | |||
| bool converted = parse::ConvertData(obj, &converted_ret, true); | |||
| if (!converted) { | |||
| MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||
| } else { | |||
| MS_LOG(DEBUG) << cell->ToString() << " add attr " << attr_name << converted_ret->ToString(); | |||
| cell->AddAttr(attr_name, converted_ret); | |||
| } | |||
| } | |||
| // Define python 'Cell' class. | |||
| REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) { | |||
| (void)py::class_<Cell, std::shared_ptr<Cell>>(*m, "Cell_") | |||
| .def(py::init<std::string &>()) | |||
| .def("__str__", &Cell::ToString) | |||
| .def("_add_attr", &CellPy::AddAttr, "Add Cell attr.") | |||
| .def("_del_attr", &Cell::DelAttr, "Delete Cell attr.") | |||
| .def( | |||
| "construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; }, | |||
| "construct"); | |||
| })); | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_ | |||
| #define MINDSPORE_CCSRC_UTILS_CELL_PY_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "pybind11/pybind11.h" | |||
| #include "pybind11/numpy.h" | |||
| #include "ir/cell.h" | |||
| namespace py = pybind11; | |||
| // brief mindspore namespace. | |||
| // | |||
| // mindspore namespace is the top level namespace of Mindsporeession project. | |||
| // Other namespace should be a sub namespace of mindspore namespace in the ME project. | |||
| namespace mindspore { | |||
| // Cell python wrapper and adapter class. | |||
| class CellPy { | |||
| public: | |||
| static void AddAttr(CellPtr cell, const std::string &name, const py::object &obj); | |||
| }; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_ | |||
| @@ -0,0 +1,94 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ir/cell.h" | |||
| #include <utility> | |||
| #include <map> | |||
| #include <algorithm> | |||
| #include "abstract/abstract_value.h" | |||
| namespace mindspore { | |||
| using mindspore::abstract::AbstractFunction; | |||
| abstract::AbstractBasePtr Cell::ToAbstract() { | |||
| /* | |||
| std::vector<abstract::AbstractAttribute> abs_attrs; | |||
| std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs), | |||
| [](std::pair<std::string, ValuePtr> attr) -> abstract::AbstractAttribute { | |||
| return std::make_pair(attr.first, attr.second->ToAbstract()); | |||
| }); | |||
| auto abs = std::make_shared<abstract::AbstractCell>(shared_from_base<Named>(), abs_attrs); | |||
| abs->set_value(shared_from_base<Value>()); | |||
| return abs; | |||
| */ | |||
| return nullptr; | |||
| } | |||
| bool Cell::operator==(const Value &other) const { | |||
| if (other.isa<Cell>()) { | |||
| auto other_prim = static_cast<const Cell &>(other); | |||
| return *this == other_prim; | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| bool Cell::operator==(const Cell &other) const { | |||
| if (name() != other.name()) { | |||
| return false; | |||
| } | |||
| if (attrs_.size() != other.attrs_.size()) { | |||
| return false; | |||
| } | |||
| auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool { | |||
| if (item.second == nullptr) { | |||
| return false; | |||
| } | |||
| auto iter = other.attrs_.find(item.first); | |||
| if (iter == other.attrs_.end()) { | |||
| return false; | |||
| } | |||
| return *item.second == *iter->second; | |||
| }); | |||
| return all; | |||
| } | |||
| std::string Cell::GetAttrString() const { | |||
| std::ostringstream buffer; | |||
| bool begin = true; | |||
| buffer << "{" << std::endl; | |||
| for (auto &attr : attrs_) { | |||
| if (!begin) { | |||
| buffer << ", " << std::endl; | |||
| } else { | |||
| begin = false; | |||
| } | |||
| buffer << attr.first << ":" << attr.second->ToString(); | |||
| } | |||
| buffer << "}"; | |||
| return buffer.str(); | |||
| } | |||
| std::string Cell::ToString() const { | |||
| std::ostringstream buffer; | |||
| buffer << "Cell " << name(); | |||
| return buffer.str(); | |||
| } | |||
| void Cell::DelAttr(const std::string &name) { attrs_.erase(name); } | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_IR_CELL_H_ | |||
| #define MINDSPORE_CCSRC_IR_CELL_H_ | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <tuple> | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/misc.h" | |||
| namespace mindspore { | |||
| using abstract::AbstractBasePtr; | |||
| using abstract::AbstractBasePtrList; | |||
| // value for Cell | |||
| class Cell : public Named { | |||
| public: | |||
| explicit Cell(const std::string &name) : Named(name) {} | |||
| MS_DECLARE_PARENT(Cell, Named); | |||
| abstract::AbstractBasePtr ToAbstract() override; | |||
| std::string ToString() const override; | |||
| std::string GetAttrString() const; | |||
| const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } | |||
| void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs_input) { attrs_ = attrs_input; } | |||
| void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; } | |||
| void DelAttr(const std::string &name); | |||
| ValuePtr GetAttr(const std::string &attr_name) const { | |||
| auto iter = attrs_.find(attr_name); | |||
| return iter == attrs_.cend() ? nullptr : iter->second; | |||
| } | |||
| bool HasAttr(const std::string &attr_name) const { | |||
| auto iter = attrs_.find(attr_name); | |||
| return !(iter == attrs_.cend()); | |||
| } | |||
| bool operator==(const Value &other) const override; | |||
| bool operator==(const Cell &other) const; | |||
| ~Cell() override = default; | |||
| const bool parse_info_ = true; | |||
| private: | |||
| std::unordered_map<std::string, ValuePtr> attrs_; | |||
| }; | |||
| using CellPtr = std::shared_ptr<Cell>; | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_IR_CELL_H_ | |||
| @@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, | |||
| MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count | |||
| << " were given."; | |||
| } | |||
| auto varg_name = specialized_graph->GetVariableArgName(); | |||
| // for python variable argument input , there is no upper limit | |||
| for (int i = 0; i < variable_args_count; ++i) { | |||
| ParameterPtr p = std::make_shared<Parameter>(specialized_graph); | |||
| std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); | |||
| std::string param_name = varg_name + std::to_string(i); | |||
| p->set_name(param_name); | |||
| MS_EXCEPTION_IF_NULL(p->debug_info()); | |||
| p->debug_info()->set_name(param_name); | |||
| @@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe | |||
| while (temp_info != nullptr) { | |||
| if (temp_info->trace_info() != nullptr) { | |||
| if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() || | |||
| temp_info->trace_info()->isa<TraceGenMetaFuncGraph>()) { | |||
| temp_info->trace_info()->isa<TraceGenMetaFuncGraph>() || | |||
| temp_info->trace_info()->isa<TraceGenerateVarArg>() || temp_info->trace_info()->isa<TraceGenerateKwArg>()) { | |||
| break; | |||
| } | |||
| trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); | |||
| @@ -24,14 +24,14 @@ from ..common import dtype as mstype | |||
| from ..common.api import _executor, _pynative_exec | |||
| from .._checkparam import _check_str_by_regular | |||
| from ..common.parameter import Parameter, ParameterTuple | |||
| from .._c_expression import init_backend | |||
| from .._c_expression import init_backend, Cell_ | |||
| from ..ops.primitive import Primitive | |||
| from ..ops.operations import HookBackward | |||
| from ..ops.functional import cast | |||
| from ..parallel._tensor import _load_tensor_by_layout | |||
| from ..common.tensor import Tensor | |||
| class Cell: | |||
| class Cell(Cell_): | |||
| """ | |||
| Base class for all neural networks. | |||
| @@ -58,14 +58,21 @@ class Cell: | |||
| >>> def construct(self, x): | |||
| >>> return self.relu(x) | |||
| """ | |||
| IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names', | |||
| '_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run', | |||
| '_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode', | |||
| '_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced', | |||
| 'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type'] | |||
| def __init__(self, auto_prefix=True, flags=None): | |||
| Cell_.__init__(self, self._cell_tag) | |||
| self._params = OrderedDict() | |||
| self._cells = OrderedDict() | |||
| self._params_list = OrderedDict() | |||
| self.training = False | |||
| self.requires_grad = False | |||
| self.pynative = False | |||
| self._attr_synced = False | |||
| self._param_prefix = '' | |||
| self._auto_prefix = auto_prefix | |||
| self._scope = None | |||
| @@ -92,6 +99,12 @@ class Cell: | |||
| def already_run(self): | |||
| return self._already_run | |||
| @property | |||
| def _cell_tag(self): | |||
| # `<class 'xxxxxxx'>` | |||
| # -> `xxxxxxx` | |||
| return str(self.__class__)[8:-2] | |||
| @already_run.setter | |||
| def already_run(self, value): | |||
| self._already_run = value | |||
| @@ -222,6 +235,7 @@ class Cell: | |||
| del self._cells[name] | |||
| else: | |||
| object.__delattr__(self, name) | |||
| self._attr_synced = False | |||
| def cast_inputs(self, inputs, dst_type): | |||
| res = list() | |||
| @@ -277,6 +291,34 @@ class Cell: | |||
| self._already_run = True | |||
| return output | |||
| def _add_attr(self, name, value): | |||
| if name and name[:2] != '__' and name not in Cell.IGNORE_LIST: | |||
| super(Cell, self)._add_attr(name, value) | |||
| def _sync_attr_for_compile(self): | |||
| """Sync the attr to c++ object.""" | |||
| if self._attr_synced: | |||
| return | |||
| cells = self.__dict__.get('_cells') | |||
| for key in cells: | |||
| cell = cells[key] | |||
| cell._sync_attr_for_compile() | |||
| self._add_attr(key, cell) | |||
| params = self.__dict__.get('_params') | |||
| for key in params: | |||
| if '.' in key: | |||
| continue | |||
| param = params[key] | |||
| self._add_attr(key, param) | |||
| params_list = self.__dict__.get('_params_list') | |||
| for key in params_list: | |||
| params_list_item = params_list[key] | |||
| self._add_attr(key, params_list_item) | |||
| for key in self.__dict__: | |||
| value = self.__dict__[key] | |||
| self._add_attr(key, value) | |||
| self._attr_synced = True | |||
| def __setattr__(self, name, value): | |||
| cells = self.__dict__.get('_cells') | |||
| params = self.__dict__.get('_params') | |||
| @@ -329,6 +371,8 @@ class Cell: | |||
| if isinstance(value, Primitive): | |||
| value.set_prim_instance_name(name) | |||
| object.__setattr__(self, name, value) | |||
| if name not in Cell.IGNORE_LIST: | |||
| self._attr_synced = False | |||
| def extend_repr(self): | |||
| """ | |||
| @@ -451,7 +495,7 @@ class Cell: | |||
| Object, the result of executing. | |||
| """ | |||
| self._auto_parallel_compile_and_run = True | |||
| _executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) | |||
| self.compile(*inputs) | |||
| if self._auto_parallel_mode: | |||
| if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: | |||
| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| """container""" | |||
| from collections import OrderedDict | |||
| from abc import abstractmethod, ABCMeta | |||
| from abc import abstractmethod | |||
| from ..cell import Cell | |||
| __all__ = ['SequentialCell', 'CellList'] | |||
| @@ -34,7 +34,7 @@ def _valid_cell(cell): | |||
| raise TypeError('Cell {} is not subclass of Cell'.format(cell)) | |||
| class _CellListBase(metaclass=ABCMeta): | |||
| class _CellListBase(): | |||
| """ | |||
| An interface for base the cell as list. | |||
| @@ -51,7 +51,7 @@ def test_get_parameter_layout(): | |||
| exe.compile(net, x, phase='train', auto_parallel_mode=True) | |||
| x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] | |||
| weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] | |||
| expect_dict = {'x': x_layout, 'w1': weight_layout} | |||
| expect_dict = {'args0': x_layout, 'w1': weight_layout} | |||
| # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut | |||
| assert net.parameter_layout_dict == expect_dict | |||
| @@ -125,7 +125,7 @@ def test_grad_sens_parameter_type(): | |||
| y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] | |||
| b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] | |||
| sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] | |||
| expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} | |||
| expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout} | |||
| assert net.parameter_layout_dict == expect_dict | |||