|
|
|
@@ -48,6 +48,13 @@ bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std: |
|
|
|
MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (outputs[0]->size < inputs[0]->size) { |
|
|
|
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; |
|
|
|
} |
|
|
|
// input x -> memcpy_async -> AllReduce |
|
|
|
if (outputs[0]->size > inputs[0]->size) { |
|
|
|
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; |
|
|
|
} |
|
|
|
rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, |
|
|
|
RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr); |
|
|
|
if (status != RT_ERROR_NONE) { |
|
|
|
@@ -70,7 +77,7 @@ void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) { |
|
|
|
if (input_size != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1"; |
|
|
|
} |
|
|
|
input_type_id_ = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, 0); |
|
|
|
input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0); |
|
|
|
} |
|
|
|
|
|
|
|
void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) { |
|
|
|
@@ -102,6 +109,14 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr |
|
|
|
MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one"; |
|
|
|
} |
|
|
|
|
|
|
|
if (outputs[0]->size < inputs[0]->size) { |
|
|
|
MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size"; |
|
|
|
} |
|
|
|
// input x -> memcpy_async -> AllReduce |
|
|
|
if (outputs[0]->size > inputs[0]->size) { |
|
|
|
MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size"; |
|
|
|
} |
|
|
|
|
|
|
|
stream_id_ = stream_id; |
|
|
|
std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr = std::make_shared<MemcpyAsyncTaskInfo>( |
|
|
|
stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE); |
|
|
|
|