Browse Source

!10190 ps cache data process thread support exit when exceptions occur

From: @zyli2020
Reviewed-by: @limingqi107
Signed-off-by: @limingqi107
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
0c78a8a9d5
15 changed files with 470 additions and 316 deletions
  1. +3
    -1
      mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc
  2. +3
    -1
      mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc
  3. +5
    -3
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  4. +82
    -55
      mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc
  5. +9
    -9
      mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h
  6. +38
    -30
      mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc
  7. +8
    -8
      mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h
  8. +16
    -9
      mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h
  9. +213
    -160
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc
  10. +20
    -15
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h
  11. +42
    -20
      mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc
  12. +6
    -3
      mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h
  13. +7
    -2
      mindspore/ccsrc/ps/ps_context.cc
  14. +10
    -0
      mindspore/ccsrc/runtime/device/gpu/gpu_common.h
  15. +8
    -0
      mindspore/core/utils/log_adapter.h

+ 3
- 1
mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc View File

@@ -304,7 +304,9 @@ Status DeviceQueueOp::PushDataToGPU() {

// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_);
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) {
return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data.");
}
}
while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) {
BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME);


+ 3
- 1
mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc View File

@@ -55,7 +55,9 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe
#if ENABLE_D
// Data prefetch only when PS mode enables cache.
if (items.size() > 0) {
ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_);
if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) {
return FAILED;
}
}
#endif
if (tdt::TdtHostPushData(channel_name, items) != 0) {


+ 5
- 3
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -53,6 +53,7 @@
#include "ps/util.h"
#include "ps/worker.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "ps/ps_cache/ps_cache_manager.h"
#endif

#if (ENABLE_GE || ENABLE_D)
@@ -1083,9 +1084,10 @@ void ClearResAtexit() {
pynative::ClearPyNativeSession();
session::ClearPythonParasMap();
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode()) {
if (ps::Util::IsRoleOfWorker()) {
ps::worker.Finalize();
if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) {
ps::worker.Finalize();
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps::ps_cache_instance.Finalize();
}
}
#endif


+ 82
- 55
mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc View File

@@ -37,155 +37,178 @@ namespace ps {
namespace ascend {
MS_REG_PS_CACHE(kAscendDevice, AscendPsCache);
namespace {
void SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
bool SetProtoInputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(proto);
MS_ERROR_IF_NULL(proto);
if (data_shape.size() != data_type.size()) {
MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type.";
MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type.";
return false;
}
for (size_t input_index = 0; input_index < data_shape.size(); input_index++) {
::mindspore::Tensor *proto_inputs = proto->add_inputs();
MS_EXCEPTION_IF_NULL(proto_inputs);
MS_ERROR_IF_NULL(proto_inputs);
auto input_shape = data_shape[input_index];
mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape();
MS_EXCEPTION_IF_NULL(tensorShape);
MS_ERROR_IF_NULL(tensorShape);
for (auto item : input_shape) {
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
MS_EXCEPTION_IF_NULL(dim);
MS_ERROR_IF_NULL(dim);
dim->set_size((::google::protobuf::int64)item);
}
auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]);
proto_inputs->set_tensor_type(input_type);
proto_inputs->set_mem_device("HBM");
}
return true;
}

void SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
bool SetProtoOutputs(const std::vector<std::vector<size_t>> &data_shape, const std::vector<TypeId> &data_type,
mindspore::NodeDef *proto) {
MS_EXCEPTION_IF_NULL(proto);
MS_ERROR_IF_NULL(proto);
if (data_shape.size() != data_type.size()) {
MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type.";
MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type.";
return false;
}
for (size_t output_index = 0; output_index < data_shape.size(); output_index++) {
::mindspore::Tensor *proto_outputs = proto->add_outputs();
MS_EXCEPTION_IF_NULL(proto_outputs);
MS_ERROR_IF_NULL(proto_outputs);
auto output_shape = data_shape[output_index];
mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape();
MS_EXCEPTION_IF_NULL(tensorShape);
MS_ERROR_IF_NULL(tensorShape);
for (auto item : output_shape) {
mindspore::TensorShape_Dim *dim = tensorShape->add_dim();
MS_EXCEPTION_IF_NULL(dim);
MS_ERROR_IF_NULL(dim);
dim->set_size((::google::protobuf::int64)item);
}
auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]);
proto_outputs->set_tensor_type(output_type);
proto_outputs->set_mem_device("HBM");
}
return true;
}

void SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info,
bool SetNodedefProto(const std::shared_ptr<KernelNodeInfo> &op_info,
const std::shared_ptr<kernel::AicpuOpKernelMod> &kernel_mod_ptr) {
MS_EXCEPTION_IF_NULL(op_info);
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
MS_ERROR_IF_NULL(op_info);
MS_ERROR_IF_NULL(kernel_mod_ptr);
mindspore::NodeDef proto;
proto.set_op(op_info->op_name_);
SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto);
SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto);
RETURN_IF_FALSE(SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto));
RETURN_IF_FALSE(SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto));
std::string nodeDefStr;
if (!proto.SerializeToString(&nodeDefStr)) {
MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed.";
MS_LOG(ERROR) << "Serialize nodeDef to string failed.";
return false;
}
MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_;
kernel_mod_ptr->SetNodeDef(nodeDefStr);
return true;
}
} // namespace

