From: @zyli2020 Reviewed-by: @limingqi107 Signed-off-by: @limingqi107tags/v1.1.0
| @@ -304,7 +304,9 @@ Status DeviceQueueOp::PushDataToGPU() { | |||
| // Data prefetch only when PS mode enables cache. | |||
| if (items.size() > 0) { | |||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_); | |||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { | |||
| return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); | |||
| } | |||
| } | |||
| while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { | |||
| BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | |||
| @@ -55,7 +55,9 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe | |||
| #if ENABLE_D | |||
| // Data prefetch only when PS mode enables cache. | |||
| if (items.size() > 0) { | |||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); | |||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) { | |||
| return FAILED; | |||
| } | |||
| } | |||
| #endif | |||
| if (tdt::TdtHostPushData(channel_name, items) != 0) { | |||
| @@ -53,6 +53,7 @@ | |||
| #include "ps/util.h" | |||
| #include "ps/worker.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #endif | |||
| #if (ENABLE_GE || ENABLE_D) | |||
| @@ -1083,9 +1084,10 @@ void ClearResAtexit() { | |||
| pynative::ClearPyNativeSession(); | |||
| session::ClearPythonParasMap(); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsParamServerMode()) { | |||
| if (ps::Util::IsRoleOfWorker()) { | |||
| ps::worker.Finalize(); | |||
| if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { | |||
| ps::worker.Finalize(); | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::ps_cache_instance.Finalize(); | |||
| } | |||
| } | |||
| #endif | |||
| @@ -37,155 +37,178 @@ namespace ps { | |||
| namespace ascend { | |||
| MS_REG_PS_CACHE(kAscendDevice, AscendPsCache); | |||
| namespace { | |||
| void SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||
| bool SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||
| mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| MS_ERROR_IF_NULL(proto); | |||
| if (data_shape.size() != data_type.size()) { | |||
| MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; | |||
| MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type."; | |||
| return false; | |||
| } | |||
| for (size_t input_index = 0; input_index < data_shape.size(); input_index++) { | |||
| ::mindspore::Tensor *proto_inputs = proto->add_inputs(); | |||
| MS_EXCEPTION_IF_NULL(proto_inputs); | |||
| MS_ERROR_IF_NULL(proto_inputs); | |||
| auto input_shape = data_shape[input_index]; | |||
| mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape(); | |||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||
| MS_ERROR_IF_NULL(tensorShape); | |||
| for (auto item : input_shape) { | |||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | |||
| MS_EXCEPTION_IF_NULL(dim); | |||
| MS_ERROR_IF_NULL(dim); | |||
| dim->set_size((::google::protobuf::int64)item); | |||
| } | |||
| auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]); | |||
| proto_inputs->set_tensor_type(input_type); | |||
| proto_inputs->set_mem_device("HBM"); | |||
| } | |||
| return true; | |||
| } | |||
| void SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||
| bool SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||
| mindspore::NodeDef *proto) { | |||
| MS_EXCEPTION_IF_NULL(proto); | |||
| MS_ERROR_IF_NULL(proto); | |||
| if (data_shape.size() != data_type.size()) { | |||
| MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; | |||
| MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type."; | |||
| return false; | |||
| } | |||
| for (size_t output_index = 0; output_index < data_shape.size(); output_index++) { | |||
| ::mindspore::Tensor *proto_outputs = proto->add_outputs(); | |||
| MS_EXCEPTION_IF_NULL(proto_outputs); | |||
| MS_ERROR_IF_NULL(proto_outputs); | |||
| auto output_shape = data_shape[output_index]; | |||
| mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape(); | |||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||
| MS_ERROR_IF_NULL(tensorShape); | |||
| for (auto item : output_shape) { | |||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | |||
| MS_EXCEPTION_IF_NULL(dim); | |||
| MS_ERROR_IF_NULL(dim); | |||
| dim->set_size((::google::protobuf::int64)item); | |||
| } | |||
| auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]); | |||
| proto_outputs->set_tensor_type(output_type); | |||
| proto_outputs->set_mem_device("HBM"); | |||
| } | |||
| return true; | |||
| } | |||
| void SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info, | |||
| bool SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info, | |||
| const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||
| MS_ERROR_IF_NULL(op_info); | |||
| MS_ERROR_IF_NULL(kernel_mod_ptr); | |||
| mindspore::NodeDef proto; | |||
| proto.set_op(op_info->op_name_); | |||
| SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto); | |||
| SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto); | |||
| RETURN_IF_FALSE(SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto)); | |||
| RETURN_IF_FALSE(SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto)); | |||
| std::string nodeDefStr; | |||
| if (!proto.SerializeToString(&nodeDefStr)) { | |||
| MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed."; | |||
| MS_LOG(ERROR) << "Serialize nodeDef to string failed."; | |||
| return false; | |||
| } | |||
| MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_; | |||
| kernel_mod_ptr->SetNodeDef(nodeDefStr); | |||
| return true; | |||
| } | |||
| } // namespace | |||
| void AscendPsCache::InitDevice(uint32_t device_id, const void *context) { | |||
| MS_EXCEPTION_IF_NULL(context); | |||
| bool AscendPsCache::InitDevice(uint32_t device_id, const void *context) { | |||
| MS_ERROR_IF_NULL(context); | |||
| auto ret = rtSetDevice(device_id); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]"; | |||
| MS_LOG(ERROR) << "Call rtSetDevice, ret[" << ret << "]"; | |||
| return false; | |||
| } | |||
| auto rt_context = const_cast<rtContext_t>(context); | |||
| ret = rtCtxSetCurrent(rt_context); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||
| MS_LOG(ERROR) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||
| return false; | |||
| } | |||
| ret = rtStreamCreate(&stream_, 0); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]"; | |||
| MS_LOG(ERROR) << "Call rtStreamCreate, ret[" << ret << "]"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void *AscendPsCache::MallocMemory(size_t size) { | |||
| return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); | |||
| } | |||
| void AscendPsCache::MallocConstantMemory(size_t constant_value) { | |||
| bool AscendPsCache::MallocConstantMemory(size_t constant_value) { | |||
| offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | |||
| MS_EXCEPTION_IF_NULL(offset_addr_); | |||
| MS_ERROR_IF_NULL(offset_addr_); | |||
| rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); | |||
| cache_vocab_size_addr_ = | |||
| reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | |||
| MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_); | |||
| MS_ERROR_IF_NULL(cache_vocab_size_addr_); | |||
| rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); | |||
| return true; | |||
| } | |||
| void AscendPsCache::RecordEvent() { | |||
| bool AscendPsCache::RecordEvent() { | |||
| event_.reset(new rtEvent_t()); | |||
| auto ret = rtEventCreate(&(*event_)); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Create event failed"; | |||
| MS_LOG(ERROR) << "Create event failed"; | |||
| return false; | |||
| } | |||
| ret = rtEventRecord(*event_, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "Record event failed"; | |||
| MS_LOG(ERROR) << "Record event failed"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AscendPsCache::SynchronizeEvent() { | |||
| bool AscendPsCache::SynchronizeEvent() { | |||
| auto ret = rtEventSynchronize(*event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed"; | |||
| MS_LOG(ERROR) << "tEventSynchronize failed"; | |||
| return false; | |||
| } | |||
| ret = rtEventDestroy(*event_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed"; | |||
| MS_LOG(ERROR) << "rtEventDestroy failed"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AscendPsCache::SynchronizeStream() { | |||
| bool AscendPsCache::SynchronizeStream() { | |||
| auto ret = rtStreamSynchronize(stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed"; | |||
| MS_LOG(ERROR) << "rtStreamSynchronize failed"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; | |||
| MS_LOG(ERROR) << "rtMemcpyAsync failed"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; | |||
| MS_LOG(ERROR) << "rtMemcpyAsync failed"; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_index_addr); | |||
| MS_ERROR_IF_NULL(hash_table_addr); | |||
| MS_ERROR_IF_NULL(swap_out_value_addr); | |||
| MS_ERROR_IF_NULL(swap_out_index_addr); | |||
| auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | |||
| MS_EXCEPTION_IF_NULL(hash_swap_out_mod); | |||
| MS_ERROR_IF_NULL(hash_swap_out_mod); | |||
| hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName); | |||
| std::vector<std::vector<size_t>> input_shape; | |||
| std::vector<std::vector<size_t>> output_shape; | |||
| @@ -197,7 +220,7 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr | |||
| output_shape.push_back({swap_out_size, embedding_size}); | |||
| auto op_info = | |||
| std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type); | |||
| SetNodedefProto(op_info, hash_swap_out_mod); | |||
| RETURN_IF_FALSE(SetNodedefProto(op_info, hash_swap_out_mod)); | |||
| AddressPtrList kernel_inputs; | |||
| AddressPtrList kernel_outputs = { | |||
| @@ -208,17 +231,19 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr | |||
| kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | |||
| auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Hash swap out launch failed."; | |||
| MS_LOG(ERROR) << "Hash swap out launch failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_index_addr); | |||
| MS_ERROR_IF_NULL(hash_table_addr); | |||
| MS_ERROR_IF_NULL(swap_in_value_addr); | |||
| MS_ERROR_IF_NULL(swap_in_index_addr); | |||
| auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | |||
| MS_EXCEPTION_IF_NULL(hash_swap_in_mod); | |||
| MS_ERROR_IF_NULL(hash_swap_in_mod); | |||
| hash_swap_in_mod->SetNodeName(kernel::kUpdateCache); | |||
| std::vector<std::vector<size_t>> input_shape; | |||
| std::vector<std::vector<size_t>> output_shape; | |||
| @@ -245,8 +270,10 @@ void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, | |||
| kernel_outputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | |||
| auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | |||
| if (!ret) { | |||
| MS_LOG(EXCEPTION) << "Hash swap in launch failed."; | |||
| MS_LOG(ERROR) << "Hash swap in launch failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace ascend | |||
| } // namespace ps | |||
| @@ -49,17 +49,17 @@ class AscendPsCache : public PsCacheBasic { | |||
| public: | |||
| AscendPsCache() = default; | |||
| ~AscendPsCache() override = default; | |||
| void InitDevice(uint32_t device_id, const void *context) override; | |||
| bool InitDevice(uint32_t device_id, const void *context) override; | |||
| void *MallocMemory(size_t size) override; | |||
| void MallocConstantMemory(size_t constant_value) override; | |||
| void RecordEvent() override; | |||
| void SynchronizeEvent() override; | |||
| void SynchronizeStream() override; | |||
| void CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||
| bool MallocConstantMemory(size_t constant_value) override; | |||
| bool RecordEvent() override; | |||
| bool SynchronizeEvent() override; | |||
| bool SynchronizeStream() override; | |||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_out_size) override; | |||
| void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_in_size) override; | |||
| private: | |||
| @@ -25,67 +25,75 @@ namespace mindspore { | |||
| namespace ps { | |||
| namespace gpu { | |||
| MS_REG_PS_CACHE(kGPUDevice, GPUPsCache); | |||
| void GPUPsCache::InitDevice(uint32_t device_id, const void *) { | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)), | |||
| "Cuda create stream failed"); | |||
| bool GPUPsCache::InitDevice(uint32_t device_id, const void *) { | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)), | |||
| "Cuda create stream failed"); | |||
| return true; | |||
| } | |||
| void *GPUPsCache::MallocMemory(size_t size) { | |||
| return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size); | |||
| } | |||
| void GPUPsCache::RecordEvent() { | |||
| bool GPUPsCache::RecordEvent() { | |||
| event_.reset(new cudaEvent_t()); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda record event failed"); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda record event failed"); | |||
| return true; | |||
| } | |||
| void GPUPsCache::SynchronizeEvent() { | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); | |||
| bool GPUPsCache::SynchronizeEvent() { | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); | |||
| return true; | |||
| } | |||
| void GPUPsCache::SynchronizeStream() { | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda sync stream failed"); | |||
| bool GPUPsCache::SynchronizeStream() { | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda sync stream failed"); | |||
| return true; | |||
| } | |||
| void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||
| bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | |||
| cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda memcpy failed"); | |||
| return true; | |||
| } | |||
| void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_EXCEPTION_IF_NULL(dst); | |||
| MS_EXCEPTION_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||
| bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | |||
| cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)), | |||
| "Cuda memcpy failed"); | |||
| return true; | |||
| } | |||
| void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, | |||
| bool GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, | |||
| size_t embedding_size, size_t swap_out_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_out_index_addr); | |||
| MS_ERROR_IF_NULL(hash_table_addr); | |||
| MS_ERROR_IF_NULL(swap_out_value_addr); | |||
| MS_ERROR_IF_NULL(swap_out_index_addr); | |||
| DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr), | |||
| reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size, | |||
| reinterpret_cast<cudaStream_t>(stream_)); | |||
| return true; | |||
| } | |||
| void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, | |||
| bool GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, | |||
| size_t embedding_size, size_t swap_in_size) { | |||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_value_addr); | |||
| MS_EXCEPTION_IF_NULL(swap_in_index_addr); | |||
| MS_ERROR_IF_NULL(hash_table_addr); | |||
| MS_ERROR_IF_NULL(swap_in_value_addr); | |||
| MS_ERROR_IF_NULL(swap_in_index_addr); | |||
| DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr), | |||
| reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size, | |||
| reinterpret_cast<cudaStream_t>(stream_)); | |||
| return true; | |||
| } | |||
| } // namespace gpu | |||
| } // namespace ps | |||
| @@ -28,16 +28,16 @@ class GPUPsCache : public PsCacheBasic { | |||
| public: | |||
| GPUPsCache() = default; | |||
| ~GPUPsCache() override = default; | |||
| void InitDevice(uint32_t device_id, const void *context) override; | |||
| bool InitDevice(uint32_t device_id, const void *context) override; | |||
| void *MallocMemory(size_t size) override; | |||
| void RecordEvent() override; | |||
| void SynchronizeEvent() override; | |||
| void SynchronizeStream() override; | |||
| void CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||
| bool RecordEvent() override; | |||
| bool SynchronizeEvent() override; | |||
| bool SynchronizeStream() override; | |||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_out_size) override; | |||
| void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||
| size_t embedding_size, size_t swap_in_size) override; | |||
| private: | |||
| @@ -21,21 +21,28 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| #define RETURN_IF_FALSE(condition) \ | |||
| do { \ | |||
| if (!(condition)) { \ | |||
| return false; \ | |||
| } \ | |||
| } while (false) | |||
| class PsCacheBasic { | |||
| public: | |||
| PsCacheBasic() = default; | |||
| virtual ~PsCacheBasic() = default; | |||
| virtual void InitDevice(uint32_t device_id, const void *context) = 0; | |||
| virtual bool InitDevice(uint32_t device_id, const void *context) = 0; | |||
| virtual void *MallocMemory(size_t size) = 0; | |||
| virtual void MallocConstantMemory(size_t constant_value) {} | |||
| virtual void RecordEvent() = 0; | |||
| virtual void SynchronizeEvent() = 0; | |||
| virtual void SynchronizeStream() = 0; | |||
| virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||
| virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||
| virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| virtual bool MallocConstantMemory(size_t constant_value) { return true; } | |||
| virtual bool RecordEvent() = 0; | |||
| virtual bool SynchronizeEvent() = 0; | |||
| virtual bool SynchronizeStream() = 0; | |||
| virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||
| virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||
| virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; | |||
| virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; | |||
| protected: | |||
| @@ -170,8 +170,10 @@ void PsCacheManager::AddEmbeddingTable() const { | |||
| void PsCacheManager::InitParameterServer() { | |||
| MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); | |||
| insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; }); | |||
| if (!running_) { | |||
| return; | |||
| } | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t key = worker.SetParamKey(param_name); | |||
| @@ -224,7 +226,9 @@ void PsCacheManager::AllocMemForHashTable() { | |||
| embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>( | |||
| embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); | |||
| embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_); | |||
| if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) { | |||
| MS_LOG(EXCEPTION) << "MallocConstantMemory failed."; | |||
| } | |||
| } | |||
| void PsCacheManager::SetLocalIdRank() { | |||
| @@ -250,19 +254,25 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { | |||
| channel_name_ = channel_name; | |||
| } | |||
| void PsCacheManager::IncreaseStep() { | |||
| bool PsCacheManager::IncreaseStep() { | |||
| if (data_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; | |||
| MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; | |||
| return false; | |||
| } | |||
| data_step_++; | |||
| set_current_graph_step(); | |||
| if (graph_running_step_ > data_step_) { | |||
| MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ | |||
| << ")."; | |||
| MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ | |||
| << ")."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| if (terminated_) { | |||
| MS_LOG(EXCEPTION) << "ps cache data process thread is terminated."; | |||
| } | |||
| if (graph_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; | |||
| } | |||
| @@ -274,7 +284,9 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| } | |||
| graph_step_++; | |||
| set_channel_name(channel_name); | |||
| PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); | |||
| if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) { | |||
| MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name; | |||
| } | |||
| data_prase_.notify_one(); | |||
| } | |||
| @@ -284,74 +296,99 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||
| MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | |||
| "mode, current dataset mode is not sink_mode."; | |||
| } | |||
| auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||
| process_data_thread.detach(); | |||
| process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||
| } | |||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| embedding_device_cache_->cache_->InitDevice(device_id, context); | |||
| running_ = true; | |||
| bool ret = true; | |||
| InitParameterServer(); | |||
| while (true) { | |||
| ProcessData(); | |||
| while (ret) { | |||
| if (!running_) { | |||
| break; | |||
| } | |||
| ret = ProcessData(); | |||
| } | |||
| if (!ret) { | |||
| terminated_ = true; | |||
| } | |||
| } | |||
| void PsCacheManager::ProcessData() { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| void PsCacheManager::Finalize() { | |||
| if (running_) { | |||
| running_ = false; | |||
| } | |||
| PsDataPrefetch::GetInstance().NotifyFinalize(); | |||
| insert_init_info_.notify_all(); | |||
| data_prase_.notify_all(); | |||
| if (process_data_thread_.joinable()) { | |||
| process_data_thread_.join(); | |||
| } | |||
| } | |||
| bool PsCacheManager::ProcessData() { | |||
| struct timeval start_time, end_time; | |||
| const uint64_t kUSecondInSecond = 1000000; | |||
| (void)gettimeofday(&start_time, nullptr); | |||
| auto channel = channel_name(); | |||
| if (channel.empty()) { | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_prase_.wait(locker, [this] { return !channel_name_.empty(); }); | |||
| data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); | |||
| if (!running_) { | |||
| return false; | |||
| } | |||
| } | |||
| auto data = PsDataPrefetch::GetInstance().data(channel_name_); | |||
| if (data == nullptr) { | |||
| MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); | |||
| return; | |||
| return true; | |||
| } | |||
| IncreaseStep(); | |||
| RETURN_IF_FALSE(IncreaseStep()); | |||
| auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); | |||
| if (data_size == 0) { | |||
| MS_LOG(ERROR) << "The data_size can not be zero."; | |||
| return false; | |||
| } | |||
| auto batch_ids = reinterpret_cast<int *>(data); | |||
| auto batch_ids_len = data_size / sizeof(int); | |||
| std::unique_ptr<int[]> hash_index(new int[batch_ids_len]); | |||
| if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { | |||
| MS_LOG(EXCEPTION) << "Process data memset failed."; | |||
| MS_LOG(ERROR) << "Process data memset failed."; | |||
| return false; | |||
| } | |||
| // Get hash swap in/out index and ids. | |||
| ParseData(batch_ids, batch_ids_len, hash_index.get()); | |||
| RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); | |||
| for (const auto &item : hash_tables_) { | |||
| auto key = worker.GetParamKey(item.first); | |||
| auto hash_info = item.second; | |||
| HashSwapHostToServer(key, hash_info); | |||
| HashSwapDeviceToHost(hash_info); | |||
| HashSwapServerToHost(key, hash_info); | |||
| HashSwapHostToDevice(hash_info); | |||
| RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info)); | |||
| RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info)); | |||
| RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info)); | |||
| RETURN_IF_FALSE(HashSwapHostToDevice(hash_info)); | |||
| } | |||
| // Replace the batch_ids by hash index for getNext-op getting hash index as input. | |||
| if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { | |||
| MS_LOG(EXCEPTION) << "Process data memcpy failed."; | |||
| MS_LOG(ERROR) << "Process data memcpy failed."; | |||
| return false; | |||
| } | |||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); | |||
| // Finish the data process and notify data prefetch. | |||
| PsDataPrefetch::GetInstance().FinalizeData(channel_name_); | |||
| RETURN_IF_FALSE(PsDataPrefetch::GetInstance().FinalizeData(channel_name_)); | |||
| (void)gettimeofday(&end_time, nullptr); | |||
| uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | |||
| cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | |||
| MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ | |||
| << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ | |||
| << ", time cost:" << cost / 1000 << "ms)."; | |||
| return true; | |||
| } | |||
| void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||
| MS_EXCEPTION_IF_NULL(batch_ids); | |||
| MS_EXCEPTION_IF_NULL(hash_index); | |||
| bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||
| MS_ERROR_IF_NULL(batch_ids); | |||
| MS_ERROR_IF_NULL(hash_index); | |||
| for (size_t i = 0; i < batch_ids_len; i++) { | |||
| bool need_swap_host_to_device = true; | |||
| bool need_swap_device_to_host = true; | |||
| @@ -360,12 +397,16 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||
| hash_index[i] = -1; | |||
| continue; | |||
| } | |||
| hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); | |||
| auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| return false; | |||
| } | |||
| hash_index[i] = index; | |||
| if (need_swap_host_to_device) { | |||
| ParseHostDataHostToDevice(id); | |||
| RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); | |||
| } | |||
| if (need_swap_device_to_host) { | |||
| ParseHostDataDeviceToHost(id); | |||
| RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); | |||
| } | |||
| } | |||
| // Each 1000 step prints ps cache hit rate. | |||
| @@ -374,33 +415,28 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||
| auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; | |||
| MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; | |||
| } | |||
| return true; | |||
| } | |||
| void PsCacheManager::WaitGraphRun() { | |||
| bool PsCacheManager::WaitGraphRun() { | |||
| MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { | |||
| MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ | |||
| << ", graph running step:" << graph_running_step_ << ")."; | |||
| MS_LOG(ERROR) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ | |||
| << ", graph running step:" << graph_running_step_ << ")."; | |||
| return false; | |||
| } | |||
| set_current_graph_step(); | |||
| return true; | |||
| } | |||
| int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { | |||
| MS_EXCEPTION_IF_NULL(need_swap_device_to_host); | |||
| MS_EXCEPTION_IF_NULL(need_swap_host_to_device); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | |||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | |||
| int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | |||
| int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||
| MS_EXCEPTION_IF_NULL(host_to_device_ids); | |||
| auto device_hash_map = embedding_device_cache_->device_hash_map_; | |||
| MS_EXCEPTION_IF_NULL(device_hash_map); | |||
| int index = 0; | |||
| auto iter = device_hash_map->id_iter(id); | |||
| if (device_hash_map->IsIdExist(iter)) { | |||
| @@ -417,7 +453,9 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b | |||
| index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | |||
| &(statistics_info_.device_to_host_size_)); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| WaitGraphRun(); | |||
| if (!WaitGraphRun()) { | |||
| return INVALID_INDEX_VALUE; | |||
| } | |||
| continue; | |||
| } | |||
| host_to_device_index[statistics_info_.host_to_device_size_] = index; | |||
| @@ -430,21 +468,20 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b | |||
| return index; | |||
| } | |||
| void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | |||
| int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | |||
| int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||
| MS_EXCEPTION_IF_NULL(server_to_host_index); | |||
| MS_EXCEPTION_IF_NULL(server_to_host_ids); | |||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||
| MS_ERROR_IF_NULL(host_to_server_index); | |||
| MS_ERROR_IF_NULL(host_to_server_ids); | |||
| MS_ERROR_IF_NULL(server_to_host_index); | |||
| MS_ERROR_IF_NULL(server_to_host_ids); | |||
| MS_ERROR_IF_NULL(host_to_device_index); | |||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | |||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||
| MS_ERROR_IF_NULL(host_hash_map); | |||
| auto iter = host_hash_map->id_iter(id); | |||
| if (host_hash_map->IsIdExist(iter)) { | |||
| auto index = iter->second; | |||
| @@ -457,7 +494,7 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| WaitGraphRun(); | |||
| RETURN_IF_FALSE(WaitGraphRun()); | |||
| continue; | |||
| } | |||
| host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | |||
| @@ -466,22 +503,21 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||
| break; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | |||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||
| MS_ERROR_IF_NULL(device_to_host_ids); | |||
| MS_ERROR_IF_NULL(host_to_server_index); | |||
| MS_ERROR_IF_NULL(host_to_server_ids); | |||
| MS_ERROR_IF_NULL(device_to_host_index); | |||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | |||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||
| MS_ERROR_IF_NULL(host_hash_map); | |||
| int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; | |||
| auto iter = host_hash_map->id_iter(swap_device_to_host_id); | |||
| if (host_hash_map->IsIdExist(iter)) { | |||
| @@ -495,13 +531,14 @@ void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| WaitGraphRun(); | |||
| RETURN_IF_FALSE(WaitGraphRun()); | |||
| continue; | |||
| } | |||
| device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | |||
| break; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, | |||
| @@ -514,19 +551,21 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, | |||
| size_t pos = index * outer_dim_size; | |||
| auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | |||
| MS_LOG(ERROR) << "LookUpTable task memcpy failed."; | |||
| terminated_ = true; | |||
| } | |||
| } else { | |||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; | |||
| MS_LOG(ERROR) << "LookUpTable task memset failed."; | |||
| terminated_ = true; | |||
| } | |||
| } | |||
| output_addr += outer_dim_size; | |||
| } | |||
| } | |||
| void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr) { | |||
| size_t first_dim_size = host_cache_vocab_size_; | |||
| size_t outer_dim_size = embedding_size; | |||
| @@ -553,9 +592,10 @@ void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l | |||
| for (size_t j = 0; j < i; j++) { | |||
| threads[j].join(); | |||
| } | |||
| return !terminated_; | |||
| } | |||
| void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||
| float *insert_data, float *hash_table_addr) { | |||
| size_t first_dim_size = host_cache_vocab_size_; | |||
| size_t thread_num = insert_indices_size / 10000 + 1; | |||
| @@ -565,8 +605,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||
| size_t i; | |||
| size_t task_offset = 0; | |||
| auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||
| auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||
| auto type_size = sizeof(float); | |||
| size_t lens = outer_dim_size * type_size; | |||
| for (size_t i = 0; i < insert_indices_size; ++i) { | |||
| @@ -574,7 +614,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | |||
| auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed."; | |||
| MS_LOG(ERROR) << "Insert hash table task memcpy failed."; | |||
| terminated_ = true; | |||
| } | |||
| } | |||
| } | |||
| @@ -596,94 +637,101 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||
| for (size_t j = 0; j < i; j++) { | |||
| threads[j].join(); | |||
| } | |||
| return !terminated_; | |||
| } | |||
| void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||
| auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | |||
| auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | |||
| auto swap_indices_size = statistics_info_.host_to_device_size_; | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| auto embedding_size = hash_info.embedding_size; | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | |||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index, | |||
| swap_out_data.get()); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_out_data.get(), | |||
| swap_indices_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_indices_size); | |||
| } | |||
| void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, | |||
| host_cache_host_to_device_index, swap_out_data.get())); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(), | |||
| swap_indices_size * embedding_size * sizeof(float))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, | |||
| device_cache_host_to_device_index, | |||
| swap_indices_size * sizeof(int))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( | |||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||
| hash_table_size, embedding_size, swap_indices_size)); | |||
| return true; | |||
| } | |||
| bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||
| auto swap_indices_size = statistics_info_.device_to_host_size_; | |||
| auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | |||
| auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto embedding_size = hash_info.embedding_size; | |||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_indices_size); | |||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(), | |||
| embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_indices_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, | |||
| swap_out_data.get(), host_hash_table_addr); | |||
| } | |||
| void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, | |||
| device_cache_device_to_host_index, | |||
| swap_indices_size * sizeof(int))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( | |||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||
| hash_table_size, embedding_size, swap_indices_size)); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( | |||
| swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_indices_size * embedding_size * sizeof(float))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); | |||
| RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, | |||
| swap_out_data.get(), host_hash_table_addr)); | |||
| return true; | |||
| } | |||
| bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { | |||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||
| auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| auto swap_indices_size = statistics_info_.host_to_server_size_; | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | |||
| ::ps::SArray<float> swap_out_data; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data.resize(swap_indices_size * embedding_size); | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, | |||
| swap_out_data.data()); | |||
| RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, | |||
| swap_out_data.data())); | |||
| auto copy_len = swap_indices_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| return true; | |||
| } | |||
| void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||
| bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { | |||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||
| auto swap_indices_size = statistics_info_.server_to_host_size_; | |||
| auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | |||
| auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | |||
| if (swap_indices_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | |||
| auto embedding_size = hash_info.embedding_size; | |||
| @@ -693,47 +741,50 @@ void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ | |||
| auto copy_len = swap_indices_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), | |||
| host_hash_table_addr); | |||
| RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, | |||
| lookup_result.data(), host_hash_table_addr)); | |||
| return true; | |||
| } | |||
| void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||
| bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||
| const HashTableInfo &hash_info) { | |||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||
| MS_EXCEPTION_IF_NULL(swap_out_data); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_ERROR_IF_NULL(swap_out_index); | |||
| MS_ERROR_IF_NULL(swap_out_data); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||
| auto swap_out_index_size = statistics_info_.device_to_host_size_; | |||
| if (swap_out_index_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| auto embedding_size = hash_info.embedding_size; | |||
| swap_out_data->resize(swap_out_index_size * embedding_size); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index, | |||
| swap_out_index_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_out_index_size); | |||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(), | |||
| embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_out_index_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->RecordEvent(); | |||
| } | |||
| void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( | |||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||
| hash_table_size, embedding_size, swap_out_index_size)); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( | |||
| swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_, | |||
| swap_out_index_size * embedding_size * sizeof(float))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->RecordEvent()); | |||
| return true; | |||
| } | |||
| bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||
| size_t key) { | |||
| MS_EXCEPTION_IF_NULL(swap_in_ids); | |||
| MS_EXCEPTION_IF_NULL(swap_in_index); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_ERROR_IF_NULL(swap_in_ids); | |||
| MS_ERROR_IF_NULL(swap_in_index); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||
| auto swap_in_ids_size = statistics_info_.host_to_device_size_; | |||
| if (swap_in_ids_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | |||
| auto hash_table_size = hash_info.device_address.size; | |||
| @@ -745,42 +796,44 @@ void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons | |||
| auto copy_len = swap_in_ids_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | |||
| // Hash swap-in in device. | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||
| lookup_result.data(), | |||
| swap_in_ids_size * embedding_size * sizeof(float)); | |||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index, | |||
| swap_in_ids_size * sizeof(int)); | |||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||
| embedding_size, swap_in_ids_size); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||
| embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), | |||
| swap_in_ids_size * embedding_size * sizeof(float))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, | |||
| swap_in_index, swap_in_ids_size * sizeof(int))); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( | |||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||
| hash_table_size, embedding_size, swap_in_ids_size)); | |||
| return true; | |||
| } | |||
| void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||
| bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||
| MS_ERROR_IF_NULL(swap_out_ids); | |||
| auto swap_out_ids_size = statistics_info_.device_to_host_size_; | |||
| if (swap_out_ids_size == 0) { | |||
| return; | |||
| return true; | |||
| } | |||
| ::ps::SArray<int> lookup_ids(swap_out_ids_size, 0); | |||
| auto copy_len = swap_out_ids_size * sizeof(int); | |||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||
| return false; | |||
| } | |||
| // Need synchronize event to ensure that the swap-out in device is completed. | |||
| embedding_device_cache_->cache_->SynchronizeEvent(); | |||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent()); | |||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | |||
| return true; | |||
| } | |||
| void PsCacheManager::DumpHashTables(bool dump_device_tables) const { | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||
| for (const auto &item : hash_tables_) { | |||
| const auto ¶m_name = item.first; | |||
| size_t cache_vocab_size = item.second.cache_vocab_size; | |||
| @@ -126,6 +126,8 @@ class PsCacheManager { | |||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | |||
| void DoProcessData(uint32_t device_id, void *context); | |||
| void IncreaseGraphStep(const std::string &channel_name); | |||
| bool terminated() const { return terminated_; } | |||
| void Finalize(); | |||
| void DumpHashTables(bool dump_device_tables = false) const; | |||
| private: | |||
| @@ -133,7 +135,7 @@ class PsCacheManager { | |||
| ~PsCacheManager() = default; | |||
| PsCacheManager(const PsCacheManager &) = delete; | |||
| PsCacheManager &operator=(const PsCacheManager &) = delete; | |||
| void IncreaseStep(); | |||
| bool IncreaseStep(); | |||
| void set_current_graph_step() { graph_running_step_ = graph_step_; } | |||
| std::string channel_name(); | |||
| void set_channel_name(const std::string channel_name); | |||
| @@ -141,23 +143,23 @@ class PsCacheManager { | |||
| void AllocMemForHashTable(); | |||
| void SetLocalIdRank(); | |||
| void ProcessDataTask(uint32_t device_id, void *context); | |||
| void ProcessData(); | |||
| void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | |||
| void WaitGraphRun(); | |||
| bool ProcessData(); | |||
| bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | |||
| bool WaitGraphRun(); | |||
| int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); | |||
| void ParseHostDataHostToDevice(size_t id); | |||
| void ParseHostDataDeviceToHost(size_t id); | |||
| void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info); | |||
| void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||
| void HashSwapHostToDevice(const HashTableInfo &hash_info); | |||
| void HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||
| void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | |||
| void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | |||
| void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||
| bool ParseHostDataHostToDevice(size_t id); | |||
| bool ParseHostDataDeviceToHost(size_t id); | |||
| bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||
| bool HashSwapHostToDevice(const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||
| bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | |||
| bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | |||
| bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||
| float *hash_table_addr); | |||
| void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| bool CheckFinishInsertInitInfo() const; | |||
| @@ -172,6 +174,7 @@ class PsCacheManager { | |||
| std::mutex data_mutex_; | |||
| std::condition_variable data_prase_; | |||
| std::condition_variable insert_init_info_; | |||
| std::thread process_data_thread_; | |||
| std::map<std::string, HashTableInfo> hash_tables_; | |||
| std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | |||
| @@ -185,6 +188,8 @@ class PsCacheManager { | |||
| std::pair<size_t, size_t> range_bound_; | |||
| std::atomic_bool finish_insert_init_info_{false}; | |||
| std::atomic_bool finish_init_parameter_server_{false}; | |||
| std::atomic_bool running_{false}; | |||
| std::atomic_bool terminated_{false}; | |||
| }; | |||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | |||
| @@ -28,11 +28,9 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s | |||
| if (iter != ps_data_channel_map_.end()) { | |||
| MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; | |||
| auto channel = iter->second; | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| channel->set_step_num(step_num); | |||
| } else { | |||
| auto channel = std::make_shared<PsDataChannel>(channel_name, step_num); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| (void)ps_data_channel_map_.emplace(channel_name, channel); | |||
| } | |||
| } | |||
| @@ -40,71 +38,95 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s | |||
| std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { | |||
| auto iter = ps_data_channel_map_.find(channel_name); | |||
| if (iter == ps_data_channel_map_.end()) { | |||
| MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name; | |||
| MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name; | |||
| return nullptr; | |||
| } | |||
| return iter->second; | |||
| } | |||
| void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||
| bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||
| if (cache_enable_ == false) { | |||
| return; | |||
| return true; | |||
| } | |||
| if (data == nullptr) { | |||
| MS_LOG(WARNING) << "No data prefetch."; | |||
| return; | |||
| return true; | |||
| } | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| MS_ERROR_IF_NULL(channel); | |||
| channel->set_data(data, data_size); | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_ready_ = true; | |||
| data_process_.notify_one(); | |||
| if (!need_wait_) { | |||
| return true; | |||
| } | |||
| for (int i = 0; i < 10; i++) { | |||
| if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) { | |||
| return; | |||
| if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), | |||
| [this] { return data_ready_ == false || need_wait_ == false; })) { | |||
| return true; | |||
| } else { | |||
| MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size."; | |||
| MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size."; | |||
| return false; | |||
| } | |||
| void PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||
| bool PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||
| if (cache_enable_ == false) { | |||
| return; | |||
| return true; | |||
| } | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| MS_ERROR_IF_NULL(channel); | |||
| channel->ResetData(); | |||
| std::unique_lock<std::mutex> locker(data_mutex_); | |||
| data_ready_ = false; | |||
| data_prefetch_.notify_one(); | |||
| if (!need_wait_) { | |||
| return true; | |||
| } | |||
| for (int i = 0; i < 10; i++) { | |||
| if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) { | |||
| return; | |||
| if (data_process_.wait_for(locker, std::chrono::seconds(30), | |||
| [this] { return data_ready_ == true || need_wait_ == false; })) { | |||
| return true; | |||
| } else { | |||
| MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; | |||
| } | |||
| } | |||
| MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout."; | |||
| MS_LOG(ERROR) << "Ps cache data prefetch timeout."; | |||
| return false; | |||
| } | |||
| void *PsDataPrefetch::data(const std::string &channel_name) const { | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| if (channel == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return channel->data(); | |||
| } | |||
| size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| if (channel == nullptr) { | |||
| return 0; | |||
| } | |||
| return channel->data_size(); | |||
| } | |||
| void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { | |||
| void PsDataPrefetch::NotifyFinalize() { | |||
| need_wait_ = false; | |||
| data_prefetch_.notify_one(); | |||
| data_process_.notify_one(); | |||
| } | |||
| bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { | |||
| auto channel = ps_data_channel(channel_name); | |||
| MS_EXCEPTION_IF_NULL(channel); | |||
| if (channel == nullptr) { | |||
| return false; | |||
| } | |||
| channel->TryWakeChannel(); | |||
| return true; | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -19,6 +19,7 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <atomic> | |||
| #include <condition_variable> | |||
| #include "ps/ps_cache/ps_data/ps_data_channel.h" | |||
| @@ -36,11 +37,12 @@ class EXPORT PsDataPrefetch { | |||
| EXPORT bool cache_enable() const { return cache_enable_; } | |||
| EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | |||
| EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | |||
| EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||
| EXPORT void FinalizeData(const std::string &channel_name); | |||
| EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||
| EXPORT bool FinalizeData(const std::string &channel_name); | |||
| EXPORT void NotifyFinalize(); | |||
| EXPORT void *data(const std::string &channel_name) const; | |||
| EXPORT size_t data_size(const std::string &channel_name) const; | |||
| EXPORT void TryWakeChannel(const std::string &channel_name); | |||
| EXPORT bool TryWakeChannel(const std::string &channel_name); | |||
| private: | |||
| PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} | |||
| @@ -54,6 +56,7 @@ class EXPORT PsDataPrefetch { | |||
| std::mutex data_mutex_; | |||
| std::condition_variable data_prefetch_; | |||
| std::condition_variable data_process_; | |||
| std::atomic_bool need_wait_{true}; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -17,10 +17,10 @@ | |||
| #include "ps/ps_context.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #include "backend/kernel_compiler/kernel.h" | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| #include "ps/ps_cache/ps_cache_manager.h" | |||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||
| #endif | |||
| namespace mindspore { | |||
| @@ -62,7 +62,12 @@ void PSContext::Reset() { | |||
| is_worker_ = false; | |||
| is_pserver_ = false; | |||
| is_sched_ = false; | |||
| set_cache_enable(false); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps_cache_instance.Finalize(); | |||
| set_cache_enable(false); | |||
| } | |||
| #endif | |||
| } | |||
| std::string PSContext::ms_role() const { | |||
| @@ -62,6 +62,16 @@ namespace gpu { | |||
| } \ | |||
| } | |||
| #define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \ | |||
| { \ | |||
| cudaError_t status = (expression); \ | |||
| if (status != cudaSuccess) { \ | |||
| MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \ | |||
| << cudaGetErrorString(status); \ | |||
| return false; \ | |||
| } \ | |||
| } | |||
| #define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \ | |||
| { \ | |||
| cudaError_t status = (expression); \ | |||
| @@ -199,6 +199,14 @@ class LogWriter { | |||
| } \ | |||
| } while (0) | |||
| #define MS_ERROR_IF_NULL(ptr) \ | |||
| do { \ | |||
| if ((ptr) == nullptr) { \ | |||
| MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ | |||
| return false; \ | |||
| } \ | |||
| } while (0) | |||
| #ifdef DEBUG | |||
| #include <cassert> | |||
| #define MS_ASSERT(f) assert(f) | |||