|
|
|
@@ -163,6 +163,25 @@ std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector |
|
|
|
return type_indexes; |
|
|
|
} |
|
|
|
|
|
|
|
TypeId JudgeMaxType(TypeId max_type, bool has_scalar_float32, bool has_scalar_int64, bool has_tensor_int8) { |
|
|
|
if (max_type == TypeId::kNumberTypeBool) { |
|
|
|
if (has_scalar_int64) { |
|
|
|
max_type = TypeId::kNumberTypeInt64; |
|
|
|
} |
|
|
|
if (has_scalar_float32) { |
|
|
|
max_type = TypeId::kNumberTypeFloat32; |
|
|
|
} |
|
|
|
} |
|
|
|
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && |
|
|
|
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) { |
|
|
|
max_type = TypeId::kNumberTypeFloat32; |
|
|
|
} |
|
|
|
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) { |
|
|
|
max_type = TypeId::kNumberTypeInt16; |
|
|
|
} |
|
|
|
return max_type; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, |
|
|
|
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) { |
|
|
|
std::map<SignatureEnumDType, TypeId> dst_type; |
|
|
|
@@ -178,14 +197,13 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, |
|
|
|
bool has_scalar_int64 = false; |
|
|
|
bool has_tensor_int8 = false; |
|
|
|
for (size_t index : indexes) { |
|
|
|
if (!has_scalar_float32 && py::isinstance<py::float_>(py_args[index])) { |
|
|
|
auto obj = py_args[index]; |
|
|
|
if (py::isinstance<py::float_>(obj)) { |
|
|
|
has_scalar_float32 = true; |
|
|
|
} |
|
|
|
if (!has_scalar_int64 && !py::isinstance<py::bool_>(py_args[index]) && py::isinstance<py::int_>(py_args[index])) { |
|
|
|
if (!py::isinstance<py::bool_>(obj) && py::isinstance<py::int_>(obj)) { |
|
|
|
has_scalar_int64 = true; |
|
|
|
} |
|
|
|
|
|
|
|
auto obj = py_args[index]; |
|
|
|
if (py::isinstance<tensor::Tensor>(obj)) { |
|
|
|
auto arg = py::cast<tensor::TensorPtr>(obj); |
|
|
|
TypeId arg_type_id = arg->data_type(); |
|
|
|
@@ -202,21 +220,7 @@ std::map<SignatureEnumDType, TypeId> GetDstType(const py::tuple &py_args, |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (max_type == TypeId::kNumberTypeBool) { |
|
|
|
if (has_scalar_int64) { |
|
|
|
max_type = TypeId::kNumberTypeInt64; |
|
|
|
} |
|
|
|
if (has_scalar_float32) { |
|
|
|
max_type = TypeId::kNumberTypeFloat32; |
|
|
|
} |
|
|
|
} |
|
|
|
if (max_type != TypeId::kNumberTypeFloat16 && max_type != TypeId::kNumberTypeFloat32 && |
|
|
|
max_type != TypeId::kNumberTypeFloat64 && max_type != TypeId::kTypeUnknown && has_scalar_float32) { |
|
|
|
max_type = TypeId::kNumberTypeFloat32; |
|
|
|
} |
|
|
|
if (max_type == TypeId::kNumberTypeUInt8 && has_tensor_int8) { |
|
|
|
max_type = TypeId::kNumberTypeInt16; |
|
|
|
} |
|
|
|
max_type = JudgeMaxType(max_type, has_scalar_float32, has_scalar_int64, has_tensor_int8); |
|
|
|
(void)dst_type.emplace(std::make_pair(type, max_type)); |
|
|
|
} |
|
|
|
return dst_type; |
|
|
|
@@ -274,11 +278,11 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, |
|
|
|
} |
|
|
|
} |
|
|
|
// get prim and abstract info |
|
|
|
(void)graph_info.append(op_exec_info->prim_id + "_"); |
|
|
|
(void)graph_info.append(op_exec_info->op_name + "_"); |
|
|
|
// get attr info |
|
|
|
const auto &op_prim = op_exec_info->py_primitive; |
|
|
|
MS_EXCEPTION_IF_NULL(op_prim); |
|
|
|
const auto &attr_map = op_prim->evaluate_added_attrs(); |
|
|
|
const auto &attr_map = op_prim->attrs(); |
|
|
|
(void)std::for_each(attr_map.begin(), attr_map.end(), |
|
|
|
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); |
|
|
|
|
|
|
|
@@ -648,7 +652,6 @@ OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { |
|
|
|
if (!prim->HasPyObj()) { |
|
|
|
MS_LOG(EXCEPTION) << "Pyobj is empty"; |
|
|
|
} |
|
|
|
op_exec_info->prim_id = GetId(prim->GetPyObj()); |
|
|
|
op_exec_info->py_primitive = prim; |
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); |
|
|
|
op_exec_info->op_inputs = args[PY_INPUTS]; |
|
|
|
@@ -701,10 +704,10 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
input_node = GetInput(obj, op_mask); |
|
|
|
} |
|
|
|
// update abstract |
|
|
|
if (input_node != nullptr && input_node->abstract() != nullptr) { |
|
|
|
abs = input_node->abstract(); |
|
|
|
} |
|
|
|
if (input_node != nullptr) { |
|
|
|
if (input_node->abstract() != nullptr) { |
|
|
|
abs = input_node->abstract(); |
|
|
|
} |
|
|
|
inputs.emplace_back(input_node); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -2169,8 +2172,8 @@ void PynativeExecutor::UpdateCellDynamic(const std::string &cell_id) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
void PynativeExecutor::UpdateBpropCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
auto update_in_endgraph = need_cloned && !is_grad; |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
// Bprop just save backward graph |
|
|
|
@@ -2197,7 +2200,12 @@ void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPt |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::UpdateCellGraph(const py::object &cell, const FuncGraphPtr &g, const std::string &cell_id, |
|
|
|
bool need_cloned, bool is_grad) { |
|
|
|
auto update_in_endgraph = need_cloned && !is_grad; |
|
|
|
UpdateBpropCellGraph(cell, g, cell_id, need_cloned, is_grad); |
|
|
|
FuncGraphPtr tmp = g; |
|
|
|
if (!IsFirstGradStep(top_cell_id_) && CheckDynamicCell(cell_id) && !CheckRealDynamicCell(cell_id)) { |
|
|
|
MS_LOG(DEBUG) << "No need cloned"; |
|
|
|
|