|
|
|
@@ -523,8 +523,9 @@ PynativeExecutor::~PynativeExecutor() { |
|
|
|
py::tuple RunOp(const py::args &args) { |
|
|
|
auto executor = PynativeExecutor::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(executor); |
|
|
|
MS_LOG(DEBUG) << "RunOp start " << args.size(); |
|
|
|
OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); |
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info); |
|
|
|
MS_LOG(DEBUG) << "RunOp name: " << op_exec_info->op_name << " start, args: " << args.size(); |
|
|
|
try { |
|
|
|
return executor->RunOpInner(op_exec_info); |
|
|
|
} catch (const py::error_already_set &ex) { |
|
|
|
@@ -708,6 +709,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v |
|
|
|
} |
|
|
|
node_abs_map_[id] = abs; |
|
|
|
} |
|
|
|
|
|
|
|
(*args_spec_list).emplace_back(abs); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1014,6 +1016,10 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex |
|
|
|
auto &new_tensor = output_tensors[i]; |
|
|
|
auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id]; |
|
|
|
std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { |
|
|
|
MS_LOG(DEBUG) << "Debug address: Replace forward old tensor obj " << tensor.get() << ", tensor id " |
|
|
|
<< tensor->id() << ", device address " << tensor->device_address().get() |
|
|
|
<< " with New tensor obj " << new_tensor.get() << ", tensor id " << new_tensor->id() |
|
|
|
<< ", device address " << new_tensor->device_address().get(); |
|
|
|
tensor->set_shape(new_tensor->shape()); |
|
|
|
tensor->set_data_type(new_tensor->data_type()); |
|
|
|
if (target != kCPUDevice) { |
|
|
|
@@ -1048,6 +1054,8 @@ void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { |
|
|
|
for (const auto &tensor : tensors) { |
|
|
|
if (tensor->device_address() != nullptr) { |
|
|
|
tensor_id_with_tensor_[tensor->id()].emplace_back(tensor); |
|
|
|
MS_LOG(DEBUG) << "Debug address: Save forward tensor obj " << tensor.get() << ", tensor id " << tensor->id() |
|
|
|
<< ", device address " << tensor->device_address().get(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|