|
|
|
@@ -78,6 +78,7 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr; |
|
|
|
ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr; |
|
|
|
GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr; |
|
|
|
std::mutex PynativeExecutor::instance_lock_; |
|
|
|
constexpr auto implcast = "implcast"; |
|
|
|
|
|
|
|
template <typename T, typename... Args> |
|
|
|
void PynativeExecutorTry(std::function<void(T *ret, const Args &...)> method, T *ret, const Args &... args) { |
|
|
|
@@ -276,33 +277,42 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, |
|
|
|
for (size_t index = 0; index < input_tensors.size(); ++index) { |
|
|
|
MS_EXCEPTION_IF_NULL(input_tensors[index]); |
|
|
|
auto tensor_shape = input_tensors[index]->shape(); |
|
|
|
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(), |
|
|
|
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); |
|
|
|
(void)graph_info.append(std::to_string(input_tensors[index]->data_type()) + "_"); |
|
|
|
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(), [&](const auto &dim) { |
|
|
|
(void)graph_info.append(std::to_string(dim)); |
|
|
|
graph_info += "_"; |
|
|
|
}); |
|
|
|
(void)graph_info.append(std::to_string(input_tensors[index]->data_type())); |
|
|
|
graph_info += "_"; |
|
|
|
auto tensor_addr = input_tensors[index]->device_address(); |
|
|
|
if (tensor_addr != nullptr) { |
|
|
|
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id()) + |
|
|
|
"_"); |
|
|
|
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format() + "_"); |
|
|
|
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id())); |
|
|
|
graph_info += "_"; |
|
|
|
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format()); |
|
|
|
graph_info += "_"; |
|
|
|
} |
|
|
|
if (static_cast<int64_t>(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) { |
|
|
|
if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) { |
|
|
|
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())) + "_"); |
|
|
|
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c()))); |
|
|
|
graph_info += "_"; |
|
|
|
} else if (input_tensors[index]->Dtype()->type_id() == kNumberTypeFloat32) { |
|
|
|
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c())) + "_"); |
|
|
|
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c()))); |
|
|
|
graph_info += "_"; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// get prim and abstract info |
|
|
|
(void)graph_info.append(op_exec_info->op_name + "_"); |
|
|
|
graph_info += (op_exec_info->op_name); |
|
|
|
graph_info += "_"; |
|
|
|
// get attr info |
|
|
|
const auto &op_prim = op_exec_info->py_primitive; |
|
|
|
MS_EXCEPTION_IF_NULL(op_prim); |
|
|
|
const auto &attr_map = op_prim->attrs(); |
|
|
|
(void)std::for_each(attr_map.begin(), attr_map.end(), |
|
|
|
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); |
|
|
|
(void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) { |
|
|
|
graph_info += (element.second->ToString()); |
|
|
|
graph_info += "_"; |
|
|
|
}); |
|
|
|
|
|
|
|
// Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing |
|
|
|
// caused by operators like DropoutGenMask whose output is related to values of input when input shapes are |
|
|
|
@@ -311,10 +321,12 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, |
|
|
|
MS_EXCEPTION_IF_NULL(abstr); |
|
|
|
auto build_shape = abstr->BuildShape(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_shape); |
|
|
|
(void)graph_info.append(build_shape->ToString() + "_"); |
|
|
|
graph_info += (build_shape->ToString()); |
|
|
|
graph_info += "_"; |
|
|
|
auto build_type = abstr->BuildType(); |
|
|
|
MS_EXCEPTION_IF_NULL(build_type); |
|
|
|
(void)graph_info.append(std::to_string(build_type->type_id()) + "_"); |
|
|
|
graph_info += std::to_string(build_type->type_id()); |
|
|
|
graph_info += "_"; |
|
|
|
|
|
|
|
return graph_info; |
|
|
|
} |
|
|
|
@@ -685,6 +697,26 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) { |
|
|
|
return op_exec_info; |
|
|
|
} |
|
|
|
|
|
|
|
bool ForwardExecutor::FindOpMask(py::object obj, std::vector<int64_t> *op_masks, std::string id) { |
|
|
|
bool op_mask = false; |
|
|
|
auto temp = op_mask_map_.find(id); |
|
|
|
if (temp != op_mask_map_.end()) { |
|
|
|
op_mask = temp->second; |
|
|
|
(*op_masks).emplace_back(op_mask); |
|
|
|
} else { |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>(); |
|
|
|
if (meta_tensor) { |
|
|
|
op_mask = meta_tensor->is_parameter(); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Gen args op_mask " << op_mask; |
|
|
|
op_mask_map_[id] = op_mask; |
|
|
|
(*op_masks).emplace_back(op_mask); |
|
|
|
} |
|
|
|
return op_mask; |
|
|
|
} |
|
|
|
|
|
|
|
void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, |
|
|
|
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) { |
|
|
|
auto prim = op_exec_info->py_primitive; |
|
|
|
@@ -696,15 +728,8 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector |
|
|
|
if (it != node_abs_map_.end()) { |
|
|
|
abs = it->second; |
|
|
|
} |
|
|
|
bool op_mask = false; |
|
|
|
if (py::isinstance<tensor::MetaTensor>(obj)) { |
|
|
|
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>(); |
|
|
|
if (meta_tensor) { |
|
|
|
op_mask = meta_tensor->is_parameter(); |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask; |
|
|
|
(*op_masks).emplace_back(op_mask); |
|
|
|
// Find the opmask of input obj |
|
|
|
bool op_mask = FindOpMask(obj, op_masks, id); |
|
|
|
|
|
|
|
// Construct grad graph |
|
|
|
if (grad()->need_construct_graph()) { |
|
|
|
@@ -798,16 +823,19 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, |
|
|
|
auto op_name = op_exec_info->op_name; |
|
|
|
auto prim = op_exec_info->py_primitive; |
|
|
|
MS_EXCEPTION_IF_NULL(prim); |
|
|
|
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { |
|
|
|
auto abs_list = prim_abs_list_[prim->id()]; |
|
|
|
|
|
|
|
auto temp = prim_abs_list_.find(prim->id()); |
|
|
|
if (temp != prim_abs_list_.end()) { |
|
|
|
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list); |
|
|
|
if (abs_list.find(args_spec_list) != abs_list.end()) { |
|
|
|
auto iter = temp->second.find(args_spec_list); |
|
|
|
if (iter != temp->second.end()) { |
|
|
|
MS_LOG(DEBUG) << "Match prim ok " << op_name; |
|
|
|
op_exec_info->abstract = abs_list[args_spec_list].abs; |
|
|
|
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs); |
|
|
|
op_exec_info->abstract = iter->second.abs; |
|
|
|
prim->set_evaluate_added_attrs(iter->second.attrs); |
|
|
|
*is_find = true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) { |
|
|
|
// use python infer method |
|
|
|
if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) { |
|
|
|
@@ -826,7 +854,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info, |
|
|
|
} |
|
|
|
|
|
|
|
py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, |
|
|
|
size_t index) { |
|
|
|
size_t index, const std::string &obj_id) { |
|
|
|
py::tuple cast_args(3); |
|
|
|
cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); |
|
|
|
cast_args[PY_NAME] = prim::kPrimCast->name(); |
|
|
|
@@ -840,6 +868,10 @@ py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type |
|
|
|
op_exec->is_mixed_precision_cast = true; |
|
|
|
op_exec->next_op_name = op_name; |
|
|
|
op_exec->next_input_index = index; |
|
|
|
// Cache the cast struct |
|
|
|
if (obj_id != implcast) { |
|
|
|
cast_struct_map_[obj_id] = op_exec; |
|
|
|
} |
|
|
|
py::object ret = py::none(); |
|
|
|
RunOpInner(&ret, op_exec); |
|
|
|
return ret; |
|
|
|
@@ -856,7 +888,20 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj |
|
|
|
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { |
|
|
|
MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); |
|
|
|
*is_cast = true; |
|
|
|
return DoAutoCast(obj, cast_type->type_id(), op_name, index); |
|
|
|
// Get obj id |
|
|
|
auto id = GetId(obj); |
|
|
|
// Find obj id in unorder map |
|
|
|
auto cast_struct_pair = cast_struct_map_.find(id); |
|
|
|
if (cast_struct_pair != cast_struct_map_.end()) { |
|
|
|
// Update input for cast struct |
|
|
|
auto cast_struct = cast_struct_pair->second; |
|
|
|
cast_struct->op_inputs[0] = obj; |
|
|
|
py::object ret = py::none(); |
|
|
|
RunOpInner(&ret, cast_struct); |
|
|
|
return ret; |
|
|
|
} else { |
|
|
|
return DoAutoCast(obj, cast_type->type_id(), op_name, index, id); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return cast_output; |
|
|
|
@@ -937,7 +982,7 @@ void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map |
|
|
|
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is " |
|
|
|
<< py::cast<py::str>(obj) << "."; |
|
|
|
} |
|
|
|
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i); |
|
|
|
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i, implcast); |
|
|
|
out_args[i] = cast_output; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1474,6 +1519,8 @@ void ForwardExecutor::ClearRes() { |
|
|
|
MS_LOG(DEBUG) << "Clear forward res"; |
|
|
|
prim_abs_list_.clear(); |
|
|
|
node_abs_map_.clear(); |
|
|
|
cast_struct_map_.clear(); |
|
|
|
op_mask_map_.clear(); |
|
|
|
cell_op_index_with_tensor_id_.clear(); |
|
|
|
cell_tensor_id_with_tensor_.clear(); |
|
|
|
} |
|
|
|
|