Browse Source

!13033 [Pynative]Using Cache and Code Optimization to Improve Performance of AMP

From: @chenyijie6
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
e54f6e4094
3 changed files with 86 additions and 34 deletions
  1. +77
    -30
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  2. +7
    -1
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h
  3. +2
    -3
      mindspore/core/utils/check_convert_utils.cc

+ 77
- 30
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -78,6 +78,7 @@ PynativeExecutorPtr PynativeExecutor::executor_ = nullptr;
ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr; ForwardExecutorPtr PynativeExecutor::forward_executor_ = nullptr;
GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr; GradExecutorPtr PynativeExecutor::grad_executor_ = nullptr;
std::mutex PynativeExecutor::instance_lock_; std::mutex PynativeExecutor::instance_lock_;
constexpr auto implcast = "implcast";


template <typename T, typename... Args> template <typename T, typename... Args>
void PynativeExecutorTry(std::function<void(T *ret, const Args &...)> method, T *ret, const Args &... args) { void PynativeExecutorTry(std::function<void(T *ret, const Args &...)> method, T *ret, const Args &... args) {
@@ -276,33 +277,42 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
for (size_t index = 0; index < input_tensors.size(); ++index) { for (size_t index = 0; index < input_tensors.size(); ++index) {
MS_EXCEPTION_IF_NULL(input_tensors[index]); MS_EXCEPTION_IF_NULL(input_tensors[index]);
auto tensor_shape = input_tensors[index]->shape(); auto tensor_shape = input_tensors[index]->shape();
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
[&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
(void)graph_info.append(std::to_string(input_tensors[index]->data_type()) + "_");
(void)std::for_each(tensor_shape.begin(), tensor_shape.end(), [&](const auto &dim) {
(void)graph_info.append(std::to_string(dim));
graph_info += "_";
});
(void)graph_info.append(std::to_string(input_tensors[index]->data_type()));
graph_info += "_";
auto tensor_addr = input_tensors[index]->device_address(); auto tensor_addr = input_tensors[index]->device_address();
if (tensor_addr != nullptr) { if (tensor_addr != nullptr) {
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id()) +
"_");
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format() + "_");
(void)graph_info.append(std::to_string(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->type_id()));
graph_info += "_";
(void)graph_info.append(std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr)->format());
graph_info += "_";
} }
if (static_cast<int64_t>(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) { if (static_cast<int64_t>(op_exec_info->inputs_mask[index]) == kValueNodeTensorMask) {
if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) { if (input_tensors[index]->Dtype()->type_id() == kNumberTypeInt64) {
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())) + "_");
(void)graph_info.append(std::to_string(*reinterpret_cast<int *>(input_tensors[index]->data_c())));
graph_info += "_";
} else if (input_tensors[index]->Dtype()->type_id() == kNumberTypeFloat32) { } else if (input_tensors[index]->Dtype()->type_id() == kNumberTypeFloat32) {
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c())) + "_");
(void)graph_info.append(std::to_string(*reinterpret_cast<float *>(input_tensors[index]->data_c())));
graph_info += "_";
} else { } else {
MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!"; MS_LOG(EXCEPTION) << "The dtype of the constant input is not int64 or float32!";
} }
} }
} }
// get prim and abstract info // get prim and abstract info
(void)graph_info.append(op_exec_info->op_name + "_");
graph_info += (op_exec_info->op_name);
graph_info += "_";
// get attr info // get attr info
const auto &op_prim = op_exec_info->py_primitive; const auto &op_prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(op_prim); MS_EXCEPTION_IF_NULL(op_prim);
const auto &attr_map = op_prim->attrs(); const auto &attr_map = op_prim->attrs();
(void)std::for_each(attr_map.begin(), attr_map.end(),
[&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); });
(void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
graph_info += (element.second->ToString());
graph_info += "_";
});


// Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing // Add output information(shape, type id) of the operator to graph_info to solve the problem of cache missing
// caused by operators like DropoutGenMask whose output is related to values of input when input shapes are // caused by operators like DropoutGenMask whose output is related to values of input when input shapes are
@@ -311,10 +321,12 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info,
MS_EXCEPTION_IF_NULL(abstr); MS_EXCEPTION_IF_NULL(abstr);
auto build_shape = abstr->BuildShape(); auto build_shape = abstr->BuildShape();
MS_EXCEPTION_IF_NULL(build_shape); MS_EXCEPTION_IF_NULL(build_shape);
(void)graph_info.append(build_shape->ToString() + "_");
graph_info += (build_shape->ToString());
graph_info += "_";
auto build_type = abstr->BuildType(); auto build_type = abstr->BuildType();
MS_EXCEPTION_IF_NULL(build_type); MS_EXCEPTION_IF_NULL(build_type);
(void)graph_info.append(std::to_string(build_type->type_id()) + "_");
graph_info += std::to_string(build_type->type_id());
graph_info += "_";


