|
|
|
@@ -235,58 +235,6 @@ std::string TypeIdToMsTypeStr(const TypeId &type_id) { |
|
|
|
return type_name->second; |
|
|
|
} |
|
|
|
|
|
|
|
py::object DoAutoCast(const py::object &arg, const TypeId &type_id) { |
|
|
|
py::tuple args(3); |
|
|
|
std::string module_name = "mindspore.ops.functional"; |
|
|
|
std::string op_name = "cast"; |
|
|
|
args[0] = parse::python_adapter::GetPyFn(module_name, op_name); |
|
|
|
args[1] = "Cast"; |
|
|
|
|
|
|
|
std::string dst_type_str = TypeIdToMsTypeStr(type_id); |
|
|
|
module_name = "mindspore.common.dtype"; |
|
|
|
py::object dst_type = parse::python_adapter::GetPyFn(module_name, dst_type_str); |
|
|
|
py::tuple inputs(2); |
|
|
|
inputs[0] = arg; |
|
|
|
inputs[1] = dst_type; |
|
|
|
args[2] = inputs; |
|
|
|
|
|
|
|
return RunOp(args)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj) { |
|
|
|
MS_EXCEPTION_IF_NULL(is_cast); |
|
|
|
auto tensor = py::cast<tensor::TensorPtr>(obj); |
|
|
|
auto cast_type = tensor->cast_dtype(); |
|
|
|
py::object cast_output = obj; |
|
|
|
if (cast_type != nullptr) { |
|
|
|
auto source_element = tensor->Dtype(); |
|
|
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { |
|
|
|
MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); |
|
|
|
cast_output = DoAutoCast(obj, cast_type->type_id()); |
|
|
|
*is_cast = true; |
|
|
|
} |
|
|
|
} |
|
|
|
return cast_output; |
|
|
|
} |
|
|
|
|
|
|
|
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple) { |
|
|
|
MS_EXCEPTION_IF_NULL(is_cast); |
|
|
|
auto tuple_size = static_cast<int64_t>(tuple.size()); |
|
|
|
py::tuple result(tuple_size); |
|
|
|
|
|
|
|
for (int64_t i = 0; i < tuple_size; i++) { |
|
|
|
if (py::isinstance<tensor::MetaTensor>(tuple[i])) { |
|
|
|
MS_LOG(DEBUG) << "Call cast for item " << i; |
|
|
|
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i]); |
|
|
|
} else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) { |
|
|
|
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i]); |
|
|
|
} else { |
|
|
|
result[i] = tuple[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
return std::move(result); |
|
|
|
} |
|
|
|
|
|
|
|
bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType> *dtypes) { |
|
|
|
MS_EXCEPTION_IF_NULL(dtypes); |
|
|
|
auto signature = prim->signatures(); |
|
|
|
@@ -302,69 +250,6 @@ bool GetSignatureType(const PrimitivePyPtr &prim, std::vector<SignatureEnumDType |
|
|
|
return has_sig_dtype; |
|
|
|
} |
|
|
|
|
|
|
|
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type, |
|
|
|
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info) { |
|
|
|
const auto &signature = prim->signatures(); |
|
|
|
auto &out_args = op_exec_info->op_inputs; |
|
|
|
bool has_dtype_sig = !dtypes.empty(); |
|
|
|
for (size_t i = 0; i < out_args.size(); ++i) { |
|
|
|
MS_LOG(DEBUG) << "Check inputs " << i; |
|
|
|
auto obj = out_args[i]; |
|
|
|
auto sig = SignatureEnumRW::kRWDefault; |
|
|
|
if (!signature.empty()) { |
|
|
|
sig = signature[i].rw; |
|
|
|
} |
|
|
|
bool is_parameter = false; |
|
|
|
TypeId arg_type_id = kTypeUnknown; |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto arg = py::cast<tensor::MetaTensorPtr>(obj); |
|
|
|
if (arg->is_parameter()) { |
|
|
|
is_parameter = true; |
|
|
|
MS_LOG(DEBUG) << "Parameter is read " << i; |
|
|
|
} |
|
|
|
arg_type_id = arg->data_type(); |
|
|
|
} |
|
|
|
|
|
|
|
// No need to implicit cast if no dtype. |
|
|
|
if (!has_dtype_sig || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto it = dst_type.find(dtypes[i]); |
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// implicit cast |
|
|
|
bool is_same_type = false; |
|
|
|
bool is_sig_write = (sig == SignatureEnumRW::kRWWrite); |
|
|
|
if (arg_type_id != 0) { |
|
|
|
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); |
|
|
|
} |
|
|
|
if (is_sig_write) { |
|
|
|
if (!is_parameter) { |
|
|
|
prim::RaiseExceptionForCheckParameter(prim->name(), i, "not"); |
|
|
|
} |
|
|
|
if (arg_type_id != 0) { |
|
|
|
if (!is_same_type) { |
|
|
|
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), |
|
|
|
TypeIdToMsTypeStr(it->second)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (is_same_type) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) { |
|
|
|
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i |
|
|
|
<< "th input is a not support implicit conversion type: " |
|
|
|
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is " |
|
|
|
<< py::cast<py::str>(obj) << "."; |
|
|
|
} |
|
|
|
py::object cast_output = DoAutoCast(out_args[i], it->second); |
|
|
|
out_args[i] = cast_output; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info, |
|
|
|
const abstract::AbstractBasePtrList &args_spec_list) { |
|
|
|
MS_LOG(DEBUG) << "Prim " << prim->name() << " input infer " << mindspore::ToString(args_spec_list); |
|
|
|
@@ -694,8 +579,10 @@ PynativeExecutor::~PynativeExecutor() { ClearRes(); } |
|
|
|
py::tuple RunOp(const py::args &args) { |
|
|
|
auto executor = PynativeExecutor::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
MS_LOG(DEBUG) << "RunOp start " << args.size(); |
|
|
|
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); |
|
|
|
try { |
|
|
|
return executor->RunOpInner(args); |
|
|
|
return executor->RunOpInner(op_exec_info); |
|
|
|
} catch (const py::error_already_set &ex) { |
|
|
|
executor->Clean(); |
|
|
|
// re-throw this exception to Python interpreter to handle it |
|
|
|
@@ -720,12 +607,9 @@ py::tuple RunOp(const py::args &args) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple PynativeExecutor::RunOpInner(const py::args &args) { |
|
|
|
MS_LOG(DEBUG) << "RunOp start " << args.size(); |
|
|
|
OpExecInfoPtr op_exec_info = nullptr; |
|
|
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); |
|
|
|
auto name = py::cast<std::string>(args[PY_NAME]); |
|
|
|
op_exec_info = GenerateOpExecInfo(args); |
|
|
|
py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { |
|
|
|
auto prim = op_exec_info->py_primitive; |
|
|
|
auto name = op_exec_info->op_name; |
|
|
|
if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { |
|
|
|
return RunOpWithInitBackendPolicy(op_exec_info); |
|
|
|
} |
|
|
|
@@ -828,8 +712,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " |
|
|
|
<< "inputs size " << sig_size; |
|
|
|
} |
|
|
|
bool is_cast_op = (op_exec_info->op_name == "Cast"); |
|
|
|
if (!is_cast_op) { |
|
|
|
if (op_exec_info->op_name != prim::kPrimCast->name()) { |
|
|
|
RunParameterAutoMixPrecisionCast(op_exec_info); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Make cnode for " << op_exec_info->op_name; |
|
|
|
@@ -846,7 +729,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
MS_LOG(DEBUG) << "Gen args i " << i << " " << op_exec_info->op_name << " op mask " << op_mask << " grad_flag_ " |
|
|
|
<< grad_flag_; |
|
|
|
|
|
|
|
AnfNodePtr node = nullptr; |
|
|
|
AnfNodePtr input_node = nullptr; |
|
|
|
abstract::AbstractBasePtr abs = nullptr; |
|
|
|
auto id = GetId(obj); |
|
|
|
auto it = node_abs_map_.find(id); |
|
|
|
@@ -854,11 +737,11 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
abs = it->second; |
|
|
|
} |
|
|
|
if (!graph_info_map_.empty()) { |
|
|
|
node = GetInput(obj, op_mask); |
|
|
|
input_node = GetInput(obj, op_mask); |
|
|
|
} |
|
|
|
// update abstract |
|
|
|
if (node != nullptr && node->abstract() != nullptr) { |
|
|
|
abs = node->abstract(); |
|
|
|
if (input_node != nullptr && input_node->abstract() != nullptr) { |
|
|
|
abs = input_node->abstract(); |
|
|
|
} |
|
|
|
|
|
|
|
auto const_input_index = prim->get_const_input_indexes(); |
|
|
|
@@ -880,8 +763,8 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
node_abs_map_[id] = abs; |
|
|
|
} |
|
|
|
(*args_spec_list).emplace_back(abs); |
|
|
|
if (node != nullptr) { |
|
|
|
inputs.emplace_back(node); |
|
|
|
if (input_node != nullptr) { |
|
|
|
inputs.emplace_back(input_node); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -893,9 +776,125 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
return cnode; |
|
|
|
} |
|
|
|
|
|
|
|
py::object PynativeExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, |
|
|
|
size_t index) { |
|
|
|
py::tuple cast_args(3); |
|
|
|
cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); |
|
|
|
cast_args[PY_NAME] = prim::kPrimCast->name(); |
|
|
|
std::string dst_type_str = TypeIdToMsTypeStr(type_id); |
|
|
|
py::object dst_type = parse::python_adapter::GetPyFn(kMSDtypeModelName, dst_type_str); |
|
|
|
py::tuple inputs(2); |
|
|
|
inputs[0] = arg; |
|
|
|
inputs[1] = dst_type; |
|
|
|
cast_args[PY_INPUTS] = inputs; |
|
|
|
auto op_exec = GenerateOpExecInfo(cast_args); |
|
|
|
op_exec->is_mixed_precision_cast = true; |
|
|
|
op_exec->next_op_name = op_name; |
|
|
|
op_exec->next_input_index = index; |
|
|
|
return RunOpInner(op_exec)[0]; |
|
|
|
} |
|
|
|
|
|
|
|
py::object PynativeExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, |
|
|
|
size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(is_cast); |
|
|
|
auto tensor = py::cast<tensor::TensorPtr>(obj); |
|
|
|
auto cast_type = tensor->cast_dtype(); |
|
|
|
py::object cast_output = obj; |
|
|
|
if (cast_type != nullptr) { |
|
|
|
auto source_element = tensor->Dtype(); |
|
|
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { |
|
|
|
MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); |
|
|
|
*is_cast = true; |
|
|
|
return DoAutoCast(obj, cast_type->type_id(), op_name, index); |
|
|
|
} |
|
|
|
} |
|
|
|
return cast_output; |
|
|
|
} |
|
|
|
|
|
|
|
py::object PynativeExecutor::DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, |
|
|
|
const std::string &op_name, size_t index) { |
|
|
|
MS_EXCEPTION_IF_NULL(is_cast); |
|
|
|
auto tuple_size = static_cast<int64_t>(tuple.size()); |
|
|
|
py::tuple result(tuple_size); |
|
|
|
|
|
|
|
for (int64_t i = 0; i < tuple_size; i++) { |
|
|
|
if (py::isinstance<tensor::MetaTensor>(tuple[i])) { |
|
|
|
MS_LOG(DEBUG) << "Call cast for item " << i; |
|
|
|
result[i] = DoParamMixPrecisionCast(is_cast, tuple[i], op_name, index); |
|
|
|
} else if (py::isinstance<py::tuple>(tuple[i]) || py::isinstance<py::list>(tuple[i])) { |
|
|
|
result[i] = DoParamMixPrecisionCastTuple(is_cast, tuple[i], op_name, index); |
|
|
|
} else { |
|
|
|
result[i] = tuple[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
return std::move(result); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type, |
|
|
|
const std::vector<SignatureEnumDType> &dtypes, |
|
|
|
const OpExecInfoPtr &op_exec_info) { |
|
|
|
const auto &signature = prim->signatures(); |
|
|
|
auto &out_args = op_exec_info->op_inputs; |
|
|
|
for (size_t i = 0; i < out_args.size(); ++i) { |
|
|
|
// No need to implicit cast if no dtype. |
|
|
|
if (dtypes.empty() || dtypes[i] == SignatureEnumDType::kDTypeEmptyDefaultValue) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto it = dst_type.find(dtypes[i]); |
|
|
|
if (it == dst_type.end() || it->second == kTypeUnknown) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Check inputs " << i; |
|
|
|
auto obj = out_args[i]; |
|
|
|
auto sig = SignatureEnumRW::kRWDefault; |
|
|
|
if (!signature.empty()) { |
|
|
|
sig = signature[i].rw; |
|
|
|
} |
|
|
|
bool is_parameter = false; |
|
|
|
TypeId arg_type_id = kTypeUnknown; |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto arg = py::cast<tensor::MetaTensorPtr>(obj); |
|
|
|
if (arg->is_parameter()) { |
|
|
|
is_parameter = true; |
|
|
|
MS_LOG(DEBUG) << "Parameter is read " << i; |
|
|
|
} |
|
|
|
arg_type_id = arg->data_type(); |
|
|
|
} |
|
|
|
// implicit cast |
|
|
|
bool is_same_type = false; |
|
|
|
if (arg_type_id != kTypeUnknown) { |
|
|
|
is_same_type = (prim::type_map.find(arg_type_id) == prim::type_map.end() || arg_type_id == it->second); |
|
|
|
} |
|
|
|
if (sig == SignatureEnumRW::kRWWrite) { |
|
|
|
if (!is_parameter) { |
|
|
|
prim::RaiseExceptionForCheckParameter(prim->name(), i, "not"); |
|
|
|
} |
|
|
|
if (arg_type_id != kTypeUnknown) { |
|
|
|
if (!is_same_type) { |
|
|
|
prim::RaiseExceptionForConvertRefDtype(prim->name(), TypeIdToMsTypeStr(arg_type_id), |
|
|
|
TypeIdToMsTypeStr(it->second)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (is_same_type) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
if (!py::isinstance<tensor::Tensor>(obj) && !py::isinstance<py::int_>(obj) && !py::isinstance<py::float_>(obj)) { |
|
|
|
MS_EXCEPTION(TypeError) << "For '" << prim->name() << "', the " << i |
|
|
|
<< "th input is a not support implicit conversion type: " |
|
|
|
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is " |
|
|
|
<< py::cast<py::str>(obj) << "."; |
|
|
|
} |
|
|
|
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i); |
|
|
|
out_args[i] = cast_output; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info) { |
|
|
|
size_t size = op_exec_info->op_inputs.size(); |
|
|
|
auto prim = op_exec_info->py_primitive; |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
const auto &signature = prim->signatures(); |
|
|
|
for (size_t i = 0; i < size; i++) { |
|
|
|
auto obj = op_exec_info->op_inputs[i]; |
|
|
|
@@ -916,10 +915,10 @@ void PynativeExecutor::RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_ |
|
|
|
} |
|
|
|
} |
|
|
|
// redundant cast call if the tensor is a const Tensor. |
|
|
|
cast_output = DoParamMixPrecisionCast(&is_cast, obj); |
|
|
|
cast_output = DoParamMixPrecisionCast(&is_cast, obj, prim->name(), i); |
|
|
|
} else if (py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) { |
|
|
|
// mix precision for tuple inputs |
|
|
|
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj); |
|
|
|
cast_output = DoParamMixPrecisionCastTuple(&is_cast, obj, prim->name(), i); |
|
|
|
} |
|
|
|
if (is_cast) { |
|
|
|
op_exec_info->op_inputs[i] = cast_output; |
|
|
|
@@ -958,7 +957,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
free_param->set_default_param(value); |
|
|
|
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; |
|
|
|
graph_info_map_[df_builder_].params.emplace(obj_id); |
|
|
|
set_node_map(df_builder_, obj_id, free_param); |
|
|
|
SetNodeMapInGraphInfoMap(df_builder_, obj_id, free_param); |
|
|
|
return free_param; |
|
|
|
} |
|
|
|
return graph_info_map_[df_builder_].node_map[obj_id].first; |
|
|
|
@@ -969,7 +968,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
// out = op(op1(x, y)) |
|
|
|
// out = op(cell1(x, y)) |
|
|
|
// out = op(cell1(x, y)[0]) |
|
|
|
node = GetObjNode(obj, obj_id); |
|
|
|
return GetObjNode(obj, obj_id); |
|
|
|
} else if (py::isinstance<py::tuple>(obj)) { |
|
|
|
// out = op((x, y)) |
|
|
|
// out = cell((x, y)) |
|
|
|
@@ -985,7 +984,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { |
|
|
|
args.emplace_back(GetInput(tuple[i], false)); |
|
|
|
} |
|
|
|
auto cnode = curr_g_->NewCNode(args); |
|
|
|
set_node_map(curr_g_, GetId(obj), cnode); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, GetId(obj), cnode); |
|
|
|
node = cnode; |
|
|
|
} else { |
|
|
|
node = MakeValueNode(obj, obj_id); |
|
|
|
@@ -1048,7 +1047,7 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str |
|
|
|
ValuePtr converted_ret = nullptr; |
|
|
|
parse::ConvertData(obj, &converted_ret); |
|
|
|
auto node = NewValueNode(converted_ret); |
|
|
|
set_node_map(curr_g_, obj_id, node); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, obj_id, node); |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1083,12 +1082,12 @@ void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::ob |
|
|
|
if (size > 1) { |
|
|
|
for (int64_t i = 0; i < size; ++i) { |
|
|
|
auto value_id = GetId(value[i]); |
|
|
|
set_node_map(curr_g_, value_id, cnode, i); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, value_id, cnode, i); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
set_node_map(curr_g_, obj_id, cnode); |
|
|
|
set_pyobj(curr_g_, obj_id); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, obj_id, cnode); |
|
|
|
SetPyObjInGraphInfoMap(curr_g_, obj_id); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { |
|
|
|
@@ -1305,8 +1304,10 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati |
|
|
|
ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); |
|
|
|
// get graph info for checking it whether existing in the cache |
|
|
|
std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); |
|
|
|
session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, op_exec_info->abstract, |
|
|
|
op_exec_info->value, op_exec_info->is_dynamic_shape}; |
|
|
|
session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, |
|
|
|
op_exec_info->abstract, op_exec_info->value, |
|
|
|
op_exec_info->is_dynamic_shape, op_exec_info->is_mixed_precision_cast, |
|
|
|
op_exec_info->next_op_name, op_exec_info->next_input_index}; |
|
|
|
session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); |
|
|
|
EraseValueNodeTensor(tensors_mask, &input_tensors); |
|
|
|
VectorRef outputs; |
|
|
|
@@ -1318,15 +1319,15 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::Pushp() { graph_context_.push(curr_g_); } |
|
|
|
void PynativeExecutor::PushCurrentGraphToStack() { graph_stack_.push(curr_g_); } |
|
|
|
|
|
|
|
void PynativeExecutor::Popp() { |
|
|
|
if (graph_context_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Stack graph_context_ is empty"; |
|
|
|
void PynativeExecutor::PopGraphStack() { |
|
|
|
if (graph_stack_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Stack graph_stack_ is empty"; |
|
|
|
} |
|
|
|
graph_context_.pop(); |
|
|
|
if (!graph_context_.empty()) { |
|
|
|
curr_g_ = graph_context_.top(); |
|
|
|
graph_stack_.pop(); |
|
|
|
if (!graph_stack_.empty()) { |
|
|
|
curr_g_ = graph_stack_.top(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1468,7 +1469,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
MS_LOG(DEBUG) << "NewGraphInner start, args size: " << args.size() << ", cell id: " << cell_id; |
|
|
|
// check whether cell needed to construct grad graph |
|
|
|
if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { |
|
|
|
if (graph_stack_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { |
|
|
|
auto it = cell_resource_map_.find(cell_id); |
|
|
|
if (it != cell_resource_map_.end()) { |
|
|
|
resource_ = it->second; |
|
|
|
@@ -1479,22 +1480,21 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg |
|
|
|
} |
|
|
|
// init resource for constructing forward graph and grad graph |
|
|
|
auto g = std::make_shared<FuncGraph>(); |
|
|
|
if (graph_context_.empty()) { |
|
|
|
if (graph_stack_.empty()) { |
|
|
|
MakeNewTopGraph(cell_id, args, g); |
|
|
|
} else { |
|
|
|
MS_EXCEPTION_IF_NULL(df_builder_); |
|
|
|
curr_g_ = g; |
|
|
|
} |
|
|
|
Pushp(); |
|
|
|
MS_EXCEPTION_IF_NULL(df_builder_); |
|
|
|
curr_g_ = g; |
|
|
|
PushCurrentGraphToStack(); |
|
|
|
if (graph_info_map_.find(curr_g_) == graph_info_map_.end()) { |
|
|
|
graph_info_map_.emplace(curr_g_, GraphInfo()); |
|
|
|
} |
|
|
|
for (size_t i = 0; i < args.size(); ++i) { |
|
|
|
auto param = args[i]; |
|
|
|
auto new_param = g->add_parameter(); |
|
|
|
std::string param_obj = GetId(param); |
|
|
|
set_node_map(curr_g_, param, new_param, true); |
|
|
|
set_node_map(curr_g_, param_obj, new_param); |
|
|
|
std::string param_id = GetId(param); |
|
|
|
SetTupleArgsToGraphInfoMap(curr_g_, param, new_param, true); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, param_id, new_param); |
|
|
|
} |
|
|
|
// check whether the constrcut of cell will be changed |
|
|
|
if (!dynamic_cell_) { |
|
|
|
@@ -1525,46 +1525,47 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar |
|
|
|
top_graph_cells_.emplace(cell_id); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, |
|
|
|
bool is_param) { |
|
|
|
if (!py::isinstance<py::tuple>(node) && !py::isinstance<py::list>(node)) { |
|
|
|
void PynativeExecutor::SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node, |
|
|
|
bool is_param) { |
|
|
|
if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto tuple = node.cast<py::tuple>(); |
|
|
|
auto tuple = args.cast<py::tuple>(); |
|
|
|
auto tuple_size = static_cast<int64_t>(tuple.size()); |
|
|
|
for (int64_t i = 0; i < tuple_size; ++i) { |
|
|
|
auto id = GetId(tuple[i]); |
|
|
|
if (is_param) { |
|
|
|
graph_info_map_[g].params.emplace(id); |
|
|
|
} |
|
|
|
set_node_map(g, id, cnode, i); |
|
|
|
set_tuple_node_map(g, tuple[i], cnode, std::vector<int64_t>{i}, is_param); |
|
|
|
SetNodeMapInGraphInfoMap(g, id, node, i); |
|
|
|
SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, std::vector<int64_t>{i}, is_param); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, |
|
|
|
const std::vector<int64_t> &idx, bool is_param) { |
|
|
|
if (!py::isinstance<py::tuple>(node) && !py::isinstance<py::list>(node)) { |
|
|
|
void PynativeExecutor::SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, |
|
|
|
const AnfNodePtr &node, |
|
|
|
const std::vector<int64_t> &index_sequence, bool is_param) { |
|
|
|
if (!py::isinstance<py::tuple>(args) && !py::isinstance<py::list>(args)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto tuple = node.cast<py::tuple>(); |
|
|
|
auto tuple = args.cast<py::tuple>(); |
|
|
|
auto tuple_size = static_cast<int64_t>(tuple.size()); |
|
|
|
for (int64_t i = 0; i < tuple_size; ++i) { |
|
|
|
std::vector<int64_t> tmp = idx; |
|
|
|
std::vector<int64_t> tmp = index_sequence; |
|
|
|
tmp.emplace_back(i); |
|
|
|
auto id = GetId(tuple[i]); |
|
|
|
if (is_param) { |
|
|
|
graph_info_map_[g].params.emplace(id); |
|
|
|
} |
|
|
|
set_node_map(g, id, cnode, tmp); |
|
|
|
set_tuple_node_map(g, tuple[i], cnode, tmp, is_param); |
|
|
|
SetNodeMapInGraphInfoMap(g, id, node, tmp); |
|
|
|
SetTupleItemArgsToGraphInfoMap(g, tuple[i], node, tmp, is_param); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &out, const py::args &args) { |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
MS_LOG(DEBUG) << "EndGraphInner start " << args.size() << " " << cell_id; |
|
|
|
if (graph_context_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { |
|
|
|
if (graph_stack_.empty() && cell_graph_map_.find(cell_id) != cell_graph_map_.end() && !dynamic_cell_) { |
|
|
|
MS_LOG(DEBUG) << "Endgraph already compiled"; |
|
|
|
return; |
|
|
|
} |
|
|
|
@@ -1582,8 +1583,8 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o |
|
|
|
inputs.emplace_back(GetInput(tuple[i], false)); |
|
|
|
} |
|
|
|
auto cnode = curr_g_->NewCNode(inputs); |
|
|
|
set_node_map(curr_g_, out, cnode); |
|
|
|
set_node_map(curr_g_, out_id, cnode); |
|
|
|
SetTupleArgsToGraphInfoMap(curr_g_, out, cnode); |
|
|
|
SetNodeMapInGraphInfoMap(curr_g_, out_id, cnode); |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "Set ValueNode as output for graph, out id: " << out_id; |
|
|
|
MakeValueNode(out, out_id); |
|
|
|
@@ -1601,21 +1602,21 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje |
|
|
|
|
|
|
|
auto newfg = MakeGradGraph(cell, args); |
|
|
|
|
|
|
|
if (graph_context_.size() > 1) { |
|
|
|
if (graph_stack_.size() > 1) { |
|
|
|
std::vector<AnfNodePtr> inputs; |
|
|
|
inputs.emplace_back(NewValueNode(curr_g_)); |
|
|
|
|
|
|
|
Popp(); |
|
|
|
PopGraphStack(); |
|
|
|
// connect the previous graph to the inside graph |
|
|
|
auto graph_prev = graph_context_.top(); |
|
|
|
auto graph_prev = graph_stack_.top(); |
|
|
|
for (size_t i = 0; i < args.size(); i++) { |
|
|
|
auto input = GetInput(args[i], false); |
|
|
|
inputs.emplace_back(input); |
|
|
|
} |
|
|
|
auto out_cnode = graph_prev->NewCNode(inputs); |
|
|
|
set_pyobj(graph_prev, GetCellId(cell, args)); |
|
|
|
set_node_map(graph_prev, out, out_cnode); |
|
|
|
set_node_map(graph_prev, GetId(out), out_cnode); |
|
|
|
SetPyObjInGraphInfoMap(graph_prev, GetCellId(cell, args)); |
|
|
|
SetTupleArgsToGraphInfoMap(graph_prev, out, out_cnode); |
|
|
|
SetNodeMapInGraphInfoMap(graph_prev, GetId(out), out_cnode); |
|
|
|
} else { |
|
|
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) { |
|
|
|
DumpIR("before_resolve.ir", newfg); |
|
|
|
@@ -1625,7 +1626,7 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje |
|
|
|
DumpIR("after_resolve.ir", newfg); |
|
|
|
} |
|
|
|
resource_->set_func_graph(newfg); |
|
|
|
Popp(); |
|
|
|
PopGraphStack(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1647,7 +1648,7 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const py::a |
|
|
|
} |
|
|
|
} |
|
|
|
// Obtain grad graph |
|
|
|
auto newfg = ad::Grad(curr_g_, resource_, graph_context_.size() == 1); |
|
|
|
auto newfg = ad::Grad(curr_g_, resource_, graph_stack_.size() == 1); |
|
|
|
graph_info_map_.erase(curr_g_); |
|
|
|
|
|
|
|
if (need_replace_param) { |
|
|
|
@@ -1986,7 +1987,7 @@ void PynativeExecutor::Clear(const std::string &flag) { |
|
|
|
op_id_map_.clear(); |
|
|
|
obj_to_forward_id_.clear(); |
|
|
|
node_abs_map_.clear(); |
|
|
|
std::stack<FuncGraphPtr>().swap(graph_context_); |
|
|
|
std::stack<FuncGraphPtr>().swap(graph_stack_); |
|
|
|
ConfigManager::GetInstance().ResetIterNum(); |
|
|
|
} |
|
|
|
|
|
|
|
|