| @@ -131,15 +131,18 @@ void *AscendPsCache::MallocMemory(size_t size) { | |||||
| return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); | return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); | ||||
| } | } | ||||
| bool AscendPsCache::MallocConstantMemory(size_t constant_value) { | |||||
| bool AscendPsCache::MallocConstantMemory(size_t cache_vocab_size) { | |||||
| offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | offset_addr_ = reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | ||||
| MS_ERROR_IF_NULL(offset_addr_); | MS_ERROR_IF_NULL(offset_addr_); | ||||
| rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); | rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); | ||||
| cache_vocab_size_addr_ = | cache_vocab_size_addr_ = | ||||
| reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | reinterpret_cast<int *>(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); | ||||
| MS_ERROR_IF_NULL(cache_vocab_size_addr_); | MS_ERROR_IF_NULL(cache_vocab_size_addr_); | ||||
| rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); | |||||
| return true; | |||||
| int copy_value = SizeToInt(cache_vocab_size); | |||||
| if (!CopyHostMemToDevice(cache_vocab_size_addr_, ©_value, sizeof(int))) { | |||||
| return false; | |||||
| } | |||||
| return SynchronizeStream(); | |||||
| } | } | ||||
| bool AscendPsCache::RecordEvent() { | bool AscendPsCache::RecordEvent() { | ||||
| @@ -51,7 +51,7 @@ class AscendPsCache : public PsCacheBasic { | |||||
| ~AscendPsCache() override = default; | ~AscendPsCache() override = default; | ||||
| bool InitDevice(uint32_t device_id, const void *context) override; | bool InitDevice(uint32_t device_id, const void *context) override; | ||||
| void *MallocMemory(size_t size) override; | void *MallocMemory(size_t size) override; | ||||
| bool MallocConstantMemory(size_t constant_value) override; | |||||
| bool MallocConstantMemory(size_t cache_vocab_size) override; | |||||
| bool RecordEvent() override; | bool RecordEvent() override; | ||||
| bool SynchronizeEvent() override; | bool SynchronizeEvent() override; | ||||
| bool SynchronizeStream() override; | bool SynchronizeStream() override; | ||||
| @@ -34,7 +34,7 @@ class PsCacheBasic { | |||||
| virtual ~PsCacheBasic() = default; | virtual ~PsCacheBasic() = default; | ||||
| virtual bool InitDevice(uint32_t device_id, const void *context) = 0; | virtual bool InitDevice(uint32_t device_id, const void *context) = 0; | ||||
| virtual void *MallocMemory(size_t size) = 0; | virtual void *MallocMemory(size_t size) = 0; | ||||
| virtual bool MallocConstantMemory(size_t constant_value) { return true; } | |||||
| virtual bool MallocConstantMemory(size_t cache_vocab_size) { return true; } | |||||
| virtual bool RecordEvent() = 0; | virtual bool RecordEvent() = 0; | ||||
| virtual bool SynchronizeEvent() = 0; | virtual bool SynchronizeEvent() = 0; | ||||
| virtual bool SynchronizeStream() = 0; | virtual bool SynchronizeStream() = 0; | ||||
| @@ -674,6 +674,7 @@ bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { | |||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( | RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( | ||||
| hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, | ||||
| hash_table_size, embedding_size, swap_indices_size)); | hash_table_size, embedding_size, swap_indices_size)); | ||||
| RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -168,7 +168,10 @@ class EmbeddingLookup(Cell): | |||||
| max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 | max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 | ||||
| or None. Default: None | or None. Default: None | ||||
| sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. | ||||
| vocab_cache_size (int): Cache size of the dictionary of embeddings. | |||||
| vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in | |||||
| parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding | |||||
| optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE' | |||||
| memory, so suggests setting a reasonable value to avoid insufficient memory. | |||||
| Inputs: | Inputs: | ||||
| - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. | ||||