return graph_info; return graph_info;
} }
@@ -685,6 +697,26 @@ OpExecInfoPtr ForwardExecutor::GenerateOpExecInfo(const py::args &args) {
return op_exec_info; return op_exec_info;
} }


bool ForwardExecutor::FindOpMask(py::object obj, std::vector<int64_t> *op_masks, std::string id) {
bool op_mask = false;
auto temp = op_mask_map_.find(id);
if (temp != op_mask_map_.end()) {
op_mask = temp->second;
(*op_masks).emplace_back(op_mask);
} else {
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Gen args op_mask " << op_mask;
op_mask_map_[id] = op_mask;
(*op_masks).emplace_back(op_mask);
}
return op_mask;
}

void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) { std::vector<AnfNodePtr> *inputs, abstract::AbstractBasePtrList *args_spec_list) {
auto prim = op_exec_info->py_primitive; auto prim = op_exec_info->py_primitive;
@@ -696,15 +728,8 @@ void ForwardExecutor::GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector
if (it != node_abs_map_.end()) { if (it != node_abs_map_.end()) {
abs = it->second; abs = it->second;
} }
bool op_mask = false;
if (py::isinstance<tensor::MetaTensor>(obj)) {
auto meta_tensor = obj.cast<tensor::MetaTensorPtr>();
if (meta_tensor) {
op_mask = meta_tensor->is_parameter();
}
}
MS_LOG(DEBUG) << "Gen args i " << i << " op_mask " << op_mask;
(*op_masks).emplace_back(op_mask);
// Find the opmask of input obj
bool op_mask = FindOpMask(obj, op_masks, id);


// Construct grad graph // Construct grad graph
if (grad()->need_construct_graph()) { if (grad()->need_construct_graph()) {
@@ -798,16 +823,19 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
auto op_name = op_exec_info->op_name; auto op_name = op_exec_info->op_name;
auto prim = op_exec_info->py_primitive; auto prim = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) {
auto abs_list = prim_abs_list_[prim->id()];

auto temp = prim_abs_list_.find(prim->id());
if (temp != prim_abs_list_.end()) {
MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list); MS_LOG(DEBUG) << "Match prim input args " << op_name << mindspore::ToString(args_spec_list);
if (abs_list.find(args_spec_list) != abs_list.end()) {
auto iter = temp->second.find(args_spec_list);
if (iter != temp->second.end()) {
MS_LOG(DEBUG) << "Match prim ok " << op_name; MS_LOG(DEBUG) << "Match prim ok " << op_name;
op_exec_info->abstract = abs_list[args_spec_list].abs;
prim->set_evaluate_added_attrs(abs_list[args_spec_list].attrs);
op_exec_info->abstract = iter->second.abs;
prim->set_evaluate_added_attrs(iter->second.attrs);
*is_find = true; *is_find = true;
} }
} }

if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) { if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_name) != force_infer_prim.end()) {
// use python infer method // use python infer method
if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) { if (ignore_infer_prim.find(op_name) == ignore_infer_prim.end()) {
@@ -826,7 +854,7 @@ void ForwardExecutor::GetOpOutputAbstract(const OpExecInfoPtr &op_exec_info,
} }


py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name,
size_t index) {
size_t index, const std::string &obj_id) {
py::tuple cast_args(3); py::tuple cast_args(3);
cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast"); cast_args[PY_PRIM] = parse::python_adapter::GetPyFn(kOpsFunctionModelName, "cast");
cast_args[PY_NAME] = prim::kPrimCast->name(); cast_args[PY_NAME] = prim::kPrimCast->name();
@@ -840,6 +868,10 @@ py::object ForwardExecutor::DoAutoCast(const py::object &arg, const TypeId &type
op_exec->is_mixed_precision_cast = true; op_exec->is_mixed_precision_cast = true;
op_exec->next_op_name = op_name; op_exec->next_op_name = op_name;
op_exec->next_input_index = index; op_exec->next_input_index = index;
// Cache the cast struct
if (obj_id != implcast) {
cast_struct_map_[obj_id] = op_exec;
}
py::object ret = py::none(); py::object ret = py::none();
RunOpInner(&ret, op_exec); RunOpInner(&ret, op_exec);
return ret; return ret;
@@ -856,7 +888,20 @@ py::object ForwardExecutor::DoParamMixPrecisionCast(bool *is_cast, const py::obj
if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { if (source_element != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) {
MS_LOG(DEBUG) << "Cast to " << cast_type->ToString(); MS_LOG(DEBUG) << "Cast to " << cast_type->ToString();
*is_cast = true; *is_cast = true;
return DoAutoCast(obj, cast_type->type_id(), op_name, index);
// Get obj id
auto id = GetId(obj);
// Find obj id in unorder map
auto cast_struct_pair = cast_struct_map_.find(id);
if (cast_struct_pair != cast_struct_map_.end()) {
// Update input for cast struct
auto cast_struct = cast_struct_pair->second;
cast_struct->op_inputs[0] = obj;
py::object ret = py::none();
RunOpInner(&ret, cast_struct);
return ret;
} else {
return DoAutoCast(obj, cast_type->type_id(), op_name, index, id);
}
} }
} }
return cast_output; return cast_output;
@@ -937,7 +982,7 @@ void ForwardExecutor::DoSignatrueCast(const PrimitivePyPtr &prim, const std::map
<< py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is " << py::cast<std::string>(obj.attr("__class__").attr("__name__")) << ", and the value is "
<< py::cast<py::str>(obj) << "."; << py::cast<py::str>(obj) << ".";
} }
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i);
py::object cast_output = DoAutoCast(out_args[i], it->second, op_exec_info->op_name, i, implcast);
out_args[i] = cast_output; out_args[i] = cast_output;
} }
} }
@@ -1474,6 +1519,8 @@ void ForwardExecutor::ClearRes() {
MS_LOG(DEBUG) << "Clear forward res"; MS_LOG(DEBUG) << "Clear forward res";
prim_abs_list_.clear(); prim_abs_list_.clear();
node_abs_map_.clear(); node_abs_map_.clear();
cast_struct_map_.clear();
op_mask_map_.clear();
cell_op_index_with_tensor_id_.clear(); cell_op_index_with_tensor_id_.clear();
cell_tensor_id_with_tensor_.clear(); cell_tensor_id_with_tensor_.clear();
} }


