From e375c86d0c395db967d43b86b7eeef7ac1ba461b Mon Sep 17 00:00:00 2001 From: simson Date: Tue, 26 Jan 2021 09:39:56 +0800 Subject: [PATCH] allow list as parameter input & store op info using op_name instead of primitive id --- mindspore/ccsrc/pipeline/pynative/base.h | 1 - .../pipeline/pynative/pynative_execute.cc | 62 +++++++++++-------- .../pipeline/pynative/pynative_execute.h | 2 + mindspore/common/parameter.py | 2 +- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index a55d2a4d86..b87b169796 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -52,7 +52,6 @@ enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; struct OpExecInfo { std::string op_name; std::string op_index; - std::string prim_id; PrimitivePyPtr py_primitive; AbstractBasePtr abstract; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ec123bfc46..ad3e6d3207 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -163,6 +163,25 @@ std::map> GetTypeIndex(const std::vector return type_indexes; } +TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) { + if (max_type == TypeId::kNumberTypeBool) { + if (has_scalar_int64) { + max_type = TypeId::kNumberTypeInt64; + } + if (has_scalar_float32) { + max_type = TypeId::kNumberTypeFloat32; + } + } + if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && + max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) { + max_type = TypeId::kNumberTypeFloat32; + } + if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) { + max_type = TypeId::kNumberTypeInt16; + } + return max_type; +} + std::map GetDstType(const py::tuple &py_args, const std::map> &type_indexes) { std::map dst_type; @@ -178,14 +197,13 @@ std::map GetDstType(const py::tuple &py_args, bool has_scalar_int64 = false; bool has_tensor_int8 = false; for (size_t index : indexes) { - if (!has_scalar_float32 && py::isinstance(py_args[index])) { + auto obj = py_args[index]; + if (py::isinstance(obj)) { has_scalar_float32 = true; } - if (!has_scalar_int64 && !py::isinstance(py_args[index]) && py::isinstance(py_args[index])) { + if (!py::isinstance(obj) && py::isinstance(obj)) { has_scalar_int64 = true; } - - auto obj = py_args[index]; if (py::isinstance(obj)) { auto arg = py::cast(obj); TypeId arg_type_id = arg->data_type(); @@ -202,21 +220,7 @@ std::map GetDstType(const py::tuple &py_args, } } } - if (max_type == TypeId::kNumberTypeBool) { - if (has_scalar_int64) { - max_type = TypeId::kNumberTypeInt64; - } - if (has_scalar_float32) { - max_type = TypeId::kNumberTypeFloat32; - } - } - if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && - max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) { - max_type = TypeId::kNumberTypeFloat32; - } - if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) { - max_type = TypeId::kNumberTypeInt16; - } + max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8); (void)dst_type.emplace(std::make_pair(type, max_type)); } return dst_type; @@ -274,11 +278,11 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, } } // get prim and abstract info - (void)graph_info.append(op_exec_info->prim_id + "_"); + (void)graph_info.append(op_exec_info->op_name + "_"); // get attr info const auto &op_prim = op_exec_info->py_primitive; MS_EXCEPTION_IF_NULL(op_prim); - const auto &attr_map = op_prim->evaluate_added_attrs(); + const auto &attr_map = op_prim->attrs(); (void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); @@ -648,7 +652,6 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { if (!prim->HasPyObj()) { MS_LOG(EXCEPTION) << "Pyobj is empty"; } - op_exec_info->prim_id = GetId(prim->GetPyObj()); op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); op_exec_info->op_inputs = args[PY_INPUTS]; @@ -701,10 +704,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v input_node = GetInput(obj, op_mask); } // update abstract - if (input_node != nullptr && input_node->abstract() != nullptr) { - abs = input_node->abstract(); - } if (input_node != nullptr) { + if (input_node->abstract() != nullptr) { + abs = input_node->abstract(); + } inputs.emplace_back(input_node); } } @@ -2169,8 +2172,8 @@ void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { } } -void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, - bool need_cloned, bool is_grad) { +void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, + bool need_cloned, bool is_grad) { auto update_in_endgraph = need_cloned && !is_grad; if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { // Bprop just save backward graph @@ -2197,7 +2200,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt } return; } +} +void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, + bool need_cloned, bool is_grad) { + auto update_in_endgraph = need_cloned && !is_grad; + UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad); FuncGraphPtr tmp = g; if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { MS_LOG(DEBUG) << "No need cloned"; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 8507d1c8b6..b087c8a31f 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -241,6 +241,8 @@ class PynativeExecutor : public std::enable_shared_from_this { bool CheckCellGraph(const std::string &cell_id, bool is_grad = false); bool CheckDynamicCell(const std::string &cell_id); bool CheckRealDynamicCell(const std::string &cell_id); + void UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned, + bool is_grad); void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned = false, bool is_grad = false); void ClearCnodeRes(const AnfNodePtr &node, std::unordered_set *node_set); diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index cba3d9232b..e1929de643 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -159,7 +159,7 @@ class Parameter(Tensor_): Tensor_.__init__(self, mstype.int64, ()) elif isinstance(default_input, float): Tensor_.__init__(self, mstype.float32, ()) - elif isinstance(default_input, np.ndarray): + elif isinstance(default_input, (np.ndarray, list)): Tensor_.__init__(self, default_input) else: raise TypeError(f"Parameter input must be [`Tensor`, `Number`]."