|
|
|
@@ -49,18 +49,23 @@ void *TensorRTAllocator::MallocDeviceMem(const std::string &name, size_t size, D |
|
|
|
if (cuda_tensor_map_[name].data != nullptr) { |
|
|
|
cuda_ret = cudaFree(cuda_tensor_map_[name].data); |
|
|
|
if (cuda_ret != cudaSuccess && cuda_ret != cudaErrorCudartUnloading) { |
|
|
|
MS_LOG(ERROR) << "free cuda failed for " << cudaGetErrorName(cuda_ret); |
|
|
|
MS_LOG(ERROR) << "free old cuda device_ptr failed for " << cudaGetErrorName(cuda_ret); |
|
|
|
cuda_ret = cudaFree(device_ptr); |
|
|
|
if (cuda_ret != cudaSuccess) { |
|
|
|
MS_LOG(ERROR) << "free new cuda device_ptr failed for " << cudaGetErrorName(cuda_ret); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
} |
|
|
|
cuda_tensor_map_[name].data = device_ptr; |
|
|
|
cuda_tensor_map_[name].isValidMem = false; |
|
|
|
cuda_tensor_map_[name].is_valid_mem = false; |
|
|
|
cuda_tensor_map_[name].size = size; |
|
|
|
return device_ptr; |
|
|
|
} |
|
|
|
|
|
|
|
void TensorRTAllocator::MarkMemValid(const std::string &name, bool isValid) { |
|
|
|
cuda_tensor_map_[name].isValidMem = isValid; |
|
|
|
cuda_tensor_map_[name].is_valid_mem = isValid; |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -83,8 +88,8 @@ int TensorRTAllocator::SyncMemInHostAndDevice(mindspore::MSTensor host_tensor, c |
|
|
|
} |
|
|
|
CudaTensorParam ¤t_cuda_tensor = cuda_tensor_map_.find(device_tensor_name)->second; |
|
|
|
// is memcpy from device to host, the host mem is valid, change tag for mem pool. |
|
|
|
current_cuda_tensor.isValidMem = is_host2device ? current_cuda_tensor.isValidMem : true; |
|
|
|
if (is_host2device && current_cuda_tensor.isValidMem) { |
|
|
|
current_cuda_tensor.is_valid_mem = is_host2device ? current_cuda_tensor.is_valid_mem : true; |
|
|
|
if (is_host2device && current_cuda_tensor.is_valid_mem) { |
|
|
|
MS_LOG(INFO) << "no need memcpy for: " << device_tensor_name; |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -108,7 +113,7 @@ int TensorRTAllocator::ClearDeviceMem() { |
|
|
|
MS_LOG(WARNING) << "free cuda failed for " << cudaGetErrorName(cuda_ret); |
|
|
|
} |
|
|
|
iter.second.data = nullptr; |
|
|
|
iter.second.isValidMem = false; |
|
|
|
iter.second.is_valid_mem = false; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|