void AscendPsCache::InitDevice(uint32_t device_id, const void *context) {
MS_EXCEPTION_IF_NULL(context);
bool AscendPsCache::InitDevice(uint32_t device_id, const void *context) {
MS_ERROR_IF_NULL(context);
auto ret = rtSetDevice(device_id);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]";
MS_LOG(ERROR) << "Call rtSetDevice, ret[" << ret << "]";
return false;
}
auto rt_context = const_cast<rtContext_t>(context);
ret = rtCtxSetCurrent(rt_context);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
MS_LOG(ERROR) << "Call rtCtxSetCurrent, ret[" << ret << "]";
return false;
}
ret = rtStreamCreate(&stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]";
MS_LOG(ERROR) << "Call rtStreamCreate, ret[" << ret << "]";
return false;
}
return true;
}

void *AscendPsCache::MallocMemory(size_t size) {
return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size);
}

void AscendPsCache::MallocConstantMemory(size_t constant_value) {
bool AscendPsCache::MallocConstantMemory(size_t constant_value) {
offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
MS_EXCEPTION_IF_NULL(offset_addr_);
MS_ERROR_IF_NULL(offset_addr_);
rtMemset(offset_addr_, sizeof(int), 0, sizeof(int));
cache_vocab_size_addr_ =
reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int)));
MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_);
MS_ERROR_IF_NULL(cache_vocab_size_addr_);
rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int));
return true;
}

void AscendPsCache::RecordEvent() {
bool AscendPsCache::RecordEvent() {
event_.reset(new rtEvent_t());
auto ret = rtEventCreate(&(*event_));
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Create event failed";
MS_LOG(ERROR) << "Create event failed";
return false;
}
ret = rtEventRecord(*event_, stream_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Record event failed";
MS_LOG(ERROR) << "Record event failed";
return false;
}
return true;
}

void AscendPsCache::SynchronizeEvent() {
bool AscendPsCache::SynchronizeEvent() {
auto ret = rtEventSynchronize(*event_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed";
MS_LOG(ERROR) << "tEventSynchronize failed";
return false;
}
ret = rtEventDestroy(*event_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed";
MS_LOG(ERROR) << "rtEventDestroy failed";
return false;
}
return true;
}

void AscendPsCache::SynchronizeStream() {
bool AscendPsCache::SynchronizeStream() {
auto ret = rtStreamSynchronize(stream_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed";
MS_LOG(ERROR) << "rtStreamSynchronize failed";
return false;
}
return true;
}

void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_EXCEPTION_IF_NULL(dst);
MS_EXCEPTION_IF_NULL(src);
bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed";
MS_LOG(ERROR) << "rtMemcpyAsync failed";
return false;
}
return true;
}

void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_EXCEPTION_IF_NULL(dst);
MS_EXCEPTION_IF_NULL(src);
bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed";
MS_LOG(ERROR) << "rtMemcpyAsync failed";
return false;
}
return true;
}

void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) {
MS_EXCEPTION_IF_NULL(hash_table_addr);
MS_EXCEPTION_IF_NULL(swap_out_value_addr);
MS_EXCEPTION_IF_NULL(swap_out_index_addr);
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_out_value_addr);
MS_ERROR_IF_NULL(swap_out_index_addr);
auto hash_swap_out_mod = std::make_shared<kernel::AicpuOpKernelMod>();
MS_EXCEPTION_IF_NULL(hash_swap_out_mod);
MS_ERROR_IF_NULL(hash_swap_out_mod);
hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName);
std::vector<std::vector<size_t>> input_shape;
std::vector<std::vector<size_t>> output_shape;
@@ -197,7 +220,7 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
output_shape.push_back({swap_out_size, embedding_size});
auto op_info =
std::make_shared<KernelNodeInfo>(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type);
SetNodedefProto(op_info, hash_swap_out_mod);
RETURN_IF_FALSE(SetNodedefProto(op_info, hash_swap_out_mod));

AddressPtrList kernel_inputs;
AddressPtrList kernel_outputs = {
@@ -208,17 +231,19 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr
kernel_inputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(EXCEPTION) << "Hash swap out launch failed.";
MS_LOG(ERROR) << "Hash swap out launch failed.";
return false;
}
return true;
}

void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) {
MS_EXCEPTION_IF_NULL(hash_table_addr);
MS_EXCEPTION_IF_NULL(swap_in_value_addr);
MS_EXCEPTION_IF_NULL(swap_in_index_addr);
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_in_value_addr);
MS_ERROR_IF_NULL(swap_in_index_addr);
auto hash_swap_in_mod = std::make_shared<kernel::AicpuOpKernelMod>();
MS_EXCEPTION_IF_NULL(hash_swap_in_mod);
MS_ERROR_IF_NULL(hash_swap_in_mod);
hash_swap_in_mod->SetNodeName(kernel::kUpdateCache);
std::vector<std::vector<size_t>> input_shape;
std::vector<std::vector<size_t>> output_shape;
@@ -245,8 +270,10 @@ void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr,
kernel_outputs.push_back(std::make_shared<Address>(offset_addr_, sizeof(int)));
auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
if (!ret) {
MS_LOG(EXCEPTION) << "Hash swap in launch failed.";
MS_LOG(ERROR) << "Hash swap in launch failed.";
return false;
}
return true;
}
} // namespace ascend
} // namespace ps


