|
|
|
@@ -705,52 +705,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { |
|
|
|
MS_LOG(INFO) << "AssignStaticMemoryValueNode end"; |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "SyncValueNodeDeviceAddr start"; |
|
|
|
for (auto &value_node : graph->graph_value_nodes()) { |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto &node_value = value_node->value(); |
|
|
|
MS_EXCEPTION_IF_NULL(node_value); |
|
|
|
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::vector<tensor::TensorPtr> tensors; |
|
|
|
TensorValueToTensor(node_value, &tensors); |
|
|
|
for (size_t index = 0; index < tensors.size(); index += 1) { |
|
|
|
const auto &tensor = tensors[index]; |
|
|
|
if (tensor->device_address() != nullptr) { |
|
|
|
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), index, |
|
|
|
value_node.get()); |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "SyncValueNodeDeviceAddr end"; |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_LOG(INFO) << "CleanValueNodeDeviceAddr start"; |
|
|
|
for (auto &value_node : graph->graph_value_nodes()) { |
|
|
|
MS_EXCEPTION_IF_NULL(value_node); |
|
|
|
auto &node_value = value_node->value(); |
|
|
|
MS_EXCEPTION_IF_NULL(node_value); |
|
|
|
if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
std::vector<tensor::TensorPtr> tensors; |
|
|
|
TensorValueToTensor(node_value, &tensors); |
|
|
|
for (size_t index = 0; index < tensors.size(); index += 1) { |
|
|
|
if (tensors[index]->device_address() != nullptr) { |
|
|
|
AnfAlgo::SetOutputAddr(nullptr, index, value_node.get()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "CleanValueNodeDeviceAddr end"; |
|
|
|
} |
|
|
|
|
|
|
|
void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_); |
|
|
|
|