|
|
|
@@ -101,12 +101,22 @@ void SyncMemory(void *dst, const void *src, uint64_t size, rtMemcpyKind_t kind) |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
auto execution_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE); |
|
|
|
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); |
|
|
|
MS_EXCEPTION_IF_NULL(runtime_instance); |
|
|
|
runtime_instance->SetContext(); |
|
|
|
auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); |
|
|
|
if (ret_rt_memcpy != RT_ERROR_NONE) { |
|
|
|
MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; |
|
|
|
|
|
|
|
// Only apply asynchronous copy in Pynative && RT_MEMCPY_HOST_TO_DEVICE mode |
|
|
|
if (execution_mode != kPynativeMode || kind != RT_MEMCPY_HOST_TO_DEVICE) { |
|
|
|
auto ret_rt_memcpy = rtMemcpy(dst, size, src, size, kind); |
|
|
|
if (ret_rt_memcpy != RT_ERROR_NONE) { |
|
|
|
MS_EXCEPTION(DeviceProcessError) << "rtMemcpy failed"; |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ret = runtime_instance->MemcpyAsync(dst, src, size, static_cast<int32_t>(kind)); |
|
|
|
if (!ret) { |
|
|
|
MS_EXCEPTION(DeviceProcessError) << "MemcpyAsync failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -527,7 +537,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size |
|
|
|
if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
SyncStream(); |
|
|
|
|
|
|
|
bool sync_ok = false; |
|
|
|
std::vector<size_t> host_shape; |
|
|
|
(void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), LongToSize); |
|
|
|
|