+ 7
- 1
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -406,6 +406,7 @@ class ForwardExecutor {
PynativeStatusCode *status); PynativeStatusCode *status);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks,
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
bool FindOpMask(py::object obj, std::vector<int64_t> *op_masks, std::string id);
void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs, void GetArgsSpec(const OpExecInfoPtr &op_exec_info, std::vector<int64_t> *op_masks, std::vector<AnfNodePtr> *inputs,
abstract::AbstractBasePtrList *args_spec_list); abstract::AbstractBasePtrList *args_spec_list);
abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj, abstract::AbstractBasePtr CheckConstValue(const PrimitivePyPtr &prim, const py::object &obj,
@@ -420,7 +421,8 @@ class ForwardExecutor {
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index); py::object DoParamMixPrecisionCast(bool *is_cast, const py::object &obj, const std::string &op_name, size_t index);
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name, py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple &tuple, const std::string &op_name,
size_t index); size_t index);
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index,
const std::string &obj_id);
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type, void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info); const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);


@@ -431,6 +433,10 @@ class ForwardExecutor {
// Used for runop and replace forward result of grad graph // Used for runop and replace forward result of grad graph
std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_; std::unordered_map<std::string, OpIndexWithTensorId> cell_op_index_with_tensor_id_;
std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_; std::unordered_map<std::string, TensorIdWithTensor> cell_tensor_id_with_tensor_;
// Used to cache cast struct
std::unordered_map<std::string, OpExecInfoPtr> cast_struct_map_;
// Used to cache op_mask
std::unordered_map<std::string, int64_t> op_mask_map_;
}; };


class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {


+ 2
- 3
mindspore/core/utils/check_convert_utils.cc View File

@@ -184,9 +184,8 @@ AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op
if (op_attr_map_it == PrimAttrConvertMap.end()) { if (op_attr_map_it == PrimAttrConvertMap.end()) {
return attr_pair; return attr_pair;
} }
auto op_attr_map = op_attr_map_it->second;
auto attr_pair_it = op_attr_map.find(attr_name);
if (attr_pair_it == op_attr_map.end()) {
auto attr_pair_it = op_attr_map_it->second.find(attr_name);
if (attr_pair_it == op_attr_map_it->second.end()) {
return attr_pair; return attr_pair;
} }




Loading…
Cancel
Save