From a844d52b42a6d312b2a137b057a335af5e5f360d Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 15 Dec 2020 10:10:35 +0800 Subject: [PATCH] add ps cache check --- .../ccsrc/backend/session/session_basic.cc | 5 ++ mindspore/ccsrc/pipeline/jit/pipeline.cc | 9 +- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 22 +++-- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 1 + .../ccsrc/runtime/device/kernel_runtime.cc | 89 ++++++++++++++++++- .../ccsrc/runtime/device/kernel_runtime.h | 5 ++ mindspore/nn/layer/embedding.py | 62 ++++++++----- 7 files changed, 159 insertions(+), 34 deletions(-) diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 2d284efe78..6f3595a0d3 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1813,6 +1813,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); + // PS embeddingLookup cache check. + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in " + "parameter server training mode."; + } std::vector node_list = TopoSort(kernel_graph->get_return()); for (auto &node : node_list) { if (node != nullptr && node->isa()) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 08c465417b..875c57eb52 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -976,15 +976,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { backend->Link(runner.graph_id); } - // PS mode does not support loop sink. + ConfigManager::GetInstance().set_iter_num(size); + // PS cache does not support loop sink. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsRoleOfWorker()) { + if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); ConfigManager::GetInstance().set_iter_num(1); - } else { -#endif - ConfigManager::GetInstance().set_iter_num(size); -#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) } #endif diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 0d725ddaf8..d30db0cf25 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -129,11 +129,19 @@ void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const st const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const { auto iter = hash_tables_.find(param_name); if (iter == hash_tables_.end()) { - MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name; + MS_LOG(EXCEPTION) << "Can not find device address of " << param_name; } return iter->second.device_address; } +const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) const { + auto iter = hash_tables_.find(param_name); + if (iter == hash_tables_.end()) { + MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name; + } + return iter->second.cache_vocab_size; +} + void PsCacheManager::Initialize() { MS_LOG(INFO) << "PS cache initialize."; if (!worker.running()) { @@ -244,19 +252,19 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { void PsCacheManager::IncreaseStep() { if (data_step_ >= UINT64_MAX) { - MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; + MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; } data_step_++; set_current_graph_step(); if (graph_running_step_ > data_step_) { - MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step (" - << data_step_ << ")."; + MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ + << ")."; } } void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { if (graph_step_ >= UINT64_MAX) { - MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; + MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; } if (graph_step_ == 0) { MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; @@ -271,8 +279,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { } void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { + // PS embeddingLookup cache check. if (!initialized_ps_cache_) { - MS_LOG(EXCEPTION) << "PS cache does not init."; + MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " + "mode, current dataset mode is not sink_mode."; } auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); process_data_thread.detach(); diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index 07c7949d7b..db8b18bbb0 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -120,6 +120,7 @@ class PsCacheManager { size_t cache_vocab_size, size_t embedding_size); void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); const Address &QueryHashTableAddr(const std::string ¶m_name) const; + const size_t &QueryHashTableSize(const std::string ¶m_name) const; bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } bool initialized_ps_cache() const { return initialized_ps_cache_; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 98d94a86a0..89583a3aa3 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -325,7 +325,9 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { } need_alloc_nodes.push_back(item); } - +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + bool ps_cache_check = false; +#endif for (auto &item : need_alloc_nodes) { auto output_size = AnfAlgo::GetOutputTensorNum(item); for (size_t index = 0; index < output_size; index++) { @@ -339,6 +341,13 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) const std::string ¶m_name = item->fullname_with_scope(); if (ps::ps_cache_instance.IsHashTable(param_name)) { + MS_LOG(INFO) << "Parameter(" << param_name << ")" + << " enables the embeddingLookup cache in parameter server training mode."; + // PS embeddingLookup cache check. + if (!ps_cache_check) { + CheckIfSupportPSEmbeddingCache(graph); + ps_cache_check = true; + } const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); MS_EXCEPTION_IF_NULL(address.addr); device_address = @@ -1024,5 +1033,83 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st MS_EXCEPTION_IF_NULL(base_ptr); return device_address; } + +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) +void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, + size_t *first_cache_size) { + MS_EXCEPTION_IF_NULL(graph); + for (const auto &kernel : graph->execution_order()) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { + continue; + } + auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); + auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); + MS_EXCEPTION_IF_NULL(input_param.first); + MS_EXCEPTION_IF_NULL(input_index.first); + auto param_name = input_param.first->fullname_with_scope(); + if (!ps::ps_cache_instance.IsHashTable(param_name)) { + continue; + } + auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); + while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { + input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); + MS_EXCEPTION_IF_NULL(input_index.first); + } + if ((!input_index.first->isa()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { + MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " + << input_index.first->fullname_with_scope(); + MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " + "parameter server training mode."; + } + *first_cache_input_index = input_index.first; + *first_cache_size = size; + MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " + << input_index.first->fullname_with_scope() << ", the cache size is " << size; + return; + } +} + +void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + AnfNodePtr first_cache_input_index = nullptr; + size_t first_cache_size = 0; + GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size); + MS_EXCEPTION_IF_NULL(first_cache_input_index); + for (const auto &kernel : graph->execution_order()) { + MS_EXCEPTION_IF_NULL(kernel); + if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { + continue; + } + auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); + auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); + MS_EXCEPTION_IF_NULL(input_param.first); + MS_EXCEPTION_IF_NULL(input_index.first); + auto param_name = input_param.first->fullname_with_scope(); + while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { + input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); + MS_EXCEPTION_IF_NULL(input_index.first); + } + if (input_index.first == first_cache_input_index) { + if (!ps::ps_cache_instance.IsHashTable(param_name)) { + MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache."; + MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the " + "same time when one of them enables cache in parameter server training mode."; + } + auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); + if (size != first_cache_size) { + MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope() + << ") is not the same as other embeddingLookup cache size."; + MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode."; + } + } else if (ps::ps_cache_instance.IsHashTable(param_name)) { + MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " + << input_index.first->fullname_with_scope(); + MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " + "parameter server training mode."; + } + } +} +#endif } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 17ccaae282..e9a563ef1d 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -131,6 +131,11 @@ class KernelRuntime { void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, + size_t *first_cache_size); + void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); +#endif protected: uint32_t device_id_{0}; diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index d50cc912ab..ffeaf56bea 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -14,7 +14,6 @@ # ============================================================================ """embedding""" import mindspore.common.dtype as mstype -import mindspore.context as context from mindspore import log as logger from mindspore.common.tensor import Tensor from mindspore.ops import operations as P @@ -23,8 +22,8 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.communication.management import get_group_size from mindspore.context import ParallelMode -from mindspore.parallel._utils import _get_parallel_mode -from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker +from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch +from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context from mindspore._checkparam import Rel from mindspore._checkparam import Validator as validator from mindspore.ops.primitive import constexpr @@ -195,11 +194,6 @@ class EmbeddingLookup(Cell): + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') if not sparse and target == 'CPU': raise ValueError('When target is CPU, embedding_lookup must be sparse.') - enable_ps = context.get_ps_context("enable_ps") - if not enable_ps and vocab_cache_size > 0: - logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, " - "current mode is not parameter server trainning mode, so it will be ignored.") - vocab_cache_size = 0 if sparse: self.gatherv2 = P.SparseGatherV2() else: @@ -207,22 +201,14 @@ class EmbeddingLookup(Cell): self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') + self._process_vocab_cache(slice_mode) self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') - parallel_mode = _get_parallel_mode() - is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - self.cache_enable = self.vocab_cache_size > 0 - if self.cache_enable: - if is_auto_parallel: - self.vocab_cache_size = self.vocab_cache_size * get_group_size() - self.vocab_size = self.vocab_cache_size - self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), name='embedding_table') if self.cache_enable: - self.embedding_table.cache_enable = True - _set_cache_enable(True) - if _is_role_worker(): - _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) + self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size) + parallel_mode = _get_parallel_mode() + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) self.forward_unique = False self.gather_revert = P.GatherV2() self.unique = P.Unique().shard(((1,),)) @@ -241,7 +227,8 @@ class EmbeddingLookup(Cell): self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) elif slice_mode == "table_row_slice" and is_auto_parallel: - if target == 'DEVICE' and not self.cache_enable: + full_batch = _get_full_batch() + if target == 'DEVICE' and not full_batch: indices_shape_size = 1 self.gather_revert.shard(((1, 1), (get_group_size(),))) self.forward_unique = True @@ -272,6 +259,39 @@ class EmbeddingLookup(Cell): self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) + def _process_vocab_cache(self, slice_mode): + """PS embeddingLookup cache check and process.""" + self.cache_enable = False + if self.vocab_cache_size > 0: + if self.target == 'CPU': + logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " + "current target is CPU, so it will be ignored.") + return + enable_ps = _get_ps_context("enable_ps") + if not enable_ps: + logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning " + "mode, current mode is not parameter server trainning mode, so it will be ignored.") + return + parallel_mode = _get_parallel_mode() + is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + if is_auto_parallel: + device_num = get_group_size() + full_batch = _get_full_batch() + if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE): + raise ValueError("The embeddingLookup cache of parameter server parallel only be used " + "in 'full_batch' and 'table_row_slice' parallel strategy.") + self.vocab_cache_size = self.vocab_cache_size * device_num + self.cache_enable = True + self.vocab_size = self.vocab_cache_size + + def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): + """PS embeddingLookup cache enable set.""" + self.embedding_table.cache_enable = True + self.embedding_table.is_param_ps = True + _set_cache_enable(True) + if _is_role_worker(): + _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) + def construct(self, indices): if self.target == "CPU": out = self.embeddinglookup(self.embedding_table, indices, 0)