|
|
|
@@ -110,7 +110,40 @@ py::object GetTupleObj(const py::object &obj) { |
|
|
|
return obj_tuple; |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { |
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) { |
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexes; |
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) { |
|
|
|
auto it = type_indexes.find(dtypes[i]); |
|
|
|
if (it == type_indexes.end()) { |
|
|
|
(void)type_indexes.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); |
|
|
|
} else { |
|
|
|
it->second.push_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
return type_indexes; |
|
|
|
} |
|
|
|
|
|
|
|
std::map<SignatureEnumDType, size_t> GetDstType(const py::tuple &py_args, |
|
|
|
const std::map<SignatureEnumDType, std::vector<size_t>> &type_indexes) { |
|
|
|
std::map<SignatureEnumDType, size_t> dst_type; |
|
|
|
for (auto it = type_indexes.begin(); it != type_indexes.end(); (void)++it) { |
|
|
|
auto type = it->first; |
|
|
|
auto indexes = it->second; |
|
|
|
if (indexes.size() < 2) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t m_index = indexes[0]; |
|
|
|
for (size_t i = 1; i < indexes.size(); ++i) { |
|
|
|
if (py::isinstance<tensor::Tensor>(py_args[indexes[i]])) { |
|
|
|
m_index = indexes[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
(void)dst_type.insert(std::make_pair(type, m_index)); |
|
|
|
} |
|
|
|
return dst_type; |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *const out_args) { |
|
|
|
auto &py_args = *out_args; |
|
|
|
py::tuple input_mask(args.size()); |
|
|
|
for (size_t i = 0; i < args.size(); ++i) { |
|
|
|
@@ -129,30 +162,8 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu |
|
|
|
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) { |
|
|
|
return input_mask; |
|
|
|
} |
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs; |
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) { |
|
|
|
auto it = type_indexs.find(dtypes[i]); |
|
|
|
if (it == type_indexs.end()) { |
|
|
|
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i})); |
|
|
|
} else { |
|
|
|
it->second.push_back(i); |
|
|
|
} |
|
|
|
} |
|
|
|
std::map<SignatureEnumDType, size_t> dst_type; |
|
|
|
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { |
|
|
|
auto type = it->first; |
|
|
|
auto indexs = it->second; |
|
|
|
if (indexs.size() < 2) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
size_t m_index = indexs[0]; |
|
|
|
for (size_t i = 1; i < indexs.size(); ++i) { |
|
|
|
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) { |
|
|
|
m_index = indexs[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
(void)dst_type.insert(std::make_pair(type, m_index)); |
|
|
|
} |
|
|
|
auto type_indexes = GetTypeIndex(dtypes); |
|
|
|
auto dst_type = GetDstType(py_args, type_indexes); |
|
|
|
for (size_t i = 0; i < py_args.size(); ++i) { |
|
|
|
auto it = dst_type.find(dtypes[i]); |
|
|
|
if (it != dst_type.end() && it->second != i && |
|
|
|
@@ -542,28 +553,7 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) { |
|
|
|
return curr_g_->NewCNode(tuple_get_item_inputs); |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple RunOp(const py::args &args) { |
|
|
|
MS_LOG(DEBUG) << "RunOp start" << args.size(); |
|
|
|
py::object result; |
|
|
|
// returns a null py::tuple on error |
|
|
|
py::tuple err_ret(0); |
|
|
|
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; |
|
|
|
|
|
|
|
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info); |
|
|
|
if (op_exec_info->abstract != nullptr) { |
|
|
|
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); |
|
|
|
if (!output["value"].is_none()) { |
|
|
|
py::tuple value_ret(1); |
|
|
|
value_ret[0] = output["value"]; |
|
|
|
return value_ret; |
|
|
|
} |
|
|
|
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { |
|
|
|
py::tuple value_ret(1); |
|
|
|
value_ret[0] = ""; |
|
|
|
return value_ret; |
|
|
|
} |
|
|
|
} |
|
|
|
py::tuple RunOp(const OpExecInfoPtr &op_exec_info, const py::args &args) { |
|
|
|
MS_LOG(INFO) << "RunOp start, op name is: " << op_exec_info->op_name; |
|
|
|
mindspore::parse::python_adapter::set_python_env_flag(true); |
|
|
|
MsBackendPolicy backend_policy; |
|
|
|
@@ -584,7 +574,10 @@ py::tuple RunOp(const py::args &args) { |
|
|
|
if (vm_operators.find(op_exec_info->op_name) != vm_operators.end()) { |
|
|
|
backend_policy = kMsBackendVmOnly; |
|
|
|
} |
|
|
|
result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); |
|
|
|
PynativeStatusCode status = PYNATIVE_UNKNOWN_STATE; |
|
|
|
// returns a null py::tuple on error |
|
|
|
py::tuple err_ret(0); |
|
|
|
py::object result = RunOpWithBackendPolicy(backend_policy, op_exec_info, &status); |
|
|
|
if (status != PYNATIVE_SUCCESS) { |
|
|
|
MS_LOG(ERROR) << "Failed to run " << op_exec_info->op_name; |
|
|
|
return err_ret; |
|
|
|
@@ -599,6 +592,26 @@ py::tuple RunOp(const py::args &args) { |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
py::tuple RunOp(const py::args &args) { |
|
|
|
MS_LOG(DEBUG) << "RunOp start" << args.size(); |
|
|
|
OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info); |
|
|
|
if (op_exec_info->abstract != nullptr) { |
|
|
|
py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); |
|
|
|
if (!output["value"].is_none()) { |
|
|
|
py::tuple value_ret(1); |
|
|
|
value_ret[0] = output["value"]; |
|
|
|
return value_ret; |
|
|
|
} |
|
|
|
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { |
|
|
|
py::tuple value_ret(1); |
|
|
|
value_ret[0] = ""; |
|
|
|
return value_ret; |
|
|
|
} |
|
|
|
} |
|
|
|
return RunOp(op_exec_info, args); |
|
|
|
} |
|
|
|
|
|
|
|
void ClearPyNativeSession() { session = nullptr; } |
|
|
|
|
|
|
|
PynativeExecutor::~PynativeExecutor() { ClearRes(); } |
|
|
|
@@ -732,7 +745,11 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
EndGraphByOutId(out_id, cell, out, args); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, |
|
|
|
const py::args &args) { |
|
|
|
AnfNodePtr output_node; |
|
|
|
if (graph_info_map_[curr_g_].param_map.count(out_id)) { |
|
|
|
output_node = graph_info_map_[curr_g_].param_map[out_id]; |
|
|
|
@@ -776,27 +793,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, |
|
|
|
const py::args &args) { |
|
|
|
MS_LOG(INFO) << "GradNet start" << args.size(); |
|
|
|
|
|
|
|
std::size_t size = args.size(); |
|
|
|
auto cell_id = GetId(cell); |
|
|
|
if (graph_map_.count(cell_id) != 0) { |
|
|
|
MS_LOG(DEBUG) << "GradNet already compiled"; |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "GradNet first compiled"; |
|
|
|
std::vector<AnfNodePtr> new_params; |
|
|
|
for (size_t i = 0; i < size; i++) { |
|
|
|
ParameterPtr p = std::make_shared<Parameter>(df_builder_); |
|
|
|
new_params.push_back(p); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); |
|
|
|
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); |
|
|
|
df_builder_->set_parameters(new_params); |
|
|
|
resource_->manager()->SetParameters(df_builder_, new_params); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weights) { |
|
|
|
std::vector<AnfNodePtr> w_args; |
|
|
|
if (py::hasattr(weights, "__parameter_tuple__")) { |
|
|
|
auto tuple = weights.cast<py::tuple>(); |
|
|
|
@@ -821,12 +818,12 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "training not paramter_tuple"; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(resource_->func_graph()); |
|
|
|
auto g = GradGraph(resource_->func_graph(), grad, w_args, size); |
|
|
|
resource_->set_func_graph(g); |
|
|
|
return w_args; |
|
|
|
} |
|
|
|
|
|
|
|
// get the parameters items and add the value to args_spec |
|
|
|
abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args) { |
|
|
|
abstract::AbstractBasePtrList args_spec; |
|
|
|
std::size_t size = args.size(); |
|
|
|
for (std::size_t i = 0; i < size; i++) { |
|
|
|
ValuePtr converted = nullptr; |
|
|
|
bool succ = parse::ConvertData(args[i], &converted); |
|
|
|
@@ -852,6 +849,38 @@ void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &c |
|
|
|
param_node->set_abstract(ptr); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return args_spec; |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, |
|
|
|
const py::args &args) { |
|
|
|
MS_LOG(INFO) << "GradNet start" << args.size(); |
|
|
|
|
|
|
|
std::size_t size = args.size(); |
|
|
|
auto cell_id = GetId(cell); |
|
|
|
if (graph_map_.count(cell_id) != 0) { |
|
|
|
MS_LOG(DEBUG) << "GradNet already compiled"; |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "GradNet first compiled"; |
|
|
|
std::vector<AnfNodePtr> new_params; |
|
|
|
for (size_t i = 0; i < size; i++) { |
|
|
|
ParameterPtr p = std::make_shared<Parameter>(df_builder_); |
|
|
|
new_params.push_back(p); |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "GradNet start weight size" << df_builder_->parameters().size(); |
|
|
|
new_params.insert(new_params.end(), df_builder_->parameters().begin(), df_builder_->parameters().end()); |
|
|
|
df_builder_->set_parameters(new_params); |
|
|
|
resource_->manager()->SetParameters(df_builder_, new_params); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> w_args = GetWeightsArgs(weights); |
|
|
|
MS_EXCEPTION_IF_NULL(resource_->func_graph()); |
|
|
|
auto g = GradGraph(resource_->func_graph(), grad, w_args, size); |
|
|
|
resource_->set_func_graph(g); |
|
|
|
|
|
|
|
// get the parameters items and add the value to args_spec |
|
|
|
abstract::AbstractBasePtrList args_spec = GetArgsSpec(args); |
|
|
|
MS_LOG(DEBUG) << "Args_spec size" << args_spec.size(); |
|
|
|
|
|
|
|
resource_->set_args_spec(args_spec); |
|
|
|
|