| @@ -183,7 +183,7 @@ bool AscendPsCache::SynchronizeStream() { | |||
| return true; | |||
| } | |||
| bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| bool AscendPsCache::CopyHostMemToDevice(void *dst, const void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | |||
| @@ -194,7 +194,7 @@ bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| return true; | |||
| } | |||
| bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| bool AscendPsCache::CopyDeviceMemToHost(void *dst, const void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | |||
| @@ -55,8 +55,8 @@ class AscendPsCache : public PsCacheBasic { | |||
| bool RecordEvent() override; | |||
| bool SynchronizeEvent() override; | |||
| bool SynchronizeStream() override; | |||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override; | |||
| bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override; | |||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, | |||
| size_t embedding_size, size_t swap_out_size) override; | |||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, | |||
| @@ -61,7 +61,7 @@ bool GPUPsCache::SynchronizeStream() { | |||
| return true; | |||
| } | |||
| bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| bool GPUPsCache::CopyHostMemToDevice(void *dst, const void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | |||
| @@ -70,7 +70,7 @@ bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||
| return true; | |||
| } | |||
| bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||
| bool GPUPsCache::CopyDeviceMemToHost(void *dst, const void *src, size_t size) { | |||
| MS_ERROR_IF_NULL(dst); | |||
| MS_ERROR_IF_NULL(src); | |||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | |||
| @@ -33,8 +33,8 @@ class GPUPsCache : public PsCacheBasic { | |||
| bool RecordEvent() override; | |||
| bool SynchronizeEvent() override; | |||
| bool SynchronizeStream() override; | |||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||
| bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override; | |||
| bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override; | |||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, | |||
| size_t embedding_size, size_t swap_out_size) override; | |||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, | |||
| @@ -38,8 +38,8 @@ class PsCacheBasic { | |||
| virtual bool RecordEvent() = 0; | |||
| virtual bool SynchronizeEvent() = 0; | |||
| virtual bool SynchronizeStream() = 0; | |||
| virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||
| virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||
| virtual bool CopyHostMemToDevice(void *dst, const void *src, size_t size) = 0; | |||
| virtual bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) = 0; | |||
| virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | |||
| size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0; | |||
| virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | |||
| @@ -309,7 +309,7 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||
| data_prase_.notify_one(); | |||
| } | |||
| void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||
| void PsCacheManager::DoProcessData(uint32_t device_id, const void *context) { | |||
| // PS embeddingLookup cache check. | |||
| if (!initialized_ps_cache_) { | |||
| MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | |||
| @@ -318,7 +318,7 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||
| process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | |||
| } | |||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) { | |||
| MS_LOG(INFO) << "PS embedding cache process data task begin."; | |||
| running_ = true; | |||
| embedding_device_cache_->cache_->InitDevice(device_id, context); | |||
| @@ -670,12 +670,14 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "LookUpTable task memcpy failed."; | |||
| running_ = false; | |||
| return; | |||
| } | |||
| } else { | |||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "LookUpTable task memset failed."; | |||
| running_ = false; | |||
| return; | |||
| } | |||
| } | |||
| output_addr += outer_dim_size; | |||
| @@ -712,8 +714,8 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l | |||
| return running_; | |||
| } | |||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||
| float *insert_data, float *hash_table_addr) { | |||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices, | |||
| const float *insert_data, float *hash_table_addr) { | |||
| size_t first_dim_size = host_vocab_cache_size_; | |||
| size_t thread_num = insert_indices_size / 10000 + 1; | |||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | |||
| @@ -723,7 +725,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||
| size_t task_offset = 0; | |||
| auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | |||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||
| const int *insert_indices, const float *insert_data, float *hash_table_addr) { | |||
| auto type_size = sizeof(float); | |||
| size_t lens = outer_dim_size * type_size; | |||
| for (size_t i = 0; i < insert_indices_size; ++i) { | |||
| @@ -733,6 +735,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "Insert hash table task memcpy failed."; | |||
| running_ = false; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| @@ -893,7 +896,7 @@ bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> * | |||
| return true; | |||
| } | |||
| bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||
| bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info, | |||
| size_t key) { | |||
| MS_ERROR_IF_NULL(swap_in_ids); | |||
| MS_ERROR_IF_NULL(swap_in_index); | |||
| @@ -129,7 +129,7 @@ class PsCacheManager { | |||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | |||
| size_t vocab_cache_size() const { return vocab_cache_size_; } | |||
| int cache_indices_lower_bound() const; | |||
| void DoProcessData(uint32_t device_id, void *context); | |||
| void DoProcessData(uint32_t device_id, const void *context); | |||
| void IncreaseGraphStep(const std::string &channel_name); | |||
| void SyncEmbeddingTable(); | |||
| void Finalize(); | |||
| @@ -148,7 +148,7 @@ class PsCacheManager { | |||
| void InitDataChannel(); | |||
| void AllocMemForHashTable(); | |||
| void SetLocalIdRank(); | |||
| void ProcessDataTask(uint32_t device_id, void *context); | |||
| void ProcessDataTask(uint32_t device_id, const void *context); | |||
| bool ProcessData(); | |||
| bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | |||
| bool WaitGraphRun(); | |||
| @@ -156,13 +156,13 @@ class PsCacheManager { | |||
| bool ParseHostDataHostToDevice(size_t id); | |||
| bool ParseHostDataDeviceToHost(); | |||
| bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||
| bool HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||
| bool HashSwapHostToDevice(const HashTableInfo &hash_info); | |||
| bool HashSwapDeviceToHost(const HashTableInfo &hash_info); | |||
| bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | |||
| bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | |||
| bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||
| float *hash_table_addr); | |||
| bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices, | |||
| const float *insert_data, float *hash_table_addr); | |||
| bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | |||
| const int *indices_addr, float *output_addr); | |||
| bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key); | |||
| @@ -54,7 +54,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||
| bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override; | |||
| void SetContext() override; | |||
| void CreateContext() override; | |||
| void *context() const override { return rt_context_; } | |||
| const void *context() const override { return rt_context_; } | |||
| void PreInit() override; | |||
| uint64_t GetAvailableMemMaxSize() const override; | |||
| DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; }; | |||
| @@ -81,7 +81,7 @@ class KernelRuntime { | |||
| virtual void ClearGlobalIdleMem() {} | |||
| virtual void CreateContext() {} | |||
| virtual void SetContext() {} | |||
| virtual void *context() const { return nullptr; } | |||
| virtual const void *context() const { return nullptr; } | |||
| uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | |||
| return mem_manager_->MallocMem(type, size, address); | |||
| } | |||