Browse Source

!11803 allow list as parameter input & store op info using op_name instead of primitive id

From: @simson_wu
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
5eddd9dbc0
4 changed files with 38 additions and 29 deletions
  1. +0
    -1
      mindspore/ccsrc/pipeline/pynative/base.h
  2. +35
    -27
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +2
    -0
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  4. +1
    -1
      mindspore/common/parameter.py

+ 0
- 1
mindspore/ccsrc/pipeline/pynative/base.h View File

@@ -52,7 +52,6 @@ enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
struct OpExecInfo {
std::string op_name;
std::string op_index;
std::string prim_id;
PrimitivePyPtr py_primitive;
AbstractBasePtr abstract;



+ 35
- 27
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -163,6 +163,25 @@ std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector
return type_indexes;
}

TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) {
if (max_type == TypeId::kNumberTypeBool) {
if (has_scalar_int64) {
max_type = TypeId::kNumberTypeInt64;
}
if (has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
}
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
max_type = TypeId::kNumberTypeInt16;
}
return max_type;
}

std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) {
std::map<SignatureEnumDType, TypeId> dst_type;
@@ -178,14 +197,13 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
bool has_scalar_int64 = false;
bool has_tensor_int8 = false;
for (size_t index : indexes) {
if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) {
auto obj = py_args[index];
if (py::isinstance<py::float_>(obj)) {
has_scalar_float32 = true;
}
if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) {
if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) {
has_scalar_int64 = true;
}

auto obj = py_args[index];
if (py::isinstance<tensor::Tensor>(obj)) {
auto arg = py::cast<tensor::TensorPtr>(obj);
TypeId arg_type_id = arg->data_type();
@@ -202,21 +220,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args,
}
}
}
if (max_type == TypeId::kNumberTypeBool) {
if (has_scalar_int64) {
max_type = TypeId::kNumberTypeInt64;
}
if (has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
}
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 &&
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) {
max_type = TypeId::kNumberTypeFloat32;
}
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) {
max_type = TypeId::kNumberTypeInt16;
}
max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8);
(void)dst_type.emplace(std::make_pair(type, max_type));
}
return dst_type;
@@ -274,11 +278,11 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
}
}
// get prim and abstract info
(void)graph_info.append(op_exec_info->prim_id + "_");
(void)graph_info.append(op_exec_info->op_name + "_");
// get attr info
const auto &op_prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->evaluate_added_attrs();
const auto &attr_map = op_prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });

@@ -648,7 +652,6 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) {
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "Pyobj is empty";
}
op_exec_info->prim_id = GetId(prim->GetPyObj());
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
op_exec_info->op_inputs = args[PY_INPUTS];
@@ -701,10 +704,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v
input_node = GetInput(obj, op_mask);
}
// update abstract
if (input_node != nullptr && input_node->abstract() != nullptr) {
abs = input_node->abstract();
}
if (input_node != nullptr) {
if (input_node->abstract() != nullptr) {
abs = input_node->abstract();
}
inputs.emplace_back(input_node);
}
}
@@ -2169,8 +2172,8 @@ void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) {
}
}

void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
auto update_in_endgraph = need_cloned && !is_grad;
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
// Bprop just save backward graph
@@ -2197,7 +2200,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt
}
return;
}
}

void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned, bool is_grad) {
auto update_in_endgraph = need_cloned && !is_grad;
UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad);
FuncGraphPtr tmp = g;
if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) {
MS_LOG(DEBUG) << "No need cloned";


+ 2
- 0
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -241,6 +241,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
bool CheckDynamicCell(const std::string &cell_id);
bool CheckRealDynamicCell(const std::string &cell_id);
void UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, bool need_cloned,
bool is_grad);
void UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id,
bool need_cloned = false, bool is_grad = false);
void ClearCnodeRes(const AnfNodePtr &node, std::unordered_set<AnfNodePtr> *node_set);


+ 1
- 1
mindspore/common/parameter.py View File

@@ -159,7 +159,7 @@ class Parameter(Tensor_):
Tensor_.__init__(self, mstype.int64, ())
elif isinstance(default_input, float):
Tensor_.__init__(self, mstype.float32, ())
elif isinstance(default_input, np.ndarray):
elif isinstance(default_input, (np.ndarray, list)):
Tensor_.__init__(self, default_input)
else:
raise TypeError(f"Parameter input must be [`Tensor`, `Number`]."


Loading…
Cancel
Save