| @@ -1813,6 +1813,11 @@ void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) co | |||||
| void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // PS embeddingLookup cache check. | |||||
| if (ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| MS_LOG(EXCEPTION) << "The other parameter cann't set ps mode when the embeddingLookup cache is enabled in " | |||||
| "parameter server training mode."; | |||||
| } | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return()); | ||||
| for (auto &node : node_list) { | for (auto &node : node_list) { | ||||
| if (node != nullptr && node->isa<CNode>()) { | if (node != nullptr && node->isa<CNode>()) { | ||||
| @@ -976,15 +976,12 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc | |||||
| if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | ||||
| backend->Link(runner.graph_id); | backend->Link(runner.graph_id); | ||||
| } | } | ||||
| // PS mode does not support loop sink. | |||||
| ConfigManager::GetInstance().set_iter_num(size); | |||||
| // PS cache does not support loop sink. | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| if (ps::Util::IsRoleOfWorker()) { | |||||
| if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { | |||||
| ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | ps::PsDataPrefetch::GetInstance().CreateDataChannel(queue_name, LongToSize(size)); | ||||
| ConfigManager::GetInstance().set_iter_num(1); | ConfigManager::GetInstance().set_iter_num(1); | ||||
| } else { | |||||
| #endif | |||||
| ConfigManager::GetInstance().set_iter_num(size); | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -129,11 +129,19 @@ void PsCacheManager::CloneHashTable(const std::string &dest_param_name, const st | |||||
| const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const { | const Address &PsCacheManager::QueryHashTableAddr(const std::string ¶m_name) const { | ||||
| auto iter = hash_tables_.find(param_name); | auto iter = hash_tables_.find(param_name); | ||||
| if (iter == hash_tables_.end()) { | if (iter == hash_tables_.end()) { | ||||
| MS_LOG(EXCEPTION) << "Can not find device_address of " << param_name; | |||||
| MS_LOG(EXCEPTION) << "Can not find device address of " << param_name; | |||||
| } | } | ||||
| return iter->second.device_address; | return iter->second.device_address; | ||||
| } | } | ||||
| const size_t &PsCacheManager::QueryHashTableSize(const std::string ¶m_name) const { | |||||
| auto iter = hash_tables_.find(param_name); | |||||
| if (iter == hash_tables_.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find vocab cache size of " << param_name; | |||||
| } | |||||
| return iter->second.cache_vocab_size; | |||||
| } | |||||
| void PsCacheManager::Initialize() { | void PsCacheManager::Initialize() { | ||||
| MS_LOG(INFO) << "PS cache initialize."; | MS_LOG(INFO) << "PS cache initialize."; | ||||
| if (!worker.running()) { | if (!worker.running()) { | ||||
| @@ -244,19 +252,19 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { | |||||
| void PsCacheManager::IncreaseStep() { | void PsCacheManager::IncreaseStep() { | ||||
| if (data_step_ >= UINT64_MAX) { | if (data_step_ >= UINT64_MAX) { | ||||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; | |||||
| MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; | |||||
| } | } | ||||
| data_step_++; | data_step_++; | ||||
| set_current_graph_step(); | set_current_graph_step(); | ||||
| if (graph_running_step_ > data_step_) { | if (graph_running_step_ > data_step_) { | ||||
| MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step (" | |||||
| << data_step_ << ")."; | |||||
| MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ | |||||
| << ")."; | |||||
| } | } | ||||
| } | } | ||||
| void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | ||||
| if (graph_step_ >= UINT64_MAX) { | if (graph_step_ >= UINT64_MAX) { | ||||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; | |||||
| MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; | |||||
| } | } | ||||
| if (graph_step_ == 0) { | if (graph_step_ == 0) { | ||||
| MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; | MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; | ||||
| @@ -271,8 +279,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||||
| } | } | ||||
| void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | ||||
| // PS embeddingLookup cache check. | |||||
| if (!initialized_ps_cache_) { | if (!initialized_ps_cache_) { | ||||
| MS_LOG(EXCEPTION) << "PS cache does not init."; | |||||
| MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | |||||
| "mode, current dataset mode is not sink_mode."; | |||||
| } | } | ||||
| auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | ||||
| process_data_thread.detach(); | process_data_thread.detach(); | ||||
| @@ -120,6 +120,7 @@ class PsCacheManager { | |||||
| size_t cache_vocab_size, size_t embedding_size); | size_t cache_vocab_size, size_t embedding_size); | ||||
| void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); | void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name); | ||||
| const Address &QueryHashTableAddr(const std::string ¶m_name) const; | const Address &QueryHashTableAddr(const std::string ¶m_name) const; | ||||
| const size_t &QueryHashTableSize(const std::string ¶m_name) const; | |||||
| bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } | bool IsHashTable(const std::string ¶m_name) { return hash_tables_.count(param_name) != 0; } | ||||
| void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } | void set_batch_elements(size_t batch_elements) { batch_elements_ = batch_elements; } | ||||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | bool initialized_ps_cache() const { return initialized_ps_cache_; } | ||||
| @@ -325,7 +325,9 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| } | } | ||||
| need_alloc_nodes.push_back(item); | need_alloc_nodes.push_back(item); | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| bool ps_cache_check = false; | |||||
| #endif | |||||
| for (auto &item : need_alloc_nodes) { | for (auto &item : need_alloc_nodes) { | ||||
| auto output_size = AnfAlgo::GetOutputTensorNum(item); | auto output_size = AnfAlgo::GetOutputTensorNum(item); | ||||
| for (size_t index = 0; index < output_size; index++) { | for (size_t index = 0; index < output_size; index++) { | ||||
| @@ -339,6 +341,13 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { | |||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | ||||
| const std::string ¶m_name = item->fullname_with_scope(); | const std::string ¶m_name = item->fullname_with_scope(); | ||||
| if (ps::ps_cache_instance.IsHashTable(param_name)) { | if (ps::ps_cache_instance.IsHashTable(param_name)) { | ||||
| MS_LOG(INFO) << "Parameter(" << param_name << ")" | |||||
| << " enables the embeddingLookup cache in parameter server training mode."; | |||||
| // PS embeddingLookup cache check. | |||||
| if (!ps_cache_check) { | |||||
| CheckIfSupportPSEmbeddingCache(graph); | |||||
| ps_cache_check = true; | |||||
| } | |||||
| const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); | const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name); | ||||
| MS_EXCEPTION_IF_NULL(address.addr); | MS_EXCEPTION_IF_NULL(address.addr); | ||||
| device_address = | device_address = | ||||
| @@ -1024,5 +1033,83 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st | |||||
| MS_EXCEPTION_IF_NULL(base_ptr); | MS_EXCEPTION_IF_NULL(base_ptr); | ||||
| return device_address; | return device_address; | ||||
| } | } | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, | |||||
| size_t *first_cache_size) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| for (const auto &kernel : graph->execution_order()) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { | |||||
| continue; | |||||
| } | |||||
| auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); | |||||
| auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); | |||||
| MS_EXCEPTION_IF_NULL(input_param.first); | |||||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||||
| auto param_name = input_param.first->fullname_with_scope(); | |||||
| if (!ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| continue; | |||||
| } | |||||
| auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); | |||||
| while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { | |||||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); | |||||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||||
| } | |||||
| if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { | |||||
| MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " | |||||
| << input_index.first->fullname_with_scope(); | |||||
| MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " | |||||
| "parameter server training mode."; | |||||
| } | |||||
| *first_cache_input_index = input_index.first; | |||||
| *first_cache_size = size; | |||||
| MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " | |||||
| << input_index.first->fullname_with_scope() << ", the cache size is " << size; | |||||
| return; | |||||
| } | |||||
| } | |||||
| void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| AnfNodePtr first_cache_input_index = nullptr; | |||||
| size_t first_cache_size = 0; | |||||
| GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size); | |||||
| MS_EXCEPTION_IF_NULL(first_cache_input_index); | |||||
| for (const auto &kernel : graph->execution_order()) { | |||||
| MS_EXCEPTION_IF_NULL(kernel); | |||||
| if (AnfAlgo::GetCNodeName(kernel) != "GatherV2") { | |||||
| continue; | |||||
| } | |||||
| auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0); | |||||
| auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1); | |||||
| MS_EXCEPTION_IF_NULL(input_param.first); | |||||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||||
| auto param_name = input_param.first->fullname_with_scope(); | |||||
| while ((AnfAlgo::GetCNodeName(input_index.first) == "Cast") || opt::IsNopNode(input_index.first)) { | |||||
| input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, input_index.second); | |||||
| MS_EXCEPTION_IF_NULL(input_index.first); | |||||
| } | |||||
| if (input_index.first == first_cache_input_index) { | |||||
| if (!ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache."; | |||||
| MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the " | |||||
| "same time when one of them enables cache in parameter server training mode."; | |||||
| } | |||||
| auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); | |||||
| if (size != first_cache_size) { | |||||
| MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope() | |||||
| << ") is not the same as other embeddingLookup cache size."; | |||||
| MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode."; | |||||
| } | |||||
| } else if (ps::ps_cache_instance.IsHashTable(param_name)) { | |||||
| MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " | |||||
| << input_index.first->fullname_with_scope(); | |||||
| MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " | |||||
| "parameter server training mode."; | |||||
| } | |||||
| } | |||||
| } | |||||
| #endif | |||||
| } // namespace device | } // namespace device | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -131,6 +131,11 @@ class KernelRuntime { | |||||
| void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); | void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph); | ||||
| void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); | ||||
| DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index); | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| void GetFirstPSEmbeddingCache(const session::KernelGraph *graph, AnfNodePtr *first_cache_input_index, | |||||
| size_t *first_cache_size); | |||||
| void CheckIfSupportPSEmbeddingCache(const session::KernelGraph *graph); | |||||
| #endif | |||||
| protected: | protected: | ||||
| uint32_t device_id_{0}; | uint32_t device_id_{0}; | ||||
| @@ -14,7 +14,6 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """embedding""" | """embedding""" | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.context as context | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| @@ -23,8 +22,8 @@ from mindspore.common.parameter import Parameter | |||||
| from mindspore.common.initializer import initializer | from mindspore.common.initializer import initializer | ||||
| from mindspore.communication.management import get_group_size | from mindspore.communication.management import get_group_size | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.parallel._utils import _get_parallel_mode | |||||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker | |||||
| from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch | |||||
| from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context | |||||
| from mindspore._checkparam import Rel | from mindspore._checkparam import Rel | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore.ops.primitive import constexpr | from mindspore.ops.primitive import constexpr | ||||
| @@ -195,11 +194,6 @@ class EmbeddingLookup(Cell): | |||||
| + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | + str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.') | ||||
| if not sparse and target == 'CPU': | if not sparse and target == 'CPU': | ||||
| raise ValueError('When target is CPU, embedding_lookup must be sparse.') | raise ValueError('When target is CPU, embedding_lookup must be sparse.') | ||||
| enable_ps = context.get_ps_context("enable_ps") | |||||
| if not enable_ps and vocab_cache_size > 0: | |||||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning mode, " | |||||
| "current mode is not parameter server trainning mode, so it will be ignored.") | |||||
| vocab_cache_size = 0 | |||||
| if sparse: | if sparse: | ||||
| self.gatherv2 = P.SparseGatherV2() | self.gatherv2 = P.SparseGatherV2() | ||||
| else: | else: | ||||
| @@ -207,22 +201,14 @@ class EmbeddingLookup(Cell): | |||||
| self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU') | ||||
| self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size') | ||||
| self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size') | ||||
| self._process_vocab_cache(slice_mode) | |||||
| self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size') | ||||
| parallel_mode = _get_parallel_mode() | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| self.cache_enable = self.vocab_cache_size > 0 | |||||
| if self.cache_enable: | |||||
| if is_auto_parallel: | |||||
| self.vocab_cache_size = self.vocab_cache_size * get_group_size() | |||||
| self.vocab_size = self.vocab_cache_size | |||||
| self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]), | ||||
| name='embedding_table') | name='embedding_table') | ||||
| if self.cache_enable: | if self.cache_enable: | ||||
| self.embedding_table.cache_enable = True | |||||
| _set_cache_enable(True) | |||||
| if _is_role_worker(): | |||||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||||
| self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size) | |||||
| parallel_mode = _get_parallel_mode() | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| self.forward_unique = False | self.forward_unique = False | ||||
| self.gather_revert = P.GatherV2() | self.gather_revert = P.GatherV2() | ||||
| self.unique = P.Unique().shard(((1,),)) | self.unique = P.Unique().shard(((1,),)) | ||||
| @@ -241,7 +227,8 @@ class EmbeddingLookup(Cell): | |||||
| self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) | ||||
| self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) | ||||
| elif slice_mode == "table_row_slice" and is_auto_parallel: | elif slice_mode == "table_row_slice" and is_auto_parallel: | ||||
| if target == 'DEVICE' and not self.cache_enable: | |||||
| full_batch = _get_full_batch() | |||||
| if target == 'DEVICE' and not full_batch: | |||||
| indices_shape_size = 1 | indices_shape_size = 1 | ||||
| self.gather_revert.shard(((1, 1), (get_group_size(),))) | self.gather_revert.shard(((1, 1), (get_group_size(),))) | ||||
| self.forward_unique = True | self.forward_unique = True | ||||
| @@ -272,6 +259,39 @@ class EmbeddingLookup(Cell): | |||||
| self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) | ||||
| self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) | self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) | ||||
| def _process_vocab_cache(self, slice_mode): | |||||
| """PS embeddingLookup cache check and process.""" | |||||
| self.cache_enable = False | |||||
| if self.vocab_cache_size > 0: | |||||
| if self.target == 'CPU': | |||||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " | |||||
| "current target is CPU, so it will be ignored.") | |||||
| return | |||||
| enable_ps = _get_ps_context("enable_ps") | |||||
| if not enable_ps: | |||||
| logger.warning("The configuration of 'vocab_cache_size' is valid only in parameter server trainning " | |||||
| "mode, current mode is not parameter server trainning mode, so it will be ignored.") | |||||
| return | |||||
| parallel_mode = _get_parallel_mode() | |||||
| is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| if is_auto_parallel: | |||||
| device_num = get_group_size() | |||||
| full_batch = _get_full_batch() | |||||
| if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE): | |||||
| raise ValueError("The embeddingLookup cache of parameter server parallel only be used " | |||||
| "in 'full_batch' and 'table_row_slice' parallel strategy.") | |||||
| self.vocab_cache_size = self.vocab_cache_size * device_num | |||||
| self.cache_enable = True | |||||
| self.vocab_size = self.vocab_cache_size | |||||
| def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size): | |||||
| """PS embeddingLookup cache enable set.""" | |||||
| self.embedding_table.cache_enable = True | |||||
| self.embedding_table.is_param_ps = True | |||||
| _set_cache_enable(True) | |||||
| if _is_role_worker(): | |||||
| _insert_hash_table_size(self.embedding_table.name, vocab_cache_size, embedding_size, vocab_size) | |||||
| def construct(self, indices): | def construct(self, indices): | ||||
| if self.target == "CPU": | if self.target == "CPU": | ||||
| out = self.embeddinglookup(self.embedding_table, indices, 0) | out = self.embeddinglookup(self.embedding_table, indices, 0) | ||||