|
|
|
@@ -179,16 +179,16 @@ void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const |
|
|
|
if (AnfAlgo::IsControlOpExecInBackend(kernel)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
auto output_size = AnfAlgo::GetOutputTensorNum(kernel); |
|
|
|
for (size_t i = 0; i < output_size; ++i) { |
|
|
|
auto kernel_mod = AnfAlgo::GetKernelMod(kernel); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_mod); |
|
|
|
auto output_sizes = kernel_mod->GetOutputSizeList(); |
|
|
|
for (size_t i = 0; i < output_sizes.size(); ++i) { |
|
|
|
if (AnfAlgo::OutputAddrExist(kernel, i)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto output_format = AnfAlgo::GetOutputFormat(kernel, i); |
|
|
|
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); |
|
|
|
auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i); |
|
|
|
auto device_address = device_context->CreateDeviceAddress(nullptr, address_size, output_format, output_type); |
|
|
|
auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); |
|
|
|
MS_LOG(DEBUG) << "Create addr for node:" << AnfAlgo::GetNodeDebugString(kernel) << " addr:" << device_address; |
|
|
|
AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); |
|
|
|
} |
|
|
|
|