+ 9
- 9
mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h View File

@@ -49,17 +49,17 @@ class AscendPsCache : public PsCacheBasic {
public:
AscendPsCache() = default;
~AscendPsCache() override = default;
void InitDevice(uint32_t device_id, const void *context) override;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void MallocConstantMemory(size_t constant_value) override;
void RecordEvent() override;
void SynchronizeEvent() override;
void SynchronizeStream() override;
void CopyHostMemToDevice(void *dst, void *src, size_t size) override;
void CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
bool MallocConstantMemory(size_t constant_value) override;
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_out_size) override;
void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_in_size) override;

private:


+ 38
- 30
mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc View File

@@ -25,67 +25,75 @@ namespace mindspore {
namespace ps {
namespace gpu {
MS_REG_PS_CACHE(kGPUDevice, GPUPsCache);
void GPUPsCache::InitDevice(uint32_t device_id, const void *) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed")
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
"Cuda create stream failed");
bool GPUPsCache::InitDevice(uint32_t device_id, const void *) {
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed")
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
"Cuda create stream failed");
return true;
}

void *GPUPsCache::MallocMemory(size_t size) {
return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size);
}

void GPUPsCache::RecordEvent() {
bool GPUPsCache::RecordEvent() {
event_.reset(new cudaEvent_t());
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda record event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda record event failed");
return true;
}

void GPUPsCache::SynchronizeEvent() {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
bool GPUPsCache::SynchronizeEvent() {
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed");
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed");
return true;
}

void GPUPsCache::SynchronizeStream() {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
"Cuda sync stream failed");
bool GPUPsCache::SynchronizeStream() {
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream_)),
"Cuda sync stream failed");
return true;
}

void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_EXCEPTION_IF_NULL(dst);
MS_EXCEPTION_IF_NULL(src);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda memcpy failed");
return true;
}

void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_EXCEPTION_IF_NULL(dst);
MS_EXCEPTION_IF_NULL(src);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) {
MS_ERROR_IF_NULL(dst);
MS_ERROR_IF_NULL(src);
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast<cudaStream_t>(stream_)),
"Cuda memcpy failed");
return true;
}

void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
bool GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t,
size_t embedding_size, size_t swap_out_size) {
MS_EXCEPTION_IF_NULL(hash_table_addr);
MS_EXCEPTION_IF_NULL(swap_out_value_addr);
MS_EXCEPTION_IF_NULL(swap_out_index_addr);
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_out_value_addr);
MS_ERROR_IF_NULL(swap_out_index_addr);
DoHashSwapOut(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_out_value_addr),
reinterpret_cast<int *>(swap_out_index_addr), swap_out_size, embedding_size,
reinterpret_cast<cudaStream_t>(stream_));
return true;
}

void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
bool GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t,
size_t embedding_size, size_t swap_in_size) {
MS_EXCEPTION_IF_NULL(hash_table_addr);
MS_EXCEPTION_IF_NULL(swap_in_value_addr);
MS_EXCEPTION_IF_NULL(swap_in_index_addr);
MS_ERROR_IF_NULL(hash_table_addr);
MS_ERROR_IF_NULL(swap_in_value_addr);
MS_ERROR_IF_NULL(swap_in_index_addr);
DoHashSwapIn(reinterpret_cast<float *>(hash_table_addr), reinterpret_cast<float *>(swap_in_value_addr),
reinterpret_cast<int *>(swap_in_index_addr), swap_in_size, embedding_size,
reinterpret_cast<cudaStream_t>(stream_));
return true;
}
} // namespace gpu
} // namespace ps


+ 8
- 8
mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h View File

@@ -28,16 +28,16 @@ class GPUPsCache : public PsCacheBasic {
public:
GPUPsCache() = default;
~GPUPsCache() override = default;
void InitDevice(uint32_t device_id, const void *context) override;
bool InitDevice(uint32_t device_id, const void *context) override;
void *MallocMemory(size_t size) override;
void RecordEvent() override;
void SynchronizeEvent() override;
void SynchronizeStream() override;
void CopyHostMemToDevice(void *dst, void *src, size_t size) override;
void CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
bool RecordEvent() override;
bool SynchronizeEvent() override;
bool SynchronizeStream() override;
bool CopyHostMemToDevice(void *dst, void *src, size_t size) override;
bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override;
bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_out_size) override;
void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size,
size_t embedding_size, size_t swap_in_size) override;

private:


+ 16
- 9
mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h View File

@@ -21,21 +21,28 @@

