Browse Source

!10456 ps cache sync before release res

From: @limingqi107
Reviewed-by: @cristoval,@zhoufeng54
Signed-off-by: @cristoval
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ca30af83f7
3 changed files with 22 additions and 6 deletions
  1. +14
    -6
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc
  2. +2
    -0
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h
  3. +6
    -0
      mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc

+ 14
- 6
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc View File

@@ -331,12 +331,7 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) {

void PsCacheManager::Finalize() {
if (running_) {
if (!SyncHostEmbeddingTable()) {
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
}
if (!SyncDeviceEmbeddingTable()) {
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
}
SyncEmbeddingTable();
}
running_ = false;
PsDataPrefetch::GetInstance().NotifyFinalize();
@@ -846,6 +841,19 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da
return true;
}

void PsCacheManager::SyncEmbeddingTable() {
if (finish_embedding_table_sync_) {
return;
}
if (!SyncHostEmbeddingTable()) {
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed.";
}
if (!SyncDeviceEmbeddingTable()) {
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed.";
}
finish_embedding_table_sync_ = true;
}

bool PsCacheManager::SyncHostEmbeddingTable() {
MS_ERROR_IF_NULL(embedding_host_cache_);
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index();


+ 2
- 0
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h View File

@@ -127,6 +127,7 @@ class PsCacheManager {
bool initialized_ps_cache() const { return initialized_ps_cache_; }
void DoProcessData(uint32_t device_id, void *context);
void IncreaseGraphStep(const std::string &channel_name);
void SyncEmbeddingTable();
void Finalize();
void DumpHashTables(bool dump_device_tables = false) const;

@@ -193,6 +194,7 @@ class PsCacheManager {
std::atomic_bool finish_insert_init_info_{false};
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};
bool finish_embedding_table_sync_{false};
};

static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();


+ 6
- 0
mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc View File

@@ -16,10 +16,16 @@

#include "runtime/device/kernel_runtime_manager.h"
#include "utils/log_adapter.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif

namespace mindspore {
namespace device {
void KernelRuntimeManager::ClearRuntimeResource() {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
ps::ps_cache_instance.SyncEmbeddingTable();
#endif
std::lock_guard<std::mutex> guard(lock_);
for (auto &iter : runtime_map_) {
MS_LOG(INFO) << "Release device " << iter.first;


Loading…
Cancel
Save