From: @zyli2020 Reviewed-by: @limingqi107 Signed-off-by: @limingqi107tags/v1.1.0
| @@ -304,7 +304,9 @@ Status DeviceQueueOp::PushDataToGPU() { | |||||
| // Data prefetch only when PS mode enables cache. | // Data prefetch only when PS mode enables cache. | ||||
| if (items.size() > 0) { | if (items.size() > 0) { | ||||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_); | |||||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { | |||||
| return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); | |||||
| } | |||||
| } | } | ||||
| while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { | while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { | ||||
| BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); | ||||
| @@ -55,7 +55,9 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe | |||||
| #if ENABLE_D | #if ENABLE_D | ||||
| // Data prefetch only when PS mode enables cache. | // Data prefetch only when PS mode enables cache. | ||||
| if (items.size() > 0) { | if (items.size() > 0) { | ||||
| ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); | |||||
| if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) { | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| #endif | #endif | ||||
| if (tdt::TdtHostPushData(channel_name, items) != 0) { | if (tdt::TdtHostPushData(channel_name, items) != 0) { | ||||
| @@ -53,6 +53,7 @@ | |||||
| #include "ps/util.h" | #include "ps/util.h" | ||||
| #include "ps/worker.h" | #include "ps/worker.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | #endif | ||||
| #if (ENABLE_GE || ENABLE_D) | #if (ENABLE_GE || ENABLE_D) | ||||
| @@ -1083,9 +1084,10 @@ void ClearResAtexit() { | |||||
| pynative::ClearPyNativeSession(); | pynative::ClearPyNativeSession(); | ||||
| session::ClearPythonParasMap(); | session::ClearPythonParasMap(); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| if (ps::Util::IsParamServerMode()) { | |||||
| if (ps::Util::IsRoleOfWorker()) { | |||||
| ps::worker.Finalize(); | |||||
| if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { | |||||
| ps::worker.Finalize(); | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| ps::ps_cache_instance.Finalize(); | |||||
| } | } | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -37,155 +37,178 @@ namespace ps { | |||||
| namespace ascend { | namespace ascend { | ||||
| MS_REG_PS_CACHE(kAscendDevice, AscendPsCache); | MS_REG_PS_CACHE(kAscendDevice, AscendPsCache); | ||||
| namespace { | namespace { | ||||
| void SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||||
| bool SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||||
| mindspore::NodeDef *proto) { | mindspore::NodeDef *proto) { | ||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| MS_ERROR_IF_NULL(proto); | |||||
| if (data_shape.size() != data_type.size()) { | if (data_shape.size() != data_type.size()) { | ||||
| MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; | |||||
| MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type."; | |||||
| return false; | |||||
| } | } | ||||
| for (size_t input_index = 0; input_index < data_shape.size(); input_index++) { | for (size_t input_index = 0; input_index < data_shape.size(); input_index++) { | ||||
| ::mindspore::Tensor *proto_inputs = proto->add_inputs(); | ::mindspore::Tensor *proto_inputs = proto->add_inputs(); | ||||
| MS_EXCEPTION_IF_NULL(proto_inputs); | |||||
| MS_ERROR_IF_NULL(proto_inputs); | |||||
| auto input_shape = data_shape[input_index]; | auto input_shape = data_shape[input_index]; | ||||
| mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape(); | mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape(); | ||||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||||
| MS_ERROR_IF_NULL(tensorShape); | |||||
| for (auto item : input_shape) { | for (auto item : input_shape) { | ||||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | ||||
| MS_EXCEPTION_IF_NULL(dim); | |||||
| MS_ERROR_IF_NULL(dim); | |||||
| dim->set_size((::google::protobuf::int64)item); | dim->set_size((::google::protobuf::int64)item); | ||||
| } | } | ||||
| auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]); | auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]); | ||||
| proto_inputs->set_tensor_type(input_type); | proto_inputs->set_tensor_type(input_type); | ||||
| proto_inputs->set_mem_device("HBM"); | proto_inputs->set_mem_device("HBM"); | ||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||||
| bool SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type, | |||||
| mindspore::NodeDef *proto) { | mindspore::NodeDef *proto) { | ||||
| MS_EXCEPTION_IF_NULL(proto); | |||||
| MS_ERROR_IF_NULL(proto); | |||||
| if (data_shape.size() != data_type.size()) { | if (data_shape.size() != data_type.size()) { | ||||
| MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; | |||||
| MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type."; | |||||
| return false; | |||||
| } | } | ||||
| for (size_t output_index = 0; output_index < data_shape.size(); output_index++) { | for (size_t output_index = 0; output_index < data_shape.size(); output_index++) { | ||||
| ::mindspore::Tensor *proto_outputs = proto->add_outputs(); | ::mindspore::Tensor *proto_outputs = proto->add_outputs(); | ||||
| MS_EXCEPTION_IF_NULL(proto_outputs); | |||||
| MS_ERROR_IF_NULL(proto_outputs); | |||||
| auto output_shape = data_shape[output_index]; | auto output_shape = data_shape[output_index]; | ||||
| mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape(); | mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape(); | ||||
| MS_EXCEPTION_IF_NULL(tensorShape); | |||||
| MS_ERROR_IF_NULL(tensorShape); | |||||
| for (auto item : output_shape) { | for (auto item : output_shape) { | ||||
| mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); | ||||
| MS_EXCEPTION_IF_NULL(dim); | |||||
| MS_ERROR_IF_NULL(dim); | |||||
| dim->set_size((::google::protobuf::int64)item); | dim->set_size((::google::protobuf::int64)item); | ||||
| } | } | ||||
| auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]); | auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]); | ||||
| proto_outputs->set_tensor_type(output_type); | proto_outputs->set_tensor_type(output_type); | ||||
| proto_outputs->set_mem_device("HBM"); | proto_outputs->set_mem_device("HBM"); | ||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info, | |||||
| bool SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info, | |||||
| const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) { | const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) { | ||||
| MS_EXCEPTION_IF_NULL(op_info); | |||||
| MS_EXCEPTION_IF_NULL(kernel_mod_ptr); | |||||
| MS_ERROR_IF_NULL(op_info); | |||||
| MS_ERROR_IF_NULL(kernel_mod_ptr); | |||||
| mindspore::NodeDef proto; | mindspore::NodeDef proto; | ||||
| proto.set_op(op_info->op_name_); | proto.set_op(op_info->op_name_); | ||||
| SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto); | |||||
| SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto); | |||||
| RETURN_IF_FALSE(SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto)); | |||||
| RETURN_IF_FALSE(SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto)); | |||||
| std::string nodeDefStr; | std::string nodeDefStr; | ||||
| if (!proto.SerializeToString(&nodeDefStr)) { | if (!proto.SerializeToString(&nodeDefStr)) { | ||||
| MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed."; | |||||
| MS_LOG(ERROR) << "Serialize nodeDef to string failed."; | |||||
| return false; | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_; | MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_; | ||||
| kernel_mod_ptr->SetNodeDef(nodeDefStr); | kernel_mod_ptr->SetNodeDef(nodeDefStr); | ||||
| return true; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| void AscendPsCache::InitDevice(uint32_t device_id, const void *context) { | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| bool AscendPsCache::InitDevice(uint32_t device_id, const void *context) { | |||||
| MS_ERROR_IF_NULL(context); | |||||
| auto ret = rtSetDevice(device_id); | auto ret = rtSetDevice(device_id); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]"; | |||||
| MS_LOG(ERROR) << "Call rtSetDevice, ret[" << ret << "]"; | |||||
| return false; | |||||
| } | } | ||||
| auto rt_context = const_cast<rtContext_t>(context); | auto rt_context = const_cast<rtContext_t>(context); | ||||
| ret = rtCtxSetCurrent(rt_context); | ret = rtCtxSetCurrent(rt_context); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||||
| MS_LOG(ERROR) << "Call rtCtxSetCurrent, ret[" << ret << "]"; | |||||
| return false; | |||||
| } | } | ||||
| ret = rtStreamCreate(&stream_, 0); | ret = rtStreamCreate(&stream_, 0); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]"; | |||||
| MS_LOG(ERROR) << "Call rtStreamCreate, ret[" << ret << "]"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void *AscendPsCache::MallocMemory(size_t size) { | void *AscendPsCache::MallocMemory(size_t size) { | ||||
| return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); | return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); | ||||
| } | } | ||||
| void AscendPsCache::MallocConstantMemory(size_t constant_value) { | |||||
| bool AscendPsCache::MallocConstantMemory(size_t constant_value) { | |||||
| offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | ||||
| MS_EXCEPTION_IF_NULL(offset_addr_); | |||||
| MS_ERROR_IF_NULL(offset_addr_); | |||||
| rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); | rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); | ||||
| cache_vocab_size_addr_ = | cache_vocab_size_addr_ = | ||||
| reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | ||||
| MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_); | |||||
| MS_ERROR_IF_NULL(cache_vocab_size_addr_); | |||||
| rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); | rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::RecordEvent() { | |||||
| bool AscendPsCache::RecordEvent() { | |||||
| event_.reset(new rtEvent_t()); | event_.reset(new rtEvent_t()); | ||||
| auto ret = rtEventCreate(&(*event_)); | auto ret = rtEventCreate(&(*event_)); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "Create event failed"; | |||||
| MS_LOG(ERROR) << "Create event failed"; | |||||
| return false; | |||||
| } | } | ||||
| ret = rtEventRecord(*event_, stream_); | ret = rtEventRecord(*event_, stream_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "Record event failed"; | |||||
| MS_LOG(ERROR) << "Record event failed"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::SynchronizeEvent() { | |||||
| bool AscendPsCache::SynchronizeEvent() { | |||||
| auto ret = rtEventSynchronize(*event_); | auto ret = rtEventSynchronize(*event_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed"; | |||||
| MS_LOG(ERROR) << "tEventSynchronize failed"; | |||||
| return false; | |||||
| } | } | ||||
| ret = rtEventDestroy(*event_); | ret = rtEventDestroy(*event_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed"; | |||||
| MS_LOG(ERROR) << "rtEventDestroy failed"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::SynchronizeStream() { | |||||
| bool AscendPsCache::SynchronizeStream() { | |||||
| auto ret = rtStreamSynchronize(stream_); | auto ret = rtStreamSynchronize(stream_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed"; | |||||
| MS_LOG(ERROR) << "rtStreamSynchronize failed"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(dst); | |||||
| MS_EXCEPTION_IF_NULL(src); | |||||
| bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | |||||
| MS_ERROR_IF_NULL(src); | |||||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; | |||||
| MS_LOG(ERROR) << "rtMemcpyAsync failed"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(dst); | |||||
| MS_EXCEPTION_IF_NULL(src); | |||||
| bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | |||||
| MS_ERROR_IF_NULL(src); | |||||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; | |||||
| MS_LOG(ERROR) << "rtMemcpyAsync failed"; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||||
| bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||||
| size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { | size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { | ||||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_value_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_index_addr); | |||||
| MS_ERROR_IF_NULL(hash_table_addr); | |||||
| MS_ERROR_IF_NULL(swap_out_value_addr); | |||||
| MS_ERROR_IF_NULL(swap_out_index_addr); | |||||
| auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | ||||
| MS_EXCEPTION_IF_NULL(hash_swap_out_mod); | |||||
| MS_ERROR_IF_NULL(hash_swap_out_mod); | |||||
| hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName); | hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName); | ||||
| std::vector<std::vector<size_t>> input_shape; | std::vector<std::vector<size_t>> input_shape; | ||||
| std::vector<std::vector<size_t>> output_shape; | std::vector<std::vector<size_t>> output_shape; | ||||
| @@ -197,7 +220,7 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr | |||||
| output_shape.push_back({swap_out_size, embedding_size}); | output_shape.push_back({swap_out_size, embedding_size}); | ||||
| auto op_info = | auto op_info = | ||||
| std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type); | std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type); | ||||
| SetNodedefProto(op_info, hash_swap_out_mod); | |||||
| RETURN_IF_FALSE(SetNodedefProto(op_info, hash_swap_out_mod)); | |||||
| AddressPtrList kernel_inputs; | AddressPtrList kernel_inputs; | ||||
| AddressPtrList kernel_outputs = { | AddressPtrList kernel_outputs = { | ||||
| @@ -208,17 +231,19 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr | |||||
| kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | ||||
| auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | ||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(EXCEPTION) << "Hash swap out launch failed."; | |||||
| MS_LOG(ERROR) << "Hash swap out launch failed."; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||||
| bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||||
| size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { | size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { | ||||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_in_value_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_in_index_addr); | |||||
| MS_ERROR_IF_NULL(hash_table_addr); | |||||
| MS_ERROR_IF_NULL(swap_in_value_addr); | |||||
| MS_ERROR_IF_NULL(swap_in_index_addr); | |||||
| auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>(); | ||||
| MS_EXCEPTION_IF_NULL(hash_swap_in_mod); | |||||
| MS_ERROR_IF_NULL(hash_swap_in_mod); | |||||
| hash_swap_in_mod->SetNodeName(kernel::kUpdateCache); | hash_swap_in_mod->SetNodeName(kernel::kUpdateCache); | ||||
| std::vector<std::vector<size_t>> input_shape; | std::vector<std::vector<size_t>> input_shape; | ||||
| std::vector<std::vector<size_t>> output_shape; | std::vector<std::vector<size_t>> output_shape; | ||||
| @@ -245,8 +270,10 @@ void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, | |||||
| kernel_outputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | kernel_outputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int))); | ||||
| auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); | ||||
| if (!ret) { | if (!ret) { | ||||
| MS_LOG(EXCEPTION) << "Hash swap in launch failed."; | |||||
| MS_LOG(ERROR) << "Hash swap in launch failed."; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| } // namespace ascend | } // namespace ascend | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -49,17 +49,17 @@ class AscendPsCache : public PsCacheBasic { | |||||
| public: | public: | ||||
| AscendPsCache() = default; | AscendPsCache() = default; | ||||
| ~AscendPsCache() override = default; | ~AscendPsCache() override = default; | ||||
| void InitDevice(uint32_t device_id, const void *context) override; | |||||
| bool InitDevice(uint32_t device_id, const void *context) override; | |||||
| void *MallocMemory(size_t size) override; | void *MallocMemory(size_t size) override; | ||||
| void MallocConstantMemory(size_t constant_value) override; | |||||
| void RecordEvent() override; | |||||
| void SynchronizeEvent() override; | |||||
| void SynchronizeStream() override; | |||||
| void CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||||
| void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||||
| void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||||
| bool MallocConstantMemory(size_t constant_value) override; | |||||
| bool RecordEvent() override; | |||||
| bool SynchronizeEvent() override; | |||||
| bool SynchronizeStream() override; | |||||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||||
| size_t embedding_size, size_t swap_out_size) override; | size_t embedding_size, size_t swap_out_size) override; | ||||
| void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||||
| size_t embedding_size, size_t swap_in_size) override; | size_t embedding_size, size_t swap_in_size) override; | ||||
| private: | private: | ||||
| @@ -25,67 +25,75 @@ namespace mindspore { | |||||
| namespace ps { | namespace ps { | ||||
| namespace gpu { | namespace gpu { | ||||
| MS_REG_PS_CACHE(kGPUDevice, GPUPsCache); | MS_REG_PS_CACHE(kGPUDevice, GPUPsCache); | ||||
| void GPUPsCache::InitDevice(uint32_t device_id, const void *) { | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)), | |||||
| "Cuda create stream failed"); | |||||
| bool GPUPsCache::InitDevice(uint32_t device_id, const void *) { | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)), | |||||
| "Cuda create stream failed"); | |||||
| return true; | |||||
| } | } | ||||
| void *GPUPsCache::MallocMemory(size_t size) { | void *GPUPsCache::MallocMemory(size_t size) { | ||||
| return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size); | return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size); | ||||
| } | } | ||||
| void GPUPsCache::RecordEvent() { | |||||
| bool GPUPsCache::RecordEvent() { | |||||
| event_.reset(new cudaEvent_t()); | event_.reset(new cudaEvent_t()); | ||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)), | |||||
| "Cuda record event failed"); | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)), | |||||
| "Cuda record event failed"); | |||||
| return true; | |||||
| } | } | ||||
| void GPUPsCache::SynchronizeEvent() { | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); | |||||
| bool GPUPsCache::SynchronizeEvent() { | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); | |||||
| return true; | |||||
| } | } | ||||
| void GPUPsCache::SynchronizeStream() { | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)), | |||||
| "Cuda sync stream failed"); | |||||
| bool GPUPsCache::SynchronizeStream() { | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)), | |||||
| "Cuda sync stream failed"); | |||||
| return true; | |||||
| } | } | ||||
| void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(dst); | |||||
| MS_EXCEPTION_IF_NULL(src); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||||
| bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | |||||
| MS_ERROR_IF_NULL(src); | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | |||||
| cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)), | cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)), | ||||
| "Cuda memcpy failed"); | "Cuda memcpy failed"); | ||||
| return true; | |||||
| } | } | ||||
| void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||||
| MS_EXCEPTION_IF_NULL(dst); | |||||
| MS_EXCEPTION_IF_NULL(src); | |||||
| CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( | |||||
| bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | |||||
| MS_ERROR_IF_NULL(src); | |||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | |||||
| cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)), | cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)), | ||||
| "Cuda memcpy failed"); | "Cuda memcpy failed"); | ||||
| return true; | |||||
| } | } | ||||
| void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, | |||||
| bool GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, | |||||
| size_t embedding_size, size_t swap_out_size) { | size_t embedding_size, size_t swap_out_size) { | ||||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_value_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_index_addr); | |||||
| MS_ERROR_IF_NULL(hash_table_addr); | |||||
| MS_ERROR_IF_NULL(swap_out_value_addr); | |||||
| MS_ERROR_IF_NULL(swap_out_index_addr); | |||||
| DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr), | DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr), | ||||
| reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size, | reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size, | ||||
| reinterpret_cast<cudaStream_t>(stream_)); | reinterpret_cast<cudaStream_t>(stream_)); | ||||
| return true; | |||||
| } | } | ||||
| void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, | |||||
| bool GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, | |||||
| size_t embedding_size, size_t swap_in_size) { | size_t embedding_size, size_t swap_in_size) { | ||||
| MS_EXCEPTION_IF_NULL(hash_table_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_in_value_addr); | |||||
| MS_EXCEPTION_IF_NULL(swap_in_index_addr); | |||||
| MS_ERROR_IF_NULL(hash_table_addr); | |||||
| MS_ERROR_IF_NULL(swap_in_value_addr); | |||||
| MS_ERROR_IF_NULL(swap_in_index_addr); | |||||
| DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr), | DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr), | ||||
| reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size, | reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size, | ||||
| reinterpret_cast<cudaStream_t>(stream_)); | reinterpret_cast<cudaStream_t>(stream_)); | ||||
| return true; | |||||
| } | } | ||||
| } // namespace gpu | } // namespace gpu | ||||
| } // namespace ps | } // namespace ps | ||||
| @@ -28,16 +28,16 @@ class GPUPsCache : public PsCacheBasic { | |||||
| public: | public: | ||||
| GPUPsCache() = default; | GPUPsCache() = default; | ||||
| ~GPUPsCache() override = default; | ~GPUPsCache() override = default; | ||||
| void InitDevice(uint32_t device_id, const void *context) override; | |||||
| bool InitDevice(uint32_t device_id, const void *context) override; | |||||
| void *MallocMemory(size_t size) override; | void *MallocMemory(size_t size) override; | ||||
| void RecordEvent() override; | |||||
| void SynchronizeEvent() override; | |||||
| void SynchronizeStream() override; | |||||
| void CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||||
| void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||||
| void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||||
| bool RecordEvent() override; | |||||
| bool SynchronizeEvent() override; | |||||
| bool SynchronizeStream() override; | |||||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, | |||||
| size_t embedding_size, size_t swap_out_size) override; | size_t embedding_size, size_t swap_out_size) override; | ||||
| void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, | |||||
| size_t embedding_size, size_t swap_in_size) override; | size_t embedding_size, size_t swap_in_size) override; | ||||
| private: | private: | ||||
| @@ -21,21 +21,28 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| #define RETURN_IF_FALSE(condition) \ | |||||
| do { \ | |||||
| if (!(condition)) { \ | |||||
| return false; \ | |||||
| } \ | |||||
| } while (false) | |||||
| class PsCacheBasic { | class PsCacheBasic { | ||||
| public: | public: | ||||
| PsCacheBasic() = default; | PsCacheBasic() = default; | ||||
| virtual ~PsCacheBasic() = default; | virtual ~PsCacheBasic() = default; | ||||
| virtual void InitDevice(uint32_t device_id, const void *context) = 0; | |||||
| virtual bool InitDevice(uint32_t device_id, const void *context) = 0; | |||||
| virtual void *MallocMemory(size_t size) = 0; | virtual void *MallocMemory(size_t size) = 0; | ||||
| virtual void MallocConstantMemory(size_t constant_value) {} | |||||
| virtual void RecordEvent() = 0; | |||||
| virtual void SynchronizeEvent() = 0; | |||||
| virtual void SynchronizeStream() = 0; | |||||
| virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||||
| virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||||
| virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||||
| virtual bool MallocConstantMemory(size_t constant_value) { return true; } | |||||
| virtual bool RecordEvent() = 0; | |||||
| virtual bool SynchronizeEvent() = 0; | |||||
| virtual bool SynchronizeStream() = 0; | |||||
| virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||||
| virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||||
| virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||||
| size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; | size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; | ||||
| virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||||
| virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||||
| size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; | size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; | ||||
| protected: | protected: | ||||
| @@ -170,8 +170,10 @@ void PsCacheManager::AddEmbeddingTable() const { | |||||
| void PsCacheManager::InitParameterServer() { | void PsCacheManager::InitParameterServer() { | ||||
| MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; | MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); | |||||
| insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; }); | |||||
| if (!running_) { | |||||
| return; | |||||
| } | |||||
| for (const auto &item : hash_tables_) { | for (const auto &item : hash_tables_) { | ||||
| const auto ¶m_name = item.first; | const auto ¶m_name = item.first; | ||||
| size_t key = worker.SetParamKey(param_name); | size_t key = worker.SetParamKey(param_name); | ||||
| @@ -224,7 +226,9 @@ void PsCacheManager::AllocMemForHashTable() { | |||||
| embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>( | embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>( | ||||
| embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); | embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); | ||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); | MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); | ||||
| embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_); | |||||
| if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) { | |||||
| MS_LOG(EXCEPTION) << "MallocConstantMemory failed."; | |||||
| } | |||||
| } | } | ||||
| void PsCacheManager::SetLocalIdRank() { | void PsCacheManager::SetLocalIdRank() { | ||||
| @@ -250,19 +254,25 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { | |||||
| channel_name_ = channel_name; | channel_name_ = channel_name; | ||||
| } | } | ||||
| void PsCacheManager::IncreaseStep() { | |||||
| bool PsCacheManager::IncreaseStep() { | |||||
| if (data_step_ >= UINT64_MAX) { | if (data_step_ >= UINT64_MAX) { | ||||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; | |||||
| MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; | |||||
| return false; | |||||
| } | } | ||||
| data_step_++; | data_step_++; | ||||
| set_current_graph_step(); | set_current_graph_step(); | ||||
| if (graph_running_step_ > data_step_) { | if (graph_running_step_ > data_step_) { | ||||
| MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ | |||||
| << ")."; | |||||
| MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ | |||||
| << ")."; | |||||
| return false; | |||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | ||||
| if (terminated_) { | |||||
| MS_LOG(EXCEPTION) << "ps cache data process thread is terminated."; | |||||
| } | |||||
| if (graph_step_ >= UINT64_MAX) { | if (graph_step_ >= UINT64_MAX) { | ||||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; | MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; | ||||
| } | } | ||||
| @@ -274,7 +284,9 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||||
| } | } | ||||
| graph_step_++; | graph_step_++; | ||||
| set_channel_name(channel_name); | set_channel_name(channel_name); | ||||
| PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); | |||||
| if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) { | |||||
| MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name; | |||||
| } | |||||
| data_prase_.notify_one(); | data_prase_.notify_one(); | ||||
| } | } | ||||
| @@ -284,74 +296,99 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||||
| MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | ||||
| "mode, current dataset mode is not sink_mode."; | "mode, current dataset mode is not sink_mode."; | ||||
| } | } | ||||
| auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||||
| process_data_thread.detach(); | |||||
| process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||||
| } | } | ||||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | ||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| embedding_device_cache_->cache_->InitDevice(device_id, context); | embedding_device_cache_->cache_->InitDevice(device_id, context); | ||||
| running_ = true; | |||||
| bool ret = true; | |||||
| InitParameterServer(); | InitParameterServer(); | ||||
| while (true) { | |||||
| ProcessData(); | |||||
| while (ret) { | |||||
| if (!running_) { | |||||
| break; | |||||
| } | |||||
| ret = ProcessData(); | |||||
| } | |||||
| if (!ret) { | |||||
| terminated_ = true; | |||||
| } | } | ||||
| } | } | ||||
| void PsCacheManager::ProcessData() { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| void PsCacheManager::Finalize() { | |||||
| if (running_) { | |||||
| running_ = false; | |||||
| } | |||||
| PsDataPrefetch::GetInstance().NotifyFinalize(); | |||||
| insert_init_info_.notify_all(); | |||||
| data_prase_.notify_all(); | |||||
| if (process_data_thread_.joinable()) { | |||||
| process_data_thread_.join(); | |||||
| } | |||||
| } | |||||
| bool PsCacheManager::ProcessData() { | |||||
| struct timeval start_time, end_time; | struct timeval start_time, end_time; | ||||
| const uint64_t kUSecondInSecond = 1000000; | const uint64_t kUSecondInSecond = 1000000; | ||||
| (void)gettimeofday(&start_time, nullptr); | (void)gettimeofday(&start_time, nullptr); | ||||
| auto channel = channel_name(); | auto channel = channel_name(); | ||||
| if (channel.empty()) { | if (channel.empty()) { | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| data_prase_.wait(locker, [this] { return !channel_name_.empty(); }); | |||||
| data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); | |||||
| if (!running_) { | |||||
| return false; | |||||
| } | |||||
| } | } | ||||
| auto data = PsDataPrefetch::GetInstance().data(channel_name_); | auto data = PsDataPrefetch::GetInstance().data(channel_name_); | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | MS_LOG(INFO) << "No data process, channel name:" << channel_name_; | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); | (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| IncreaseStep(); | |||||
| RETURN_IF_FALSE(IncreaseStep()); | |||||
| auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); | auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); | ||||
| if (data_size == 0) { | |||||
| MS_LOG(ERROR) << "The data_size can not be zero."; | |||||
| return false; | |||||
| } | |||||
| auto batch_ids = reinterpret_cast<int *>(data); | auto batch_ids = reinterpret_cast<int *>(data); | ||||
| auto batch_ids_len = data_size / sizeof(int); | auto batch_ids_len = data_size / sizeof(int); | ||||
| std::unique_ptr<int[]> hash_index(new int[batch_ids_len]); | std::unique_ptr<int[]> hash_index(new int[batch_ids_len]); | ||||
| if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { | if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { | ||||
| MS_LOG(EXCEPTION) << "Process data memset failed."; | |||||
| MS_LOG(ERROR) << "Process data memset failed."; | |||||
| return false; | |||||
| } | } | ||||
| // Get hash swap in/out index and ids. | // Get hash swap in/out index and ids. | ||||
| ParseData(batch_ids, batch_ids_len, hash_index.get()); | |||||
| RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); | |||||
| for (const auto &item : hash_tables_) { | for (const auto &item : hash_tables_) { | ||||
| auto key = worker.GetParamKey(item.first); | auto key = worker.GetParamKey(item.first); | ||||
| auto hash_info = item.second; | auto hash_info = item.second; | ||||
| HashSwapHostToServer(key, hash_info); | |||||
| HashSwapDeviceToHost(hash_info); | |||||
| HashSwapServerToHost(key, hash_info); | |||||
| HashSwapHostToDevice(hash_info); | |||||
| RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info)); | |||||
| RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info)); | |||||
| RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info)); | |||||
| RETURN_IF_FALSE(HashSwapHostToDevice(hash_info)); | |||||
| } | } | ||||
| // Replace the batch_ids by hash index for getNext-op getting hash index as input. | // Replace the batch_ids by hash index for getNext-op getting hash index as input. | ||||
| if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { | if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { | ||||
| MS_LOG(EXCEPTION) << "Process data memcpy failed."; | |||||
| MS_LOG(ERROR) << "Process data memcpy failed."; | |||||
| return false; | |||||
| } | } | ||||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); | |||||
| // Finish the data process and notify data prefetch. | // Finish the data process and notify data prefetch. | ||||
| PsDataPrefetch::GetInstance().FinalizeData(channel_name_); | |||||
| RETURN_IF_FALSE(PsDataPrefetch::GetInstance().FinalizeData(channel_name_)); | |||||
| (void)gettimeofday(&end_time, nullptr); | (void)gettimeofday(&end_time, nullptr); | ||||
| uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec); | ||||
| cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec); | ||||
| MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ | MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ | ||||
| << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ | << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ | ||||
| << ", time cost:" << cost / 1000 << "ms)."; | << ", time cost:" << cost / 1000 << "ms)."; | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||||
| MS_EXCEPTION_IF_NULL(batch_ids); | |||||
| MS_EXCEPTION_IF_NULL(hash_index); | |||||
| bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||||
| MS_ERROR_IF_NULL(batch_ids); | |||||
| MS_ERROR_IF_NULL(hash_index); | |||||
| for (size_t i = 0; i < batch_ids_len; i++) { | for (size_t i = 0; i < batch_ids_len; i++) { | ||||
| bool need_swap_host_to_device = true; | bool need_swap_host_to_device = true; | ||||
| bool need_swap_device_to_host = true; | bool need_swap_device_to_host = true; | ||||
| @@ -360,12 +397,16 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||||
| hash_index[i] = -1; | hash_index[i] = -1; | ||||
| continue; | continue; | ||||
| } | } | ||||
| hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); | |||||
| auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); | |||||
| if (index == INVALID_INDEX_VALUE) { | |||||
| return false; | |||||
| } | |||||
| hash_index[i] = index; | |||||
| if (need_swap_host_to_device) { | if (need_swap_host_to_device) { | ||||
| ParseHostDataHostToDevice(id); | |||||
| RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); | |||||
| } | } | ||||
| if (need_swap_device_to_host) { | if (need_swap_device_to_host) { | ||||
| ParseHostDataDeviceToHost(id); | |||||
| RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); | |||||
| } | } | ||||
| } | } | ||||
| // Each 1000 step prints ps cache hit rate. | // Each 1000 step prints ps cache hit rate. | ||||
| @@ -374,33 +415,28 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||||
| auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; | auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; | ||||
| MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; | MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; | ||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::WaitGraphRun() { | |||||
| bool PsCacheManager::WaitGraphRun() { | |||||
| MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; | MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { | if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { | ||||
| MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ | |||||
| << ", graph running step:" << graph_running_step_ << ")."; | |||||
| MS_LOG(ERROR) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ | |||||
| << ", graph running step:" << graph_running_step_ << ")."; | |||||
| return false; | |||||
| } | } | ||||
| set_current_graph_step(); | set_current_graph_step(); | ||||
| return true; | |||||
| } | } | ||||
| int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { | int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { | ||||
| MS_EXCEPTION_IF_NULL(need_swap_device_to_host); | |||||
| MS_EXCEPTION_IF_NULL(need_swap_host_to_device); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | ||||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | ||||
| int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | ||||
| int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); | int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); | ||||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||||
| MS_EXCEPTION_IF_NULL(host_to_device_ids); | |||||
| auto device_hash_map = embedding_device_cache_->device_hash_map_; | auto device_hash_map = embedding_device_cache_->device_hash_map_; | ||||
| MS_EXCEPTION_IF_NULL(device_hash_map); | |||||
| int index = 0; | int index = 0; | ||||
| auto iter = device_hash_map->id_iter(id); | auto iter = device_hash_map->id_iter(id); | ||||
| if (device_hash_map->IsIdExist(iter)) { | if (device_hash_map->IsIdExist(iter)) { | ||||
| @@ -417,7 +453,9 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b | |||||
| index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | ||||
| &(statistics_info_.device_to_host_size_)); | &(statistics_info_.device_to_host_size_)); | ||||
| if (index == INVALID_INDEX_VALUE) { | if (index == INVALID_INDEX_VALUE) { | ||||
| WaitGraphRun(); | |||||
| if (!WaitGraphRun()) { | |||||
| return INVALID_INDEX_VALUE; | |||||
| } | |||||
| continue; | continue; | ||||
| } | } | ||||
| host_to_device_index[statistics_info_.host_to_device_size_] = index; | host_to_device_index[statistics_info_.host_to_device_size_] = index; | ||||
| @@ -430,21 +468,20 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b | |||||
| return index; | return index; | ||||
| } | } | ||||
| void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | ||||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | ||||
| int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | ||||
| int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | ||||
| int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | ||||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||||
| MS_EXCEPTION_IF_NULL(server_to_host_index); | |||||
| MS_EXCEPTION_IF_NULL(server_to_host_ids); | |||||
| MS_EXCEPTION_IF_NULL(host_to_device_index); | |||||
| MS_ERROR_IF_NULL(host_to_server_index); | |||||
| MS_ERROR_IF_NULL(host_to_server_ids); | |||||
| MS_ERROR_IF_NULL(server_to_host_index); | |||||
| MS_ERROR_IF_NULL(server_to_host_ids); | |||||
| MS_ERROR_IF_NULL(host_to_device_index); | |||||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | auto host_hash_map = embedding_host_cache_->host_hash_map_; | ||||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||||
| MS_ERROR_IF_NULL(host_hash_map); | |||||
| auto iter = host_hash_map->id_iter(id); | auto iter = host_hash_map->id_iter(id); | ||||
| if (host_hash_map->IsIdExist(iter)) { | if (host_hash_map->IsIdExist(iter)) { | ||||
| auto index = iter->second; | auto index = iter->second; | ||||
| @@ -457,7 +494,7 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | ||||
| graph_running_step_, &statistics_info_.host_to_server_size_); | graph_running_step_, &statistics_info_.host_to_server_size_); | ||||
| if (index == INVALID_INDEX_VALUE) { | if (index == INVALID_INDEX_VALUE) { | ||||
| WaitGraphRun(); | |||||
| RETURN_IF_FALSE(WaitGraphRun()); | |||||
| continue; | continue; | ||||
| } | } | ||||
| host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; | ||||
| @@ -466,22 +503,21 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||||
| int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); | ||||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | ||||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | ||||
| int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | ||||
| MS_EXCEPTION_IF_NULL(device_to_host_ids); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_index); | |||||
| MS_EXCEPTION_IF_NULL(host_to_server_ids); | |||||
| MS_EXCEPTION_IF_NULL(device_to_host_index); | |||||
| MS_ERROR_IF_NULL(device_to_host_ids); | |||||
| MS_ERROR_IF_NULL(host_to_server_index); | |||||
| MS_ERROR_IF_NULL(host_to_server_ids); | |||||
| MS_ERROR_IF_NULL(device_to_host_index); | |||||
| auto host_hash_map = embedding_host_cache_->host_hash_map_; | auto host_hash_map = embedding_host_cache_->host_hash_map_; | ||||
| MS_EXCEPTION_IF_NULL(host_hash_map); | |||||
| MS_ERROR_IF_NULL(host_hash_map); | |||||
| int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; | int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; | ||||
| auto iter = host_hash_map->id_iter(swap_device_to_host_id); | auto iter = host_hash_map->id_iter(swap_device_to_host_id); | ||||
| if (host_hash_map->IsIdExist(iter)) { | if (host_hash_map->IsIdExist(iter)) { | ||||
| @@ -495,13 +531,14 @@ void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { | |||||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | ||||
| graph_running_step_, &statistics_info_.host_to_server_size_); | graph_running_step_, &statistics_info_.host_to_server_size_); | ||||
| if (index == INVALID_INDEX_VALUE) { | if (index == INVALID_INDEX_VALUE) { | ||||
| WaitGraphRun(); | |||||
| RETURN_IF_FALSE(WaitGraphRun()); | |||||
| continue; | continue; | ||||
| } | } | ||||
| device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; | ||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, | void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, | ||||
| @@ -514,19 +551,21 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, | |||||
| size_t pos = index * outer_dim_size; | size_t pos = index * outer_dim_size; | ||||
| auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); | auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | |||||
| MS_LOG(ERROR) << "LookUpTable task memcpy failed."; | |||||
| terminated_ = true; | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; | |||||
| MS_LOG(ERROR) << "LookUpTable task memset failed."; | |||||
| terminated_ = true; | |||||
| } | } | ||||
| } | } | ||||
| output_addr += outer_dim_size; | output_addr += outer_dim_size; | ||||
| } | } | ||||
| } | } | ||||
| void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||||
| bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||||
| const int *indices_addr, float *output_addr) { | const int *indices_addr, float *output_addr) { | ||||
| size_t first_dim_size = host_cache_vocab_size_; | size_t first_dim_size = host_cache_vocab_size_; | ||||
| size_t outer_dim_size = embedding_size; | size_t outer_dim_size = embedding_size; | ||||
| @@ -553,9 +592,10 @@ void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l | |||||
| for (size_t j = 0; j < i; j++) { | for (size_t j = 0; j < i; j++) { | ||||
| threads[j].join(); | threads[j].join(); | ||||
| } | } | ||||
| return !terminated_; | |||||
| } | } | ||||
| void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||||
| float *insert_data, float *hash_table_addr) { | float *insert_data, float *hash_table_addr) { | ||||
| size_t first_dim_size = host_cache_vocab_size_; | size_t first_dim_size = host_cache_vocab_size_; | ||||
| size_t thread_num = insert_indices_size / 10000 + 1; | size_t thread_num = insert_indices_size / 10000 + 1; | ||||
| @@ -565,8 +605,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||||
| size_t i; | size_t i; | ||||
| size_t task_offset = 0; | size_t task_offset = 0; | ||||
| auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||||
| auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||||
| auto type_size = sizeof(float); | auto type_size = sizeof(float); | ||||
| size_t lens = outer_dim_size * type_size; | size_t lens = outer_dim_size * type_size; | ||||
| for (size_t i = 0; i < insert_indices_size; ++i) { | for (size_t i = 0; i < insert_indices_size; ++i) { | ||||
| @@ -574,7 +614,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | if (index >= 0 && index < SizeToInt(first_dim_size)) { | ||||
| auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); | auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed."; | |||||
| MS_LOG(ERROR) << "Insert hash table task memcpy failed."; | |||||
| terminated_ = true; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -596,94 +637,101 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||||
| for (size_t j = 0; j < i; j++) { | for (size_t j = 0; j < i; j++) { | ||||
| threads[j].join(); | threads[j].join(); | ||||
| } | } | ||||
| return !terminated_; | |||||
| } | } | ||||
| void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||||
| auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); | ||||
| auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); | ||||
| auto swap_indices_size = statistics_info_.host_to_device_size_; | auto swap_indices_size = statistics_info_.host_to_device_size_; | ||||
| if (swap_indices_size == 0) { | if (swap_indices_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto embedding_size = hash_info.embedding_size; | auto embedding_size = hash_info.embedding_size; | ||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | ||||
| auto hash_table_size = hash_info.device_address.size; | auto hash_table_size = hash_info.device_address.size; | ||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | ||||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | ||||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index, | |||||
| swap_out_data.get()); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_out_data.get(), | |||||
| swap_indices_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_indices_size); | |||||
| } | |||||
| void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, | |||||
| host_cache_host_to_device_index, swap_out_data.get())); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(), | |||||
| swap_indices_size * embedding_size * sizeof(float))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, | |||||
| device_cache_host_to_device_index, | |||||
| swap_indices_size * sizeof(int))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( | |||||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||||
| hash_table_size, embedding_size, swap_indices_size)); | |||||
| return true; | |||||
| } | |||||
| bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||||
| auto swap_indices_size = statistics_info_.device_to_host_size_; | auto swap_indices_size = statistics_info_.device_to_host_size_; | ||||
| auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); | ||||
| auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); | ||||
| if (swap_indices_size == 0) { | if (swap_indices_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | ||||
| auto hash_table_size = hash_info.device_address.size; | auto hash_table_size = hash_info.device_address.size; | ||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | ||||
| auto embedding_size = hash_info.embedding_size; | auto embedding_size = hash_info.embedding_size; | ||||
| auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size); | ||||
| embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_indices_size); | |||||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(), | |||||
| embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_indices_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->SynchronizeStream(); | |||||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, | |||||
| swap_out_data.get(), host_hash_table_addr); | |||||
| } | |||||
| void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, | |||||
| device_cache_device_to_host_index, | |||||
| swap_indices_size * sizeof(int))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( | |||||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||||
| hash_table_size, embedding_size, swap_indices_size)); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( | |||||
| swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_indices_size * embedding_size * sizeof(float))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); | |||||
| RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, | |||||
| swap_out_data.get(), host_hash_table_addr)); | |||||
| return true; | |||||
| } | |||||
| bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { | |||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||||
| auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | ||||
| auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | ||||
| auto swap_indices_size = statistics_info_.host_to_server_size_; | auto swap_indices_size = statistics_info_.host_to_server_size_; | ||||
| if (swap_indices_size == 0) { | if (swap_indices_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | ::ps::SArray<int> lookup_ids(swap_indices_size, 0); | ||||
| ::ps::SArray<float> swap_out_data; | ::ps::SArray<float> swap_out_data; | ||||
| auto embedding_size = hash_info.embedding_size; | auto embedding_size = hash_info.embedding_size; | ||||
| swap_out_data.resize(swap_indices_size * embedding_size); | swap_out_data.resize(swap_indices_size * embedding_size); | ||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | ||||
| LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, | |||||
| swap_out_data.data()); | |||||
| RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, | |||||
| swap_out_data.data())); | |||||
| auto copy_len = swap_indices_size * sizeof(int); | auto copy_len = swap_indices_size * sizeof(int); | ||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); | auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||||
| return false; | |||||
| } | } | ||||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_host_cache_); | |||||
| bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { | |||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||||
| auto swap_indices_size = statistics_info_.server_to_host_size_; | auto swap_indices_size = statistics_info_.server_to_host_size_; | ||||
| auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); | ||||
| auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); | ||||
| if (swap_indices_size == 0) { | if (swap_indices_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get()); | ||||
| auto embedding_size = hash_info.embedding_size; | auto embedding_size = hash_info.embedding_size; | ||||
| @@ -693,47 +741,50 @@ void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ | |||||
| auto copy_len = swap_indices_size * sizeof(int); | auto copy_len = swap_indices_size * sizeof(int); | ||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); | auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||||
| return false; | |||||
| } | } | ||||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | ||||
| InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), | |||||
| host_hash_table_addr); | |||||
| RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, | |||||
| lookup_result.data(), host_hash_table_addr)); | |||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||||
| bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, | |||||
| const HashTableInfo &hash_info) { | const HashTableInfo &hash_info) { | ||||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_data); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_ERROR_IF_NULL(swap_out_index); | |||||
| MS_ERROR_IF_NULL(swap_out_data); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||||
| auto swap_out_index_size = statistics_info_.device_to_host_size_; | auto swap_out_index_size = statistics_info_.device_to_host_size_; | ||||
| if (swap_out_index_size == 0) { | if (swap_out_index_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | ||||
| auto hash_table_size = hash_info.device_address.size; | auto hash_table_size = hash_info.device_address.size; | ||||
| auto embedding_size = hash_info.embedding_size; | auto embedding_size = hash_info.embedding_size; | ||||
| swap_out_data->resize(swap_out_index_size * embedding_size); | swap_out_data->resize(swap_out_index_size * embedding_size); | ||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index, | |||||
| swap_out_index_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_out_index_size); | |||||
| embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(), | |||||
| embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_out_index_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->RecordEvent(); | |||||
| } | |||||
| void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( | |||||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||||
| hash_table_size, embedding_size, swap_out_index_size)); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( | |||||
| swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_, | |||||
| swap_out_index_size * embedding_size * sizeof(float))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->RecordEvent()); | |||||
| return true; | |||||
| } | |||||
| bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||||
| size_t key) { | size_t key) { | ||||
| MS_EXCEPTION_IF_NULL(swap_in_ids); | |||||
| MS_EXCEPTION_IF_NULL(swap_in_index); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_ERROR_IF_NULL(swap_in_ids); | |||||
| MS_ERROR_IF_NULL(swap_in_index); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||||
| auto swap_in_ids_size = statistics_info_.host_to_device_size_; | auto swap_in_ids_size = statistics_info_.host_to_device_size_; | ||||
| if (swap_in_ids_size == 0) { | if (swap_in_ids_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr); | ||||
| auto hash_table_size = hash_info.device_address.size; | auto hash_table_size = hash_info.device_address.size; | ||||
| @@ -745,42 +796,44 @@ void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons | |||||
| auto copy_len = swap_in_ids_size * sizeof(int); | auto copy_len = swap_in_ids_size * sizeof(int); | ||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); | auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||||
| return false; | |||||
| } | } | ||||
| worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); | ||||
| // Hash swap-in in device. | // Hash swap-in in device. | ||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, | |||||
| lookup_result.data(), | |||||
| swap_in_ids_size * embedding_size * sizeof(float)); | |||||
| embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index, | |||||
| swap_in_ids_size * sizeof(int)); | |||||
| embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, | |||||
| embedding_device_cache_->hash_swap_index_addr_, hash_table_size, | |||||
| embedding_size, swap_in_ids_size); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( | |||||
| embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), | |||||
| swap_in_ids_size * embedding_size * sizeof(float))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, | |||||
| swap_in_index, swap_in_ids_size * sizeof(int))); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( | |||||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | |||||
| hash_table_size, embedding_size, swap_in_ids_size)); | |||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||||
| bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) { | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_->cache_); | |||||
| MS_ERROR_IF_NULL(swap_out_ids); | |||||
| auto swap_out_ids_size = statistics_info_.device_to_host_size_; | auto swap_out_ids_size = statistics_info_.device_to_host_size_; | ||||
| if (swap_out_ids_size == 0) { | if (swap_out_ids_size == 0) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| ::ps::SArray<int> lookup_ids(swap_out_ids_size, 0); | ::ps::SArray<int> lookup_ids(swap_out_ids_size, 0); | ||||
| auto copy_len = swap_out_ids_size * sizeof(int); | auto copy_len = swap_out_ids_size * sizeof(int); | ||||
| auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); | auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; | |||||
| MS_LOG(ERROR) << "Lookup id memcpy failed."; | |||||
| return false; | |||||
| } | } | ||||
| // Need synchronize event to ensure that the swap-out in device is completed. | // Need synchronize event to ensure that the swap-out in device is completed. | ||||
| embedding_device_cache_->cache_->SynchronizeEvent(); | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent()); | |||||
| worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); | ||||
| return true; | |||||
| } | } | ||||
| void PsCacheManager::DumpHashTables(bool dump_device_tables) const { | void PsCacheManager::DumpHashTables(bool dump_device_tables) const { | ||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_); | |||||
| MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); | |||||
| for (const auto &item : hash_tables_) { | for (const auto &item : hash_tables_) { | ||||
| const auto ¶m_name = item.first; | const auto ¶m_name = item.first; | ||||
| size_t cache_vocab_size = item.second.cache_vocab_size; | size_t cache_vocab_size = item.second.cache_vocab_size; | ||||
| @@ -126,6 +126,8 @@ class PsCacheManager { | |||||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | bool initialized_ps_cache() const { return initialized_ps_cache_; } | ||||
| void DoProcessData(uint32_t device_id, void *context); | void DoProcessData(uint32_t device_id, void *context); | ||||
| void IncreaseGraphStep(const std::string &channel_name); | void IncreaseGraphStep(const std::string &channel_name); | ||||
| bool terminated() const { return terminated_; } | |||||
| void Finalize(); | |||||
| void DumpHashTables(bool dump_device_tables = false) const; | void DumpHashTables(bool dump_device_tables = false) const; | ||||
| private: | private: | ||||
| @@ -133,7 +135,7 @@ class PsCacheManager { | |||||
| ~PsCacheManager() = default; | ~PsCacheManager() = default; | ||||
| PsCacheManager(const PsCacheManager &) = delete; | PsCacheManager(const PsCacheManager &) = delete; | ||||
| PsCacheManager &operator=(const PsCacheManager &) = delete; | PsCacheManager &operator=(const PsCacheManager &) = delete; | ||||
| void IncreaseStep(); | |||||
| bool IncreaseStep(); | |||||
| void set_current_graph_step() { graph_running_step_ = graph_step_; } | void set_current_graph_step() { graph_running_step_ = graph_step_; } | ||||
| std::string channel_name(); | std::string channel_name(); | ||||
| void set_channel_name(const std::string channel_name); | void set_channel_name(const std::string channel_name); | ||||
| @@ -141,23 +143,23 @@ class PsCacheManager { | |||||
| void AllocMemForHashTable(); | void AllocMemForHashTable(); | ||||
| void SetLocalIdRank(); | void SetLocalIdRank(); | ||||
| void ProcessDataTask(uint32_t device_id, void *context); | void ProcessDataTask(uint32_t device_id, void *context); | ||||
| void ProcessData(); | |||||
| void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | |||||
| void WaitGraphRun(); | |||||
| bool ProcessData(); | |||||
| bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | |||||
| bool WaitGraphRun(); | |||||
| int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); | int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); | ||||
| void ParseHostDataHostToDevice(size_t id); | |||||
| void ParseHostDataDeviceToHost(size_t id); | |||||
| void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info); | |||||
| void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||||
| void HashSwapHostToDevice(const HashTableInfo &hash_info); | |||||
| void HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||||
| void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | |||||
| void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | |||||
| void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||||
| bool ParseHostDataHostToDevice(size_t id); | |||||
| bool ParseHostDataDeviceToHost(size_t id); | |||||
| bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info); | |||||
| bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||||
| bool HashSwapHostToDevice(const HashTableInfo &hash_info); | |||||
| bool HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||||
| bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | |||||
| bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | |||||
| bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||||
| float *hash_table_addr); | float *hash_table_addr); | ||||
| void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||||
| bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||||
| const int *indices_addr, float *output_addr); | const int *indices_addr, float *output_addr); | ||||
| void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||||
| bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key); | |||||
| void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, | ||||
| const int *indices_addr, float *output_addr); | const int *indices_addr, float *output_addr); | ||||
| bool CheckFinishInsertInitInfo() const; | bool CheckFinishInsertInitInfo() const; | ||||
| @@ -172,6 +174,7 @@ class PsCacheManager { | |||||
| std::mutex data_mutex_; | std::mutex data_mutex_; | ||||
| std::condition_variable data_prase_; | std::condition_variable data_prase_; | ||||
| std::condition_variable insert_init_info_; | std::condition_variable insert_init_info_; | ||||
| std::thread process_data_thread_; | |||||
| std::map<std::string, HashTableInfo> hash_tables_; | std::map<std::string, HashTableInfo> hash_tables_; | ||||
| std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_; | ||||
| @@ -185,6 +188,8 @@ class PsCacheManager { | |||||
| std::pair<size_t, size_t> range_bound_; | std::pair<size_t, size_t> range_bound_; | ||||
| std::atomic_bool finish_insert_init_info_{false}; | std::atomic_bool finish_insert_init_info_{false}; | ||||
| std::atomic_bool finish_init_parameter_server_{false}; | std::atomic_bool finish_init_parameter_server_{false}; | ||||
| std::atomic_bool running_{false}; | |||||
| std::atomic_bool terminated_{false}; | |||||
| }; | }; | ||||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | ||||
| @@ -28,11 +28,9 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s | |||||
| if (iter != ps_data_channel_map_.end()) { | if (iter != ps_data_channel_map_.end()) { | ||||
| MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; | MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; | ||||
| auto channel = iter->second; | auto channel = iter->second; | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| channel->set_step_num(step_num); | channel->set_step_num(step_num); | ||||
| } else { | } else { | ||||
| auto channel = std::make_shared<PsDataChannel>(channel_name, step_num); | auto channel = std::make_shared<PsDataChannel>(channel_name, step_num); | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| (void)ps_data_channel_map_.emplace(channel_name, channel); | (void)ps_data_channel_map_.emplace(channel_name, channel); | ||||
| } | } | ||||
| } | } | ||||
| @@ -40,71 +38,95 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s | |||||
| std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { | std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { | ||||
| auto iter = ps_data_channel_map_.find(channel_name); | auto iter = ps_data_channel_map_.find(channel_name); | ||||
| if (iter == ps_data_channel_map_.end()) { | if (iter == ps_data_channel_map_.end()) { | ||||
| MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name; | |||||
| MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name; | |||||
| return nullptr; | |||||
| } | } | ||||
| return iter->second; | return iter->second; | ||||
| } | } | ||||
| void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||||
| bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { | |||||
| if (cache_enable_ == false) { | if (cache_enable_ == false) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| MS_LOG(WARNING) << "No data prefetch."; | MS_LOG(WARNING) << "No data prefetch."; | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto channel = ps_data_channel(channel_name); | auto channel = ps_data_channel(channel_name); | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| MS_ERROR_IF_NULL(channel); | |||||
| channel->set_data(data, data_size); | channel->set_data(data, data_size); | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| data_ready_ = true; | data_ready_ = true; | ||||
| data_process_.notify_one(); | data_process_.notify_one(); | ||||
| if (!need_wait_) { | |||||
| return true; | |||||
| } | |||||
| for (int i = 0; i < 10; i++) { | for (int i = 0; i < 10; i++) { | ||||
| if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) { | |||||
| return; | |||||
| if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), | |||||
| [this] { return data_ready_ == false || need_wait_ == false; })) { | |||||
| return true; | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; | MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size."; | |||||
| MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size."; | |||||
| return false; | |||||
| } | } | ||||
| void PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||||
| bool PsDataPrefetch::FinalizeData(const std::string &channel_name) { | |||||
| if (cache_enable_ == false) { | if (cache_enable_ == false) { | ||||
| return; | |||||
| return true; | |||||
| } | } | ||||
| auto channel = ps_data_channel(channel_name); | auto channel = ps_data_channel(channel_name); | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| MS_ERROR_IF_NULL(channel); | |||||
| channel->ResetData(); | channel->ResetData(); | ||||
| std::unique_lock<std::mutex> locker(data_mutex_); | std::unique_lock<std::mutex> locker(data_mutex_); | ||||
| data_ready_ = false; | data_ready_ = false; | ||||
| data_prefetch_.notify_one(); | data_prefetch_.notify_one(); | ||||
| if (!need_wait_) { | |||||
| return true; | |||||
| } | |||||
| for (int i = 0; i < 10; i++) { | for (int i = 0; i < 10; i++) { | ||||
| if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) { | |||||
| return; | |||||
| if (data_process_.wait_for(locker, std::chrono::seconds(30), | |||||
| [this] { return data_ready_ == true || need_wait_ == false; })) { | |||||
| return true; | |||||
| } else { | } else { | ||||
| MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; | MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout."; | |||||
| MS_LOG(ERROR) << "Ps cache data prefetch timeout."; | |||||
| return false; | |||||
| } | } | ||||
| void *PsDataPrefetch::data(const std::string &channel_name) const { | void *PsDataPrefetch::data(const std::string &channel_name) const { | ||||
| auto channel = ps_data_channel(channel_name); | auto channel = ps_data_channel(channel_name); | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| if (channel == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return channel->data(); | return channel->data(); | ||||
| } | } | ||||
| size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | size_t PsDataPrefetch::data_size(const std::string &channel_name) const { | ||||
| auto channel = ps_data_channel(channel_name); | auto channel = ps_data_channel(channel_name); | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| if (channel == nullptr) { | |||||
| return 0; | |||||
| } | |||||
| return channel->data_size(); | return channel->data_size(); | ||||
| } | } | ||||
| void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { | |||||
| void PsDataPrefetch::NotifyFinalize() { | |||||
| need_wait_ = false; | |||||
| data_prefetch_.notify_one(); | |||||
| data_process_.notify_one(); | |||||
| } | |||||
| bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { | |||||
| auto channel = ps_data_channel(channel_name); | auto channel = ps_data_channel(channel_name); | ||||
| MS_EXCEPTION_IF_NULL(channel); | |||||
| if (channel == nullptr) { | |||||
| return false; | |||||
| } | |||||
| channel->TryWakeChannel(); | channel->TryWakeChannel(); | ||||
| return true; | |||||
| } | } | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <atomic> | |||||
| #include <condition_variable> | #include <condition_variable> | ||||
| #include "ps/ps_cache/ps_data/ps_data_channel.h" | #include "ps/ps_cache/ps_data/ps_data_channel.h" | ||||
| @@ -36,11 +37,12 @@ class EXPORT PsDataPrefetch { | |||||
| EXPORT bool cache_enable() const { return cache_enable_; } | EXPORT bool cache_enable() const { return cache_enable_; } | ||||
| EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } | ||||
| EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); | ||||
| EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||||
| EXPORT void FinalizeData(const std::string &channel_name); | |||||
| EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size); | |||||
| EXPORT bool FinalizeData(const std::string &channel_name); | |||||
| EXPORT void NotifyFinalize(); | |||||
| EXPORT void *data(const std::string &channel_name) const; | EXPORT void *data(const std::string &channel_name) const; | ||||
| EXPORT size_t data_size(const std::string &channel_name) const; | EXPORT size_t data_size(const std::string &channel_name) const; | ||||
| EXPORT void TryWakeChannel(const std::string &channel_name); | |||||
| EXPORT bool TryWakeChannel(const std::string &channel_name); | |||||
| private: | private: | ||||
| PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} | PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} | ||||
| @@ -54,6 +56,7 @@ class EXPORT PsDataPrefetch { | |||||
| std::mutex data_mutex_; | std::mutex data_mutex_; | ||||
| std::condition_variable data_prefetch_; | std::condition_variable data_prefetch_; | ||||
| std::condition_variable data_process_; | std::condition_variable data_process_; | ||||
| std::atomic_bool need_wait_{true}; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -17,10 +17,10 @@ | |||||
| #include "ps/ps_context.h" | #include "ps/ps_context.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #include "backend/kernel_compiler/kernel.h" | #include "backend/kernel_compiler/kernel.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| #include "ps/ps_cache/ps_cache_manager.h" | #include "ps/ps_cache/ps_cache_manager.h" | ||||
| #include "ps/ps_cache/ps_data/ps_data_prefetch.h" | |||||
| #endif | #endif | ||||
| namespace mindspore { | namespace mindspore { | ||||
| @@ -62,7 +62,12 @@ void PSContext::Reset() { | |||||
| is_worker_ = false; | is_worker_ = false; | ||||
| is_pserver_ = false; | is_pserver_ = false; | ||||
| is_sched_ = false; | is_sched_ = false; | ||||
| set_cache_enable(false); | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| ps_cache_instance.Finalize(); | |||||
| set_cache_enable(false); | |||||
| } | |||||
| #endif | |||||
| } | } | ||||
| std::string PSContext::ms_role() const { | std::string PSContext::ms_role() const { | ||||
| @@ -62,6 +62,16 @@ namespace gpu { | |||||
| } \ | } \ | ||||
| } | } | ||||
| #define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \ | |||||
| { \ | |||||
| cudaError_t status = (expression); \ | |||||
| if (status != cudaSuccess) { \ | |||||
| MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \ | |||||
| << cudaGetErrorString(status); \ | |||||
| return false; \ | |||||
| } \ | |||||
| } | |||||
| #define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \ | #define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \ | ||||
| { \ | { \ | ||||
| cudaError_t status = (expression); \ | cudaError_t status = (expression); \ | ||||
| @@ -199,6 +199,14 @@ class LogWriter { | |||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| #define MS_ERROR_IF_NULL(ptr) \ | |||||
| do { \ | |||||
| if ((ptr) == nullptr) { \ | |||||
| MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ | |||||
| return false; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #ifdef DEBUG | #ifdef DEBUG | ||||
| #include <cassert> | #include <cassert> | ||||
| #define MS_ASSERT(f) assert(f) | #define MS_ASSERT(f) assert(f) | ||||