namespace mindspore {
namespace ps {
#define RETURN_IF_FALSE(condition) \
do { \
if (!(condition)) { \
return false; \
} \
} while (false)

class PsCacheBasic {
public:
PsCacheBasic() = default;
virtual ~PsCacheBasic() = default;
virtual void InitDevice(uint32_t device_id, const void *context) = 0;
virtual bool InitDevice(uint32_t device_id, const void *context) = 0;
virtual void *MallocMemory(size_t size) = 0;
virtual void MallocConstantMemory(size_t constant_value) {}
virtual void RecordEvent() = 0;
virtual void SynchronizeEvent() = 0;
virtual void SynchronizeStream() = 0;
virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
virtual bool MallocConstantMemory(size_t constant_value) { return true; }
virtual bool RecordEvent() = 0;
virtual bool SynchronizeEvent() = 0;
virtual bool SynchronizeStream() = 0;
virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0;
virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0;
virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0;
virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr,
size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0;

protected:


+ 213
- 160
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc View File

@@ -170,8 +170,10 @@ void PsCacheManager::AddEmbeddingTable() const {
void PsCacheManager::InitParameterServer() {
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_;
std::unique_lock<std::mutex> locker(data_mutex_);
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; });

insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; });
if (!running_) {
return;
}
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t key = worker.SetParamKey(param_name);
@@ -224,7 +226,9 @@ void PsCacheManager::AllocMemForHashTable() {
embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast<float *>(
embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float)));
MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_);
embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_);
if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) {
MS_LOG(EXCEPTION) << "MallocConstantMemory failed.";
}
}

void PsCacheManager::SetLocalIdRank() {
@@ -250,19 +254,25 @@ void PsCacheManager::set_channel_name(const std::string channel_name) {
channel_name_ = channel_name;
}

void PsCacheManager::IncreaseStep() {
bool PsCacheManager::IncreaseStep() {
if (data_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t.";
return false;
}
data_step_++;
set_current_graph_step();
if (graph_running_step_ > data_step_) {
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
<< ").";
MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_
<< ").";
return false;
}
return true;
}

void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
if (terminated_) {
MS_LOG(EXCEPTION) << "ps cache data process thread is terminated.";
}
if (graph_step_ >= UINT64_MAX) {
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t.";
}
@@ -274,7 +284,9 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) {
}
graph_step_++;
set_channel_name(channel_name);
PsDataPrefetch::GetInstance().TryWakeChannel(channel_name);
if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) {
MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name;
}
data_prase_.notify_one();
}

@@ -284,74 +296,99 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) {
MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training "
"mode, current dataset mode is not sink_mode.";
}
auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
process_data_thread.detach();
process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context);
}

void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
embedding_device_cache_->cache_->InitDevice(device_id, context);
running_ = true;
bool ret = true;
InitParameterServer();
while (true) {
ProcessData();
while (ret) {
if (!running_) {
break;
}
ret = ProcessData();
}
if (!ret) {
terminated_ = true;
}
}

void PsCacheManager::ProcessData() {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
void PsCacheManager::Finalize() {
if (running_) {
running_ = false;
}
PsDataPrefetch::GetInstance().NotifyFinalize();
insert_init_info_.notify_all();
data_prase_.notify_all();
if (process_data_thread_.joinable()) {
process_data_thread_.join();
}
}

bool PsCacheManager::ProcessData() {
struct timeval start_time, end_time;
const uint64_t kUSecondInSecond = 1000000;
(void)gettimeofday(&start_time, nullptr);
auto channel = channel_name();
if (channel.empty()) {
std::unique_lock<std::mutex> locker(data_mutex_);
data_prase_.wait(locker, [this] { return !channel_name_.empty(); });
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; });
if (!running_) {
return false;
}
}
auto data = PsDataPrefetch::GetInstance().data(channel_name_);
if (data == nullptr) {
MS_LOG(INFO) << "No data process, channel name:" << channel_name_;
std::unique_lock<std::mutex> locker(data_mutex_);
(void)data_prase_.wait_for(locker, std::chrono::milliseconds(100));
return;
return true;
}
IncreaseStep();
RETURN_IF_FALSE(IncreaseStep());
auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_);
if (data_size == 0) {
MS_LOG(ERROR) << "The data_size can not be zero.";
return false;
}
auto batch_ids = reinterpret_cast<int *>(data);
auto batch_ids_len = data_size / sizeof(int);
std::unique_ptr<int[]> hash_index(new int[batch_ids_len]);
if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) {
MS_LOG(EXCEPTION) << "Process data memset failed.";
MS_LOG(ERROR) << "Process data memset failed.";
return false;
}
// Get hash swap in/out index and ids.
ParseData(batch_ids, batch_ids_len, hash_index.get());
RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
for (const auto &item : hash_tables_) {
auto key = worker.GetParamKey(item.first);
auto hash_info = item.second;
HashSwapHostToServer(key, hash_info);
HashSwapDeviceToHost(hash_info);
HashSwapServerToHost(key, hash_info);
HashSwapHostToDevice(hash_info);
RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info));
RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info));
RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info));
RETURN_IF_FALSE(HashSwapHostToDevice(hash_info));
}
// Replace the batch_ids by hash index for getNext-op getting hash index as input.
if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) {
MS_LOG(EXCEPTION) << "Process data memcpy failed.";
MS_LOG(ERROR) << "Process data memcpy failed.";
return false;
}
embedding_device_cache_->cache_->SynchronizeStream();
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
// Finish the data process and notify data prefetch.
PsDataPrefetch::GetInstance().FinalizeData(channel_name_);
RETURN_IF_FALSE(PsDataPrefetch::GetInstance().FinalizeData(channel_name_));
(void)gettimeofday(&end_time, nullptr);
uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_
<< ",graph step:" << graph_running_step_ << " channel name:" << channel_name_
<< ", time cost:" << cost / 1000 << "ms).";
return true;
}

void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
MS_EXCEPTION_IF_NULL(batch_ids);
MS_EXCEPTION_IF_NULL(hash_index);
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
for (size_t i = 0; i < batch_ids_len; i++) {
bool need_swap_host_to_device = true;
bool need_swap_device_to_host = true;
@@ -360,12 +397,16 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
hash_index[i] = -1;
continue;
}
hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device);
auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device);
if (index == INVALID_INDEX_VALUE) {
return false;
}
hash_index[i] = index;
if (need_swap_host_to_device) {
ParseHostDataHostToDevice(id);
RETURN_IF_FALSE(ParseHostDataHostToDevice(id));
}
if (need_swap_device_to_host) {
ParseHostDataDeviceToHost(id);
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id));
}
}
// Each 1000 step prints ps cache hit rate.
@@ -374,33 +415,28 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_;
MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%.";
}
return true;
}

void PsCacheManager::WaitGraphRun() {
bool PsCacheManager::WaitGraphRun() {
MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes.";
std::unique_lock<std::mutex> locker(data_mutex_);
if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) {
MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_
<< ", graph running step:" << graph_running_step_ << ").";
MS_LOG(ERROR) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_
<< ", graph running step:" << graph_running_step_ << ").";
return false;
}
set_current_graph_step();
return true;
}

int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) {
MS_EXCEPTION_IF_NULL(need_swap_device_to_host);
MS_EXCEPTION_IF_NULL(need_swap_host_to_device);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get();
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get();
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get();
MS_EXCEPTION_IF_NULL(device_to_host_index);
MS_EXCEPTION_IF_NULL(device_to_host_ids);
MS_EXCEPTION_IF_NULL(host_to_device_index);
MS_EXCEPTION_IF_NULL(host_to_device_ids);

auto device_hash_map = embedding_device_cache_->device_hash_map_;
MS_EXCEPTION_IF_NULL(device_hash_map);
int index = 0;
auto iter = device_hash_map->id_iter(id);
if (device_hash_map->IsIdExist(iter)) {
@@ -417,7 +453,9 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
&(statistics_info_.device_to_host_size_));
if (index == INVALID_INDEX_VALUE) {
WaitGraphRun();
if (!WaitGraphRun()) {
return INVALID_INDEX_VALUE;
}
continue;
}
host_to_device_index[statistics_info_.host_to_device_size_] = index;
@@ -430,21 +468,20 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b
return index;
}

void PsCacheManager::ParseHostDataHostToDevice(size_t id) {
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get();
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get();
MS_EXCEPTION_IF_NULL(host_to_server_index);
MS_EXCEPTION_IF_NULL(host_to_server_ids);
MS_EXCEPTION_IF_NULL(server_to_host_index);
MS_EXCEPTION_IF_NULL(server_to_host_ids);
MS_EXCEPTION_IF_NULL(host_to_device_index);
MS_ERROR_IF_NULL(host_to_server_index);
MS_ERROR_IF_NULL(host_to_server_ids);
MS_ERROR_IF_NULL(server_to_host_index);
MS_ERROR_IF_NULL(server_to_host_ids);
MS_ERROR_IF_NULL(host_to_device_index);

auto host_hash_map = embedding_host_cache_->host_hash_map_;
MS_EXCEPTION_IF_NULL(host_hash_map);
MS_ERROR_IF_NULL(host_hash_map);
auto iter = host_hash_map->id_iter(id);
if (host_hash_map->IsIdExist(iter)) {
auto index = iter->second;
@@ -457,7 +494,7 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
if (index == INVALID_INDEX_VALUE) {
WaitGraphRun();
RETURN_IF_FALSE(WaitGraphRun());
continue;
}
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index;
@@ -466,22 +503,21 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) {
break;
}
}
return true;
}

void PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get();
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get();
MS_EXCEPTION_IF_NULL(device_to_host_ids);
MS_EXCEPTION_IF_NULL(host_to_server_index);
MS_EXCEPTION_IF_NULL(host_to_server_ids);
MS_EXCEPTION_IF_NULL(device_to_host_index);
MS_ERROR_IF_NULL(device_to_host_ids);
MS_ERROR_IF_NULL(host_to_server_index);
MS_ERROR_IF_NULL(host_to_server_ids);
MS_ERROR_IF_NULL(device_to_host_index);

auto host_hash_map = embedding_host_cache_->host_hash_map_;
MS_EXCEPTION_IF_NULL(host_hash_map);
MS_ERROR_IF_NULL(host_hash_map);
int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1];
auto iter = host_hash_map->id_iter(swap_device_to_host_id);
if (host_hash_map->IsIdExist(iter)) {
@@ -495,13 +531,14 @@ void PsCacheManager::ParseHostDataDeviceToHost(size_t id) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
if (index == INVALID_INDEX_VALUE) {
WaitGraphRun();
RETURN_IF_FALSE(WaitGraphRun());
continue;
}
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index;
break;
}
}
return true;
}

void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size,
@@ -514,19 +551,21 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size,
size_t pos = index * outer_dim_size;
auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
MS_LOG(ERROR) << "LookUpTable task memcpy failed.";
terminated_ = true;
}
} else {
auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
MS_LOG(ERROR) << "LookUpTable task memset failed.";
terminated_ = true;
}
}
output_addr += outer_dim_size;
}
}

void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr) {
size_t first_dim_size = host_cache_vocab_size_;
size_t outer_dim_size = embedding_size;
@@ -553,9 +592,10 @@ void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
return !terminated_;
}

void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices,
float *insert_data, float *hash_table_addr) {
size_t first_dim_size = host_cache_vocab_size_;
size_t thread_num = insert_indices_size / 10000 + 1;
@@ -565,8 +605,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
size_t i;
size_t task_offset = 0;

auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
int *insert_indices, float *insert_data, float *hash_table_addr) {
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
int *insert_indices, float *insert_data, float *hash_table_addr) {
auto type_size = sizeof(float);
size_t lens = outer_dim_size * type_size;
for (size_t i = 0; i < insert_indices_size; ++i) {
@@ -574,7 +614,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
if (index >= 0 && index < SizeToInt(first_dim_size)) {
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed.";
MS_LOG(ERROR) << "Insert hash table task memcpy failed.";
terminated_ = true;
}
}
}
@@ -596,94 +637,101 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
for (size_t j = 0; j < i; j++) {
threads[j].join();
}
return !terminated_;
}

void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) {
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(embedding_host_cache_);
auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get();
auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get();
auto swap_indices_size = statistics_info_.host_to_device_size_;
if (swap_indices_size == 0) {
return;
return true;
}
auto embedding_size = hash_info.embedding_size;
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index,
swap_out_data.get());
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_,
swap_out_data.get(),
swap_indices_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_indices_size);
}

void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr,
host_cache_host_to_device_index, swap_out_data.get()));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(),
swap_indices_size * embedding_size * sizeof(float)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
device_cache_host_to_device_index,
swap_indices_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_indices_size));
return true;
}

bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) {
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(embedding_host_cache_);
auto swap_indices_size = statistics_info_.device_to_host_size_;
auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get();
auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get();
if (swap_indices_size == 0) {
return;
return true;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto embedding_size = hash_info.embedding_size;
auto swap_out_data = std::make_unique<float[]>(swap_indices_size * embedding_size);
embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_indices_size);
embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(),
embedding_device_cache_->hash_swap_value_addr_,
swap_indices_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->SynchronizeStream();
InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
swap_out_data.get(), host_hash_table_addr);
}

void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
device_cache_device_to_host_index,
swap_indices_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_indices_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_,
swap_indices_size * embedding_size * sizeof(float)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream());
RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index,
swap_out_data.get(), host_hash_table_addr));
return true;
}

bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) {
MS_ERROR_IF_NULL(embedding_host_cache_);
auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
auto host_to_server_index = embedding_host_cache_->host_to_server_index.get();
auto swap_indices_size = statistics_info_.host_to_server_size_;
if (swap_indices_size == 0) {
return;
return true;
}
::ps::SArray<int> lookup_ids(swap_indices_size, 0);
::ps::SArray<float> swap_out_data;
auto embedding_size = hash_info.embedding_size;
swap_out_data.resize(swap_indices_size * embedding_size);
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
swap_out_data.data());
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
swap_out_data.data()));

auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
return true;
}

void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(embedding_host_cache_);
bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) {
MS_ERROR_IF_NULL(embedding_host_cache_);
auto swap_indices_size = statistics_info_.server_to_host_size_;
auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get();
auto server_to_host_index = embedding_host_cache_->server_to_host_index.get();
if (swap_indices_size == 0) {
return;
return true;
}
auto host_hash_table_addr = reinterpret_cast<float *>(hash_info.host_address.get());
auto embedding_size = hash_info.embedding_size;
@@ -693,47 +741,50 @@ void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_
auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(),
host_hash_table_addr);
RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index,
lookup_result.data(), host_hash_table_addr));
return true;
}

void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data,
bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data,
const HashTableInfo &hash_info) {
MS_EXCEPTION_IF_NULL(swap_out_index);
MS_EXCEPTION_IF_NULL(swap_out_data);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(swap_out_index);
MS_ERROR_IF_NULL(swap_out_data);
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
auto swap_out_index_size = statistics_info_.device_to_host_size_;
if (swap_out_index_size == 0) {
return;
return true;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
auto embedding_size = hash_info.embedding_size;
swap_out_data->resize(swap_out_index_size * embedding_size);
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index,
swap_out_index_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_out_index_size);
embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(),
embedding_device_cache_->hash_swap_value_addr_,
swap_out_index_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->RecordEvent();
}

void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info,
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_out_index_size));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost(
swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_,
swap_out_index_size * embedding_size * sizeof(float)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->RecordEvent());
return true;
}

bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info,
size_t key) {
MS_EXCEPTION_IF_NULL(swap_in_ids);
MS_EXCEPTION_IF_NULL(swap_in_index);
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(swap_in_ids);
MS_ERROR_IF_NULL(swap_in_index);
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
auto swap_in_ids_size = statistics_info_.host_to_device_size_;
if (swap_in_ids_size == 0) {
return;
return true;
}
auto hash_table_addr = reinterpret_cast<float *>(hash_info.device_address.addr);
auto hash_table_size = hash_info.device_address.size;
@@ -745,42 +796,44 @@ void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons
auto copy_len = swap_in_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd);
// Hash swap-in in device.
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_,
lookup_result.data(),
swap_in_ids_size * embedding_size * sizeof(float));
embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index,
swap_in_ids_size * sizeof(int));
embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_,
embedding_device_cache_->hash_swap_index_addr_, hash_table_size,
embedding_size, swap_in_ids_size);
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(
embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(),
swap_in_ids_size * embedding_size * sizeof(float)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_,
swap_in_index, swap_in_ids_size * sizeof(int)));
RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn(
hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_,
hash_table_size, embedding_size, swap_in_ids_size));
return true;
}

void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
MS_EXCEPTION_IF_NULL(swap_out_ids);
bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key) {
MS_ERROR_IF_NULL(embedding_device_cache_);
MS_ERROR_IF_NULL(embedding_device_cache_->cache_);
MS_ERROR_IF_NULL(swap_out_ids);
auto swap_out_ids_size = statistics_info_.device_to_host_size_;
if (swap_out_ids_size == 0) {
return;
return true;
}
::ps::SArray<int> lookup_ids(swap_out_ids_size, 0);
auto copy_len = swap_out_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "Lookup id memcpy failed.";
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
}
// Need synchronize event to ensure that the swap-out in device is completed.
embedding_device_cache_->cache_->SynchronizeEvent();
RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent());
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data);
return true;
}

void PsCacheManager::DumpHashTables(bool dump_device_tables) const {
MS_EXCEPTION_IF_NULL(embedding_device_cache_);
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_);
for (const auto &item : hash_tables_) {
const auto &param_name = item.first;
size_t cache_vocab_size = item.second.cache_vocab_size;


+ 20
- 15
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h View File

@@ -126,6 +126,8 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
bool terminated() const { return terminated_; }
void Finalize();
void DumpHashTables(bool dump_device_tables = false) const;

private:
@@ -133,7 +135,7 @@ class PsCacheManager {
~PsCacheManager() = default;
PsCacheManager(const PsCacheManager &) = delete;
PsCacheManager &operator=(const PsCacheManager &) = delete;
void IncreaseStep();
bool IncreaseStep();
void set_current_graph_step() { graph_running_step_ = graph_step_; }
std::string channel_name();
void set_channel_name(const std::string channel_name);
@@ -141,23 +143,23 @@ class PsCacheManager {
void AllocMemForHashTable();
void SetLocalIdRank();
void ProcessDataTask(uint32_t device_id, void *context);
void ProcessData();
void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
void WaitGraphRun();
bool ProcessData();
bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index);
bool WaitGraphRun();
int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device);
void ParseHostDataHostToDevice(size_t id);
void ParseHostDataDeviceToHost(size_t id);
void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
void HashSwapHostToDevice(const HashTableInfo &hash_info);
void HashSwapDeviceToHost(const HashTableInfo &hash_info);
void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info);
void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info);
void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data,
bool ParseHostDataHostToDevice(size_t id);
bool ParseHostDataDeviceToHost(size_t id);
bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray<float> *swap_out_data, const HashTableInfo &hash_info);
bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key);
bool HashSwapHostToDevice(const HashTableInfo &hash_info);
bool HashSwapDeviceToHost(const HashTableInfo &hash_info);
bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info);
bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info);
bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data,
float *hash_table_addr);
void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr,
const int *indices_addr, float *output_addr);
void UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
bool UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_data, int *swap_out_ids, size_t key);
void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr,
const int *indices_addr, float *output_addr);
bool CheckFinishInsertInitInfo() const;
@@ -172,6 +174,7 @@ class PsCacheManager {
std::mutex data_mutex_;
std::condition_variable data_prase_;
std::condition_variable insert_init_info_;
std::thread process_data_thread_;

std::map<std::string, HashTableInfo> hash_tables_;
std::shared_ptr<EmbeddingDeviceCache> embedding_device_cache_;
@@ -185,6 +188,8 @@ class PsCacheManager {
std::pair<size_t, size_t> range_bound_;
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};
std::atomic_bool terminated_{false};
};

static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();


+ 42
- 20
mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc View File

@@ -28,11 +28,9 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s
if (iter != ps_data_channel_map_.end()) {
MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
auto channel = iter->second;
MS_EXCEPTION_IF_NULL(channel);
channel->set_step_num(step_num);
} else {
auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
MS_EXCEPTION_IF_NULL(channel);
(void)ps_data_channel_map_.emplace(channel_name, channel);
}
}
@@ -40,71 +38,95 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s
std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
auto iter = ps_data_channel_map_.find(channel_name);
if (iter == ps_data_channel_map_.end()) {
MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name;
MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name;
return nullptr;
}
return iter->second;
}

void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
if (cache_enable_ == false) {
return;
return true;
}
if (data == nullptr) {
MS_LOG(WARNING) << "No data prefetch.";
return;
return true;
}
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
MS_ERROR_IF_NULL(channel);
channel->set_data(data, data_size);
std::unique_lock<std::mutex> locker(data_mutex_);
data_ready_ = true;
data_process_.notify_one();
if (!need_wait_) {
return true;
}
for (int i = 0; i < 10; i++) {
if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) {
return;
if (data_prefetch_.wait_for(locker, std::chrono::seconds(30),
[this] { return data_ready_ == false || need_wait_ == false; })) {
return true;
} else {
MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)";
}
}
MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size.";
MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size.";
return false;
}

void PsDataPrefetch::FinalizeData(const std::string &channel_name) {
bool PsDataPrefetch::FinalizeData(const std::string &channel_name) {
if (cache_enable_ == false) {
return;
return true;
}
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
MS_ERROR_IF_NULL(channel);
channel->ResetData();
std::unique_lock<std::mutex> locker(data_mutex_);
data_ready_ = false;
data_prefetch_.notify_one();
if (!need_wait_) {
return true;
}
for (int i = 0; i < 10; i++) {
if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) {
return;
if (data_process_.wait_for(locker, std::chrono::seconds(30),
[this] { return data_ready_ == true || need_wait_ == false; })) {
return true;
} else {
MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)";
}
}
MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout.";
MS_LOG(ERROR) << "Ps cache data prefetch timeout.";
return false;
}

void *PsDataPrefetch::data(const std::string &channel_name) const {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
if (channel == nullptr) {
return nullptr;
}
return channel->data();
}

size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
if (channel == nullptr) {
return 0;
}
return channel->data_size();
}

void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
void PsDataPrefetch::NotifyFinalize() {
need_wait_ = false;
data_prefetch_.notify_one();
data_process_.notify_one();
}

bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
if (channel == nullptr) {
return false;
}
channel->TryWakeChannel();
return true;
}
} // namespace ps
} // namespace mindspore

+ 6
- 3
mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h View File

@@ -19,6 +19,7 @@
#include <map>
#include <string>
#include <memory>
#include <atomic>
#include <condition_variable>
#include "ps/ps_cache/ps_data/ps_data_channel.h"

@@ -36,11 +37,12 @@ class EXPORT PsDataPrefetch {
EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT void FinalizeData(const std::string &channel_name);
EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT bool FinalizeData(const std::string &channel_name);
EXPORT void NotifyFinalize();
EXPORT void *data(const std::string &channel_name) const;
EXPORT size_t data_size(const std::string &channel_name) const;
EXPORT void TryWakeChannel(const std::string &channel_name);
EXPORT bool TryWakeChannel(const std::string &channel_name);

private:
PsDataPrefetch() : cache_enable_(false), data_ready_(false) {}
@@ -54,6 +56,7 @@ class EXPORT PsDataPrefetch {
std::mutex data_mutex_;
std::condition_variable data_prefetch_;
std::condition_variable data_process_;
std::atomic_bool need_wait_{true};
};
} // namespace ps
} // namespace mindspore


+ 7
- 2
mindspore/ccsrc/ps/ps_context.cc View File

@@ -17,10 +17,10 @@
#include "ps/ps_context.h"
#include "utils/log_adapter.h"
#include "utils/ms_utils.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "backend/kernel_compiler/kernel.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#endif

namespace mindspore {
@@ -62,7 +62,12 @@ void PSContext::Reset() {
is_worker_ = false;
is_pserver_ = false;
is_sched_ = false;
set_cache_enable(false);
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
ps_cache_instance.Finalize();
set_cache_enable(false);
}
#endif
}

std::string PSContext::ms_role() const {


+ 10
- 0
mindspore/ccsrc/runtime/device/gpu/gpu_common.h View File

@@ -62,6 +62,16 @@ namespace gpu {
} \
}

#define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \
{ \
cudaError_t status = (expression); \
if (status != cudaSuccess) { \
MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \
<< cudaGetErrorString(status); \
return false; \
} \
}

#define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \
{ \
cudaError_t status = (expression); \


+ 8
- 0
mindspore/core/utils/log_adapter.h View File

@@ -199,6 +199,14 @@ class LogWriter {
} \
} while (0)

#define MS_ERROR_IF_NULL(ptr) \
do { \
if ((ptr) == nullptr) { \
MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \
return false; \
} \
} while (0)

#ifdef DEBUG
#include <cassert>
#define MS_ASSERT(f) assert(f)


Loading…
Cancel
Save