| @@ -1813,6 +1813,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co | |||
| void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| // PS embeddingLookup cache check. | |||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in " | |||
| "parameter server training mode."; | |||
| } | |||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (node != nullptr && node->isa<CNode>()) { | |||
| @@ -976,15 +976,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||
| backend->Link(runner.graph_id); | |||
| } | |||
| // PS mode does not support loop sink. | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| // PS cache does not support loop sink. | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| if (ps::Util::IsRoleOfWorker()) { | |||
| if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | |||
| ConfigManager::GetInstance().set_iter_num(1); | |||
| } else { | |||
| #endif | |||
| ConfigManager::GetInstance().set_iter_num(size); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| } | |||
| #endif | |||
| @@ -129,11 +129,19 @@ void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const st | |||
| const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const { | |||
| auto iter = hash_tables_.find(param_name); | |||
| if (iter == hash_tables_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name; | |||
| MS_LOG(EXCEPTION) << "Can not find device address of " << param_name; | |||
| } | |||
| return iter->second.device_address; | |||
| } | |||
| const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) const { | |||
| auto iter = hash_tables_.find(param_name); | |||
| if (iter == hash_tables_.end()) { | |||
| MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name; | |||
| } | |||
| return iter->second.cache_vocab_size; | |||
| } | |||
| void PsCacheManager::Initialize() { | |||
| MS_LOG(INFO) << "PS cache initialize."; | |||
| if (!worker.running()) { | |||
| @@ -244,19 +252,19 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { | |||
| void PsCacheManager::IncreaseStep() { | |||
| if (data_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; | |||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; | |||
| } | |||
| data_step_++; | |||
| set_current_graph_step(); | |||
| if (graph_running_step_ > data_step_) { | |||
| MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step (" | |||
| << data_step_ << ")."; | |||
| MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ | |||
| << ")."; | |||
| } | |||
| } | |||
| void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| if (graph_step_ >= UINT64_MAX) { | |||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; | |||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; | |||
| } | |||
| if (graph_step_ == 0) { | |||
| MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; | |||
| @@ -271,8 +279,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| } | |||
| void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||
| // PS embeddingLookup cache check. | |||
| if (!initialized_ps_cache_) { | |||
| MS_LOG(EXCEPTION) << "PS cache does not init."; | |||
| MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | |||
| "mode, current dataset mode is not sink_mode."; | |||
| } | |||
| auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||
| process_data_thread.detach(); | |||
| @@ -120,6 +120,7 @@ class PsCacheManager { | |||
| size_t cache_vocab_size, size_t embedding_size); | |||
| void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); | |||
| const Address &QueryHashTableAddr(const std::string ¶m_name) const; | |||
| const size_t &QueryHashTableSize(const std::string ¶m_name) const; | |||
| bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } | |||
| void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } | |||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | |||
| @@ -325,7 +325,9 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| } | |||
| need_alloc_nodes.push_back(item); | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| bool ps_cache_check = false; | |||
| #endif | |||
| for (auto &item : need_alloc_nodes) { | |||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | |||
| for (size_t index = 0; index < output_size; index++) { | |||
| @@ -339,6 +341,13 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| const std::string ¶m_name = item->fullname_with_scope(); | |||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| MS_LOG(INFO) << "Parameter(" << param_name << ")" | |||
| << " enables the embeddingLookup cache in parameter server training mode."; | |||
| // PS embeddingLookup cache check. | |||
| if (!ps_cache_check) { | |||
| CheckIfSupportPSEmbeddingCache(graph); | |||
| ps_cache_check = true; | |||
| } | |||
| const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); | |||
| MS_EXCEPTION_IF_NULL(address.addr); | |||
| device_address = | |||
| @@ -1024,5 +1033,83 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st | |||
| MS_EXCEPTION_IF_NULL(base_ptr); | |||
| return device_address; | |||
| } | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, | |||
| size_t *first_cache_size) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| for (const auto &kernel : graph->execution_order()) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { | |||
| continue; | |||
| } | |||
| auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); | |||
| auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); | |||
| MS_EXCEPTION_IF_NULL(input_param.first); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| auto param_name = input_param.first->fullname_with_scope(); | |||
| if (!ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| continue; | |||
| } | |||
| auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); | |||
| while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { | |||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| } | |||
| if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { | |||
| MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " | |||
| << input_index.first->fullname_with_scope(); | |||
| MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " | |||
| "parameter server training mode."; | |||
| } | |||
| *first_cache_input_index = input_index.first; | |||
| *first_cache_size = size; | |||
| MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " | |||
| << input_index.first->fullname_with_scope() << ", the cache size is " << size; | |||
| return; | |||
| } | |||
| } | |||
| void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) { | |||
| MS_EXCEPTION_IF_NULL(graph); | |||
| AnfNodePtr first_cache_input_index = nullptr; | |||
| size_t first_cache_size = 0; | |||
| GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size); | |||
| MS_EXCEPTION_IF_NULL(first_cache_input_index); | |||
| for (const auto &kernel : graph->execution_order()) { | |||
| MS_EXCEPTION_IF_NULL(kernel); | |||
| if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { | |||
| continue; | |||
| } | |||
| auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); | |||
| auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); | |||
| MS_EXCEPTION_IF_NULL(input_param.first); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| auto param_name = input_param.first->fullname_with_scope(); | |||
| while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { | |||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); | |||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||
| } | |||
| if (input_index.first == first_cache_input_index) { | |||
| if (!ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache."; | |||
| MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the " | |||
| "same time when one of them enables cache in parameter server training mode."; | |||
| } | |||
| auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); | |||
| if (size != first_cache_size) { | |||
| MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope() | |||
| << ") is not the same as other embeddingLookup cache size."; | |||
| MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode."; | |||
| } | |||
| } else if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||
| MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " | |||
| << input_index.first->fullname_with_scope(); | |||
| MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " | |||
| "parameter server training mode."; | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| } // namespace device | |||
| } // namespace mindspore | |||
| @@ -131,6 +131,11 @@ class KernelRuntime { | |||
| void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); | |||
| void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | |||
| DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | |||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||
| void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, | |||
| size_t *first_cache_size); | |||
| void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); | |||
| #endif | |||
| protected: | |||
| uint32_t device_id_{0}; | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """embedding""" | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.context as context | |||
| from mindspore import log as logger | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import operations as P | |||
| @@ -23,8 +22,8 @@ from mindspore.common.parameter import Parameter | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.parallel._utils import _get_parallel_mode | |||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker | |||
| from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch | |||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context | |||
| from mindspore._checkparam import Rel | |||
| from mindspore._checkparam import Validator as validator | |||
| from mindspore.ops.primitive import constexpr | |||
| @@ -195,11 +194,6 @@ class EmbeddingLookup(Cell): | |||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | |||
| if not sparse and target == 'CPU': | |||
| raise ValueError('When target is CPU, embedding_lookup must be sparse.') | |||
| enable_ps = context.get_ps_context("enable_ps") | |||
| if not enable_ps and vocab_cache_size > 0: | |||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, " | |||
| "current mode is not parameter server trainning mode, so it will be ignored.") | |||
| vocab_cache_size = 0 | |||
| if sparse: | |||
| self.gatherv2 = P.SparseGatherV2() | |||
| else: | |||
| @@ -207,22 +201,14 @@ class EmbeddingLookup(Cell): | |||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | |||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | |||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | |||
| self._process_vocab_cache(slice_mode) | |||
| self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.cache_enable = self.vocab_cache_size > 0 | |||
| if self.cache_enable: | |||
| if is_auto_parallel: | |||
| self.vocab_cache_size = self.vocab_cache_size * get_group_size() | |||
| self.vocab_size = self.vocab_cache_size | |||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | |||
| name='embedding_table') | |||
| if self.cache_enable: | |||
| self.embedding_table.cache_enable = True | |||
| _set_cache_enable(True) | |||
| if _is_role_worker(): | |||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||
| self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size) | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| self.forward_unique = False | |||
| self.gather_revert = P.GatherV2() | |||
| self.unique = P.Unique().shard(((1,),)) | |||
| @@ -241,7 +227,8 @@ class EmbeddingLookup(Cell): | |||
| self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | |||
| elif slice_mode == "table_row_slice" and is_auto_parallel: | |||
| if target == 'DEVICE' and not self.cache_enable: | |||
| full_batch = _get_full_batch() | |||
| if target == 'DEVICE' and not full_batch: | |||
| indices_shape_size = 1 | |||
| self.gather_revert.shard(((1, 1), (get_group_size(),))) | |||
| self.forward_unique = True | |||
| @@ -272,6 +259,39 @@ class EmbeddingLookup(Cell): | |||
| self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | |||
| self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) | |||
| def _process_vocab_cache(self, slice_mode): | |||
| """PS embeddingLookup cache check and process.""" | |||
| self.cache_enable = False | |||
| if self.vocab_cache_size > 0: | |||
| if self.target == 'CPU': | |||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " | |||
| "current target is CPU, so it will be ignored.") | |||
| return | |||
| enable_ps = _get_ps_context("enable_ps") | |||
| if not enable_ps: | |||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning " | |||
| "mode, current mode is not parameter server trainning mode, so it will be ignored.") | |||
| return | |||
| parallel_mode = _get_parallel_mode() | |||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||
| if is_auto_parallel: | |||
| device_num = get_group_size() | |||
| full_batch = _get_full_batch() | |||
| if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE): | |||
| raise ValueError("The embeddingLookup cache of parameter server parallel only be used " | |||
| "in 'full_batch' and 'table_row_slice' parallel strategy.") | |||
| self.vocab_cache_size = self.vocab_cache_size * device_num | |||
| self.cache_enable = True | |||
| self.vocab_size = self.vocab_cache_size | |||
| def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): | |||
| """PS embeddingLookup cache enable set.""" | |||
| self.embedding_table.cache_enable = True | |||
| self.embedding_table.is_param_ps = True | |||
| _set_cache_enable(True) | |||
| if _is_role_worker(): | |||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||
| def construct(self, indices): | |||
| if self.target == "CPU": | |||
| out = self.embeddinglookup(self.embedding_table, indices, 0) | |||