|
|
|
@@ -625,7 +625,10 @@ void UpdateTensorInfo(const tensor::TensorPtr &new_tensor, const std::vector<ten |
|
|
|
pre_tensor->set_data_type(new_tensor->data_type()); |
|
|
|
if (device_target != kCPUDevice) { |
|
|
|
pre_tensor->set_device_address(new_tensor->device_address()); |
|
|
|
} else { |
|
|
|
continue; |
|
|
|
} |
|
|
|
// Replace data in device address when run in CPU device. |
|
|
|
if (pre_tensor->device_address() != nullptr) { |
|
|
|
auto old_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(pre_tensor->device_address()); |
|
|
|
auto new_device_address = std::dynamic_pointer_cast<device::DeviceAddress>(new_tensor->device_address()); |
|
|
|
auto old_ptr = old_device_address->GetMutablePtr(); |
|
|
|
@@ -708,7 +711,24 @@ bool TopCellInfo::IsSubCell(const std::string &cell_id) const { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
void TopCellInfo::clear() { |
|
|
|
void TopCellInfo::ClearDeviceMemory() { |
|
|
|
MS_LOG(DEBUG) << "Clear device memory in value nodes of bprop graph, top cell: " << cell_id_; |
|
|
|
auto device_target = MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET); |
|
|
|
if (device_target == kCPUDevice) { |
|
|
|
MS_LOG(DEBUG) << "No need to clear device address when run in CPU device."; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
k_pynative_cell_ptr_ = nullptr; |
|
|
|
for (const auto &elem : tensor_id_with_tensor_object_) { |
|
|
|
std::for_each(elem.second.begin(), elem.second.end(), [](const tensor::TensorPtr &tensor) { |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
tensor->set_device_address(nullptr); |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void TopCellInfo::Clear() { |
|
|
|
MS_LOG(DEBUG) << "Clear top cell info. Cell id " << cell_id_; |
|
|
|
op_num_ = 0; |
|
|
|
is_dynamic_ = false; |
|
|
|
@@ -1491,21 +1511,11 @@ void GradExecutor::UpdateForwardTensorInfoInBpropGraph(const OpExecInfoPtr &op_e |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) { |
|
|
|
void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr &resource) const { |
|
|
|
MS_EXCEPTION_IF_NULL(resource); |
|
|
|
// Get all tensors id belong to forward op |
|
|
|
std::unordered_set<std::string> forward_op_tensor_id; |
|
|
|
const auto &op_info_with_tensor_id = top_cell()->op_info_with_tensor_id(); |
|
|
|
for (const auto &e : op_info_with_tensor_id) { |
|
|
|
std::for_each(e.second.begin(), e.second.end(), |
|
|
|
[&forward_op_tensor_id](const std::string &tensor_id) { forward_op_tensor_id.emplace(tensor_id); }); |
|
|
|
} |
|
|
|
auto &tensor_id_with_tensor_object_ = top_cell()->tensor_id_with_tensor_object(); |
|
|
|
if (!tensor_id_with_tensor_object_.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "When compile a new graph, the map tensor_id_with_tensor_object should be empty. Top cell " |
|
|
|
<< top_cell()->cell_id(); |
|
|
|
} |
|
|
|
// Get all tensors obj in value node of bprop graph |
|
|
|
const auto &bprop_graph = resource->func_graph(); |
|
|
|
MS_EXCEPTION_IF_NULL(bprop_graph); |
|
|
|
const auto &value_node_list = bprop_graph->value_nodes(); |
|
|
|
std::vector<tensor::TensorPtr> tensors_in_bprop_graph; |
|
|
|
for (const auto &elem : value_node_list) { |
|
|
|
@@ -1513,14 +1523,18 @@ void GradExecutor::SaveForwardTensorInfoInBpropGraph(const pipeline::ResourcePtr |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
TensorValueToTensor(value_node->value(), &tensors_in_bprop_graph); |
|
|
|
} |
|
|
|
// Save tensor info in bprop graph |
|
|
|
|
|
|
|
auto &tensor_id_with_tensor_object = top_cell()->tensor_id_with_tensor_object(); |
|
|
|
if (!tensor_id_with_tensor_object.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "When compile a top graph, the tensor_id_with_tensor_object map should be empty. Top cell: " |
|
|
|
<< top_cell()->cell_id(); |
|
|
|
} |
|
|
|
// Save tensor in value node of bprop graph |
|
|
|
for (const auto &tensor : tensors_in_bprop_graph) { |
|
|
|
if (tensor->device_address() == nullptr || forward_op_tensor_id.find(tensor->id()) == forward_op_tensor_id.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
tensor_id_with_tensor_object_[tensor->id()].emplace_back(tensor); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
tensor_id_with_tensor_object[tensor->id()].emplace_back(tensor); |
|
|
|
MS_LOG(DEBUG) << "Save forward tensor " << tensor.get() << " id " << tensor->id() |
|
|
|
<< " device address: " << tensor->device_address()->GetMutablePtr() << " shape and dtype " |
|
|
|
<< " device address: " << tensor->device_address() << " shape and dtype " |
|
|
|
<< tensor->GetShapeAndDataTypeInfo(); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -1825,7 +1839,7 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) { |
|
|
|
MS_LOG(DEBUG) << "Clear all cell resources"; |
|
|
|
clear_all_cell_res = true; |
|
|
|
for (const auto &iter : top_cell_list_) { |
|
|
|
iter->clear(); |
|
|
|
iter->Clear(); |
|
|
|
} |
|
|
|
top_cell_list_.clear(); |
|
|
|
already_run_top_cell_.clear(); |
|
|
|
@@ -1840,7 +1854,7 @@ void GradExecutor::ClearCellRes(const std::string &cell_id) { |
|
|
|
for (auto it = top_cell_list_.begin(); it != top_cell_list_.end();) { |
|
|
|
auto top_cell_id = (*it)->cell_id(); |
|
|
|
if (IsCellObjIdEq(cell_id, top_cell_id)) { |
|
|
|
(*it)->clear(); |
|
|
|
(*it)->Clear(); |
|
|
|
it = top_cell_list_.erase(it); |
|
|
|
if (already_run_top_cell_.find(top_cell_id) != already_run_top_cell_.end()) { |
|
|
|
(void)already_run_top_cell_.erase(top_cell_id); |
|
|
|
@@ -2476,13 +2490,13 @@ void GradExecutor::CheckNeedCompileGraph() { |
|
|
|
if (pre_all_op_info != new_all_op_info) { |
|
|
|
MS_LOG(DEBUG) << "The op info has been changed, need to compile graph again"; |
|
|
|
EraseTopCellFromTopCellList(pre_top_cell); |
|
|
|
pre_top_cell->clear(); |
|
|
|
pre_top_cell->Clear(); |
|
|
|
already_run_top_cell_[top_cell_id] = new_top_cell; |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "The op info has not been changed, no need to compile graph again"; |
|
|
|
pre_top_cell->set_input_args_id(new_top_cell->input_args_id()); |
|
|
|
EraseTopCellFromTopCellList(new_top_cell); |
|
|
|
new_top_cell->clear(); |
|
|
|
new_top_cell->Clear(); |
|
|
|
pre_top_cell->set_forward_already_run(true); |
|
|
|
set_top_cell(pre_top_cell); |
|
|
|
} |
|
|
|
@@ -2521,6 +2535,10 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p |
|
|
|
grad_is_running_ = false; |
|
|
|
MS_LOG(DEBUG) << "Eval run end " << value.ToString(); |
|
|
|
*ret = BaseRefToPyData(value); |
|
|
|
|
|
|
|
if (GetHighOrderStackSize() == 1) { |
|
|
|
top_cell()->ClearDeviceMemory(); |
|
|
|
} |
|
|
|
if (top_cell()->vm_compiled()) { |
|
|
|
MakeNestedCnode(cell, cell_id, forward_args, resource, *ret); |
|
|
|
} else if (GetHighOrderStackSize() >= 2) { |
|
|
|
|