|
|
|
@@ -458,7 +458,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode |
|
|
|
auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type); |
|
|
|
MS_EXCEPTION_IF_NULL(address); |
|
|
|
if (output_ptr == nullptr) { |
|
|
|
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address); |
|
|
|
output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true); |
|
|
|
MS_EXCEPTION_IF_NULL(output_ptr); |
|
|
|
} else { |
|
|
|
address->set_ptr(output_ptr); |
|
|
|
@@ -515,8 +515,17 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP |
|
|
|
MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input."; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, addr_size[0].first); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (cnode->inputs().size() < 2) { |
|
|
|
// communication node's input should contain itself and at least on input |
|
|
|
MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope(); |
|
|
|
return; |
|
|
|
} |
|
|
|
auto first_input_node = cnode->input(1); |
|
|
|
auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true); |
|
|
|
uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size, |
|
|
|
addr_size[0].first, true); |
|
|
|
for (const auto &iter : addr_size) { |
|
|
|
MS_EXCEPTION_IF_NULL(iter.first); |
|
|
|
iter.first->set_ptr(input_ptr); |
|
|
|
@@ -568,7 +577,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in |
|
|
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i); |
|
|
|
auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address); |
|
|
|
uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false); |
|
|
|
MS_EXCEPTION_IF_NULL(ptr); |
|
|
|
device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i)); |
|
|
|
AnfAlgo::SetOutputAddr(device_address, i, node.get()); |
|
|
|
|