Browse Source

!17817 Fix memory not enough when running dynamic shape in pynative

Merge pull request !17817 from JoyLvliang/fix_memory_not_enough_in_pynative_mode
tags/v1.3.0
i-robot Gitee 4 years ago
parent
commit
032b03b4b0
3 changed files with 54 additions and 35 deletions
  1. +8
    -8
      mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc
  2. +43
    -25
      mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
  3. +3
    -2
      mindspore/ccsrc/pipeline/pynative/pynative_execute.h

+ 8
- 8
mindspore/ccsrc/frontend/optimizer/ad/kpynative.cc View File

@@ -449,7 +449,7 @@ void KPynativeCellImpl::UpdateOutputNodeOfTopCell(const AnfNodePtr &output_node)
anfnode_to_adjoin_.insert(std::make_pair(output_node, v_node_pynative_adjoint));
return;
}
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->ToString();
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString();
}
}

@@ -668,13 +668,13 @@ bool KPynativeCellImpl::BackPropagate(const CNodePtr &cnode_primal, const CNodeP
}
auto cnode_input = input->cast<CNodePtr>();
if (cnode_input != nullptr && cnode_input->stop_gradient()) {
MS_LOG(DEBUG) << "Bypass accumulate dout to cnode with stop_gradient flag, cnode: " << input->ToString();
MS_LOG(DEBUG) << "Bypass accumulate dout to cnode with stop_gradient flag, cnode: " << input->DebugString();
continue;
}
// Backprop sens wrt inputs.
auto input_adjoint_iter = anfnode_to_adjoin_.find(input);
if (input_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->ToString();
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist input[" << i << "] " << input->DebugString();
}
AnfNodePtr din;
if (abstract_tuple != nullptr) {
@@ -774,7 +774,7 @@ bool KPynativeCellImpl::BackPropagateOneCNodeWithBPropFuncGraph(const CNodePtr &
bool KPynativeCellImpl::BackPropagateOneCNodeWithFPropFuncGraph(const CNodePtr &cnode,
const PynativeAdjointPtr &adjoint,
const FuncGraphPtr &fprop_fg, bool by_value) {
MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->ToString();
MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();

AnfNodePtrList node_list;
CNodePtr bprop_cnode;
@@ -821,10 +821,10 @@ bool KPynativeCellImpl::BackPropagate(bool by_value) {
}
auto cnode = iter->first->cast<CNodePtr>();
if (cnode->stop_gradient()) {
MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->ToString();
MS_LOG(DEBUG) << "Bypass backpropagate for cnode with stop_gradient flag: " << cnode->DebugString();
continue;
}
MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->ToString();
MS_LOG(DEBUG) << "BackPropagate for CNode: " << cnode->DebugString();
auto fg = iter->second->fg();
auto fg_type = iter->second->fg_type();

@@ -867,7 +867,7 @@ void KPynativeCellImpl::PropagateStopGradient() {
// Cut off the cnode only when it's not referred any more
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) ||
AllReferencesStopped(cnode)) {
MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->ToString();
MS_LOG(DEBUG) << "Set stop_gradient flag for " << cnode->DebugString();
cnode->set_stop_gradient(true);
}
}
@@ -955,7 +955,7 @@ void KPynativeCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool ha
MS_LOG(DEBUG) << "Last node info " << last_node_->DebugString();
auto last_node_adjoint_iter = anfnode_to_adjoin_.find(last_node_);
if (last_node_adjoint_iter == anfnode_to_adjoin_.end()) {
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->ToString();
MS_LOG(EXCEPTION) << "BackPropagate adjoint does not exist for input: " << last_node_->DebugString();
}
// Add sens parameter
if (has_sens_arg) {


+ 43
- 25
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc View File

@@ -625,7 +625,10 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten
pre_tensor->set_data_type(new_tensor->data_type());
if (device_target != kCPUDevice) {
pre_tensor->set_device_address(new_tensor->device_address());
} else {
continue;
}
// Replace data in device address when run in CPU device.
if (pre_tensor->device_address() != nullptr) {
auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address());
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address());
auto old_ptr = old_device_address->GetMutablePtr();
@@ -708,7 +711,24 @@ bool TopCellInfo::IsSubCell(const std::string &cell_id) const {
return false;
}

void TopCellInfo::clear() {
void TopCellInfo::ClearDeviceMemory() {
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_;
auto device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_target == kCPUDevice) {
MS_LOG(DEBUG) << "No need to clear device address when run in CPU device.";
return;
}

k_pynative_cell_ptr_ = nullptr;
for (const auto &elem : tensor_id_with_tensor_object_) {
std::for_each(elem.second.begin(), elem.second.end(), [](const tensor::TensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor);
tensor->set_device_address(nullptr);
});
}
}

void TopCellInfo::Clear() {
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_;
op_num_ = 0;
is_dynamic_ = false;
@@ -1491,21 +1511,11 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e
}
}

void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) {
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const {
MS_EXCEPTION_IF_NULL(resource);
// Get all tensors id belong to forward op
std::unordered_set<std::string> forward_op_tensor_id;
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id();
for (const auto &e : op_info_with_tensor_id) {
std::for_each(e.second.begin(), e.second.end(),
[&forward_op_tensor_id](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); });
}
auto &tensor_id_with_tensor_object_ = top_cell()->tensor_id_with_tensor_object();
if (!tensor_id_with_tensor_object_.empty()) {
MS_LOG(EXCEPTION) << "When compile a new graph, the map tensor_id_with_tensor_object should be empty. Top cell "
<< top_cell()->cell_id();
}
// Get all tensors obj in value node of bprop graph
const auto &bprop_graph = resource->func_graph();
MS_EXCEPTION_IF_NULL(bprop_graph);
const auto &value_node_list = bprop_graph->value_nodes();
std::vector<tensor::TensorPtr> tensors_in_bprop_graph;
for (const auto &elem : value_node_list) {
@@ -1513,14 +1523,18 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr
MS_EXCEPTION_IF_NULL(value_node);
TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph);
}
// Save tensor info in bprop graph

auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object();
if (!tensor_id_with_tensor_object.empty()) {
MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: "
<< top_cell()->cell_id();
}
// Save tensor in value node of bprop graph
for (const auto &tensor : tensors_in_bprop_graph) {
if (tensor->device_address() == nullptr || forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end()) {
continue;
}
tensor_id_with_tensor_object_[tensor->id()].emplace_back(tensor);
MS_EXCEPTION_IF_NULL(tensor);
tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor);
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id()
<< " device address: " << tensor->device_address()->GetMutablePtr() << " shape and dtype "
<< " device address: " << tensor->device_address() << " shape and dtype "
<< tensor->GetShapeAndDataTypeInfo();
}
}
@@ -1825,7 +1839,7 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) {
MS_LOG(DEBUG) << "Clear all cell resources";
clear_all_cell_res = true;
for (const auto &iter : top_cell_list_) {
iter->clear();
iter->Clear();
}
top_cell_list_.clear();
already_run_top_cell_.clear();
@@ -1840,7 +1854,7 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) {
for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) {
auto top_cell_id = (*it)->cell_id();
if (IsCellObjIdEq(cell_id, top_cell_id)) {
(*it)->clear();
(*it)->Clear();
it = top_cell_list_.erase(it);
if (already_run_top_cell_.find(top_cell_id) != already_run_top_cell_.end()) {
(void)already_run_top_cell_.erase(top_cell_id);
@@ -2476,13 +2490,13 @@ void GradExecutor::CheckNeedCompileGraph() {
if (pre_all_op_info != new_all_op_info) {
MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again";
EraseTopCellFromTopCellList(pre_top_cell);
pre_top_cell->clear();
pre_top_cell->Clear();
already_run_top_cell_[top_cell_id] = new_top_cell;
} else {
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
EraseTopCellFromTopCellList(new_top_cell);
new_top_cell->clear();
new_top_cell->Clear();
pre_top_cell->set_forward_already_run(true);
set_top_cell(pre_top_cell);
}
@@ -2521,6 +2535,10 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p
grad_is_running_ = false;
MS_LOG(DEBUG) << "Eval run end " << value.ToString();
*ret = BaseRefToPyData(value);

if (GetHighOrderStackSize() == 1) {
top_cell()->ClearDeviceMemory();
}
if (top_cell()->vm_compiled()) {
MakeNestedCnode(cell, cell_id, forward_args, resource, *ret);
} else if (GetHighOrderStackSize() >= 2) {


+ 3
- 2
mindspore/ccsrc/pipeline/pynative/pynative_execute.h View File

@@ -107,7 +107,8 @@ class TopCellInfo {
const FuncGraphPtr &grad_graph) {
ms_function_grad_cache_[graph_phase] = std::make_pair(func_graph, grad_graph);
}
void clear();
void ClearDeviceMemory();
void Clear();

private:
bool is_topest_{false};
@@ -190,7 +191,7 @@ class GradExecutor {
const OpExecInfoPtr &op_exec_info, ValuePtrList *input_values,
CNodePtr *ms_function_cnode);
void UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_exec_info, const py::object &out_real);
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource);
void SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const;
py::object CheckGraph(const py::object &cell, const py::args &args);
void RunGradGraph(py::object *ret, const py::object &cell, const py::tuple &args);
void EraseTopCellFromTopCellList(const TopCellInfoPtr &top_cell);


Loading…
Cancel
Save