| @@ -183,7 +183,7 @@ bool AscendPsCache::SynchronizeStream() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| bool AscendPsCache::CopyHostMemToDevice(void *dst, const void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | MS_ERROR_IF_NULL(dst); | ||||
| MS_ERROR_IF_NULL(src); | MS_ERROR_IF_NULL(src); | ||||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); | ||||
| @@ -194,7 +194,7 @@ bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||||
| bool AscendPsCache::CopyDeviceMemToHost(void *dst, const void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | MS_ERROR_IF_NULL(dst); | ||||
| MS_ERROR_IF_NULL(src); | MS_ERROR_IF_NULL(src); | ||||
| auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); | ||||
| @@ -55,8 +55,8 @@ class AscendPsCache : public PsCacheBasic { | |||||
| bool RecordEvent() override; | bool RecordEvent() override; | ||||
| bool SynchronizeEvent() override; | bool SynchronizeEvent() override; | ||||
| bool SynchronizeStream() override; | bool SynchronizeStream() override; | ||||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||||
| bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override; | |||||
| bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override; | |||||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, | bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, | ||||
| size_t embedding_size, size_t swap_out_size) override; | size_t embedding_size, size_t swap_out_size) override; | ||||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, | bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, | ||||
| @@ -61,7 +61,7 @@ bool GPUPsCache::SynchronizeStream() { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| bool GPUPsCache::CopyHostMemToDevice(void *dst, const void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | MS_ERROR_IF_NULL(dst); | ||||
| MS_ERROR_IF_NULL(src); | MS_ERROR_IF_NULL(src); | ||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | ||||
| @@ -70,7 +70,7 @@ bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { | |||||
| bool GPUPsCache::CopyDeviceMemToHost(void *dst, const void *src, size_t size) { | |||||
| MS_ERROR_IF_NULL(dst); | MS_ERROR_IF_NULL(dst); | ||||
| MS_ERROR_IF_NULL(src); | MS_ERROR_IF_NULL(src); | ||||
| CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( | ||||
| @@ -33,8 +33,8 @@ class GPUPsCache : public PsCacheBasic { | |||||
| bool RecordEvent() override; | bool RecordEvent() override; | ||||
| bool SynchronizeEvent() override; | bool SynchronizeEvent() override; | ||||
| bool SynchronizeStream() override; | bool SynchronizeStream() override; | ||||
| bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; | |||||
| bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; | |||||
| bool CopyHostMemToDevice(void *dst, const void *src, size_t size) override; | |||||
| bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) override; | |||||
| bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, | bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t cache_vocab_size, | ||||
| size_t embedding_size, size_t swap_out_size) override; | size_t embedding_size, size_t swap_out_size) override; | ||||
| bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, | bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t cache_vocab_size, | ||||
| @@ -38,8 +38,8 @@ class PsCacheBasic { | |||||
| virtual bool RecordEvent() = 0; | virtual bool RecordEvent() = 0; | ||||
| virtual bool SynchronizeEvent() = 0; | virtual bool SynchronizeEvent() = 0; | ||||
| virtual bool SynchronizeStream() = 0; | virtual bool SynchronizeStream() = 0; | ||||
| virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; | |||||
| virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; | |||||
| virtual bool CopyHostMemToDevice(void *dst, const void *src, size_t size) = 0; | |||||
| virtual bool CopyDeviceMemToHost(void *dst, const void *src, size_t size) = 0; | |||||
| virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, | ||||
| size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0; | size_t cache_vocab_size, size_t embedding_size, size_t swap_out_size) = 0; | ||||
| virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, | ||||
| @@ -309,7 +309,7 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { | |||||
| data_prase_.notify_one(); | data_prase_.notify_one(); | ||||
| } | } | ||||
| void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||||
| void PsCacheManager::DoProcessData(uint32_t device_id, const void *context) { | |||||
| // PS embeddingLookup cache check. | // PS embeddingLookup cache check. | ||||
| if (!initialized_ps_cache_) { | if (!initialized_ps_cache_) { | ||||
| MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " | ||||
| @@ -318,7 +318,7 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { | |||||
| process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); | ||||
| } | } | ||||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||||
| void PsCacheManager::ProcessDataTask(uint32_t device_id, const void *context) { | |||||
| MS_LOG(INFO) << "PS embedding cache process data task begin."; | MS_LOG(INFO) << "PS embedding cache process data task begin."; | ||||
| running_ = true; | running_ = true; | ||||
| embedding_device_cache_->cache_->InitDevice(device_id, context); | embedding_device_cache_->cache_->InitDevice(device_id, context); | ||||
| @@ -670,12 +670,14 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(ERROR) << "LookUpTable task memcpy failed."; | MS_LOG(ERROR) << "LookUpTable task memcpy failed."; | ||||
| running_ = false; | running_ = false; | ||||
| return; | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(ERROR) << "LookUpTable task memset failed."; | MS_LOG(ERROR) << "LookUpTable task memset failed."; | ||||
| running_ = false; | running_ = false; | ||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| output_addr += outer_dim_size; | output_addr += outer_dim_size; | ||||
| @@ -712,8 +714,8 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l | |||||
| return running_; | return running_; | ||||
| } | } | ||||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, | |||||
| float *insert_data, float *hash_table_addr) { | |||||
| bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices, | |||||
| const float *insert_data, float *hash_table_addr) { | |||||
| size_t first_dim_size = host_vocab_cache_size_; | size_t first_dim_size = host_vocab_cache_size_; | ||||
| size_t thread_num = insert_indices_size / 10000 + 1; | size_t thread_num = insert_indices_size / 10000 + 1; | ||||
| thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; | ||||
| @@ -723,7 +725,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||||
| size_t task_offset = 0; | size_t task_offset = 0; | ||||
| auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, | ||||
| int *insert_indices, float *insert_data, float *hash_table_addr) { | |||||
| const int *insert_indices, const float *insert_data, float *hash_table_addr) { | |||||
| auto type_size = sizeof(float); | auto type_size = sizeof(float); | ||||
| size_t lens = outer_dim_size * type_size; | size_t lens = outer_dim_size * type_size; | ||||
| for (size_t i = 0; i < insert_indices_size; ++i) { | for (size_t i = 0; i < insert_indices_size; ++i) { | ||||
| @@ -733,6 +735,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(ERROR) << "Insert hash table task memcpy failed."; | MS_LOG(ERROR) << "Insert hash table task memcpy failed."; | ||||
| running_ = false; | running_ = false; | ||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -893,7 +896,7 @@ bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, std::vector<float> * | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, | |||||
| bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info, | |||||
| size_t key) { | size_t key) { | ||||
| MS_ERROR_IF_NULL(swap_in_ids); | MS_ERROR_IF_NULL(swap_in_ids); | ||||
| MS_ERROR_IF_NULL(swap_in_index); | MS_ERROR_IF_NULL(swap_in_index); | ||||
| @@ -129,7 +129,7 @@ class PsCacheManager { | |||||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | bool initialized_ps_cache() const { return initialized_ps_cache_; } | ||||
| size_t vocab_cache_size() const { return vocab_cache_size_; } | size_t vocab_cache_size() const { return vocab_cache_size_; } | ||||
| int cache_indices_lower_bound() const; | int cache_indices_lower_bound() const; | ||||
| void DoProcessData(uint32_t device_id, void *context); | |||||
| void DoProcessData(uint32_t device_id, const void *context); | |||||
| void IncreaseGraphStep(const std::string &channel_name); | void IncreaseGraphStep(const std::string &channel_name); | ||||
| void SyncEmbeddingTable(); | void SyncEmbeddingTable(); | ||||
| void Finalize(); | void Finalize(); | ||||
| @@ -148,7 +148,7 @@ class PsCacheManager { | |||||
| void InitDataChannel(); | void InitDataChannel(); | ||||
| void AllocMemForHashTable(); | void AllocMemForHashTable(); | ||||
| void SetLocalIdRank(); | void SetLocalIdRank(); | ||||
| void ProcessDataTask(uint32_t device_id, void *context); | |||||
| void ProcessDataTask(uint32_t device_id, const void *context); | |||||
| bool ProcessData(); | bool ProcessData(); | ||||
| bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); | ||||
| bool WaitGraphRun(); | bool WaitGraphRun(); | ||||
| @@ -156,13 +156,13 @@ class PsCacheManager { | |||||
| bool ParseHostDataHostToDevice(size_t id); | bool ParseHostDataHostToDevice(size_t id); | ||||
| bool ParseHostDataDeviceToHost(); | bool ParseHostDataDeviceToHost(); | ||||
| bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info); | bool HashSwapDeviceOut(int *swap_out_index, std::vector<float> *swap_out_data, const HashTableInfo &hash_info); | ||||
| bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||||
| bool HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in_index, const HashTableInfo &hash_info, size_t key); | |||||
| bool HashSwapHostToDevice(const HashTableInfo &hash_info); | bool HashSwapHostToDevice(const HashTableInfo &hash_info); | ||||
| bool HashSwapDeviceToHost(const HashTableInfo &hash_info); | bool HashSwapDeviceToHost(const HashTableInfo &hash_info); | ||||
| bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); | ||||
| bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); | ||||
| bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, | |||||
| float *hash_table_addr); | |||||
| bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, const int *insert_indices, | |||||
| const float *insert_data, float *hash_table_addr); | |||||
| bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, | ||||
| const int *indices_addr, float *output_addr); | const int *indices_addr, float *output_addr); | ||||
| bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key); | bool UpdataEmbeddingTable(const std::vector<float> &swap_out_data, int *swap_out_ids, size_t key); | ||||
| @@ -54,7 +54,7 @@ class AscendKernelRuntime : public KernelRuntime { | |||||
| bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override; | bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override; | ||||
| void SetContext() override; | void SetContext() override; | ||||
| void CreateContext() override; | void CreateContext() override; | ||||
| void *context() const override { return rt_context_; } | |||||
| const void *context() const override { return rt_context_; } | |||||
| void PreInit() override; | void PreInit() override; | ||||
| uint64_t GetAvailableMemMaxSize() const override; | uint64_t GetAvailableMemMaxSize() const override; | ||||
| DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; }; | DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; }; | ||||
| @@ -81,7 +81,7 @@ class KernelRuntime { | |||||
| virtual void ClearGlobalIdleMem() {} | virtual void ClearGlobalIdleMem() {} | ||||
| virtual void CreateContext() {} | virtual void CreateContext() {} | ||||
| virtual void SetContext() {} | virtual void SetContext() {} | ||||
| virtual void *context() const { return nullptr; } | |||||
| virtual const void *context() const { return nullptr; } | |||||
| uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { | ||||
| return mem_manager_->MallocMem(type, size, address); | return mem_manager_->MallocMem(type, size, address); | ||||
| } | } | ||||