diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc index 97694e452a..5212f1a72e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc @@ -74,12 +74,13 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector &inputs, con void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids, const float *update_vals, size_t ids_size) { - size_t copy_lens = outer_dim_size_ * sizeof(float); + size_t copy_len = outer_dim_size_ * sizeof(float); + size_t dest_len = copy_len; for (size_t i = 0; i < ids_size; ++i) { int index = lookup_ids[i] - offset_; if (index >= 0 && index < SizeToInt(first_dim_size_)) { auto ret = - memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens); + memcpy_s(embedding_table + index * outer_dim_size_, dest_len, update_vals + i * outer_dim_size_, copy_len); if (ret != EOK) { MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; } diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index b23ad5b5f3..b96a1c2d36 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -222,9 +222,8 @@ void PsCacheManager::AllocMemForHashTable() { device_address.addr = addr; auto &host_address = item.second.host_address; - auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size]; - MS_EXCEPTION_IF_NULL(host_address_ptr); - host_address = std::shared_ptr(host_address_ptr, std::default_delete()); + host_address = + std::shared_ptr(new float[host_vocab_cache_size_ * embedding_size], std::default_delete()); MS_EXCEPTION_IF_NULL(host_address); max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; @@ -387,8 +386,9 @@ bool PsCacheManager::ProcessData() { RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info)); RETURN_IF_FALSE(HashSwapHostToDevice(hash_info)); } + size_t dest_len = data_size; // Replace the batch_ids by hash index for getNext-op getting hash index as input. - if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { + if (memcpy_s(data, dest_len, hash_index.get(), data_size) != EOK) { MS_LOG(ERROR) << "Process data memcpy failed."; return false; } @@ -727,11 +727,13 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, const int *insert_indices, const float *insert_data, float *hash_table_addr) { auto type_size = sizeof(float); - size_t lens = outer_dim_size * type_size; + size_t copy_len = outer_dim_size * type_size; + size_t dest_len = copy_len; for (size_t i = 0; i < insert_indices_size; ++i) { int index = insert_indices[i]; if (index >= 0 && index < SizeToInt(first_dim_size)) { - auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); + auto ret = + memcpy_s(hash_table_addr + index * outer_dim_size, dest_len, insert_data + i * outer_dim_size, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Insert hash table task memcpy failed."; running_ = false; @@ -836,8 +838,9 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_ RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, swap_out_data.data())); - auto copy_len = swap_indices_size * sizeof(int); - auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); + size_t copy_len = swap_indices_size * sizeof(int); + size_t dest_len = copy_len; + auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; @@ -858,8 +861,9 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ auto embedding_size = hash_info.embedding_size; std::vector lookup_result(swap_indices_size * embedding_size, 0); std::vector lookup_ids(swap_indices_size, 0); - auto copy_len = swap_indices_size * sizeof(int); - auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); + size_t copy_len = swap_indices_size * sizeof(int); + size_t dest_len = copy_len; + auto ret = memcpy_s(lookup_ids.data(), dest_len, server_to_host_ids, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; @@ -912,8 +916,9 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in // Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). std::vector lookup_result(swap_in_ids_size * embedding_size, 0); std::vector lookup_ids(swap_in_ids_size, 0); - auto copy_len = swap_in_ids_size * sizeof(int); - auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); + size_t copy_len = swap_in_ids_size * sizeof(int); + size_t dest_len = copy_len; + auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_in_ids, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; @@ -940,8 +945,9 @@ bool PsCacheManager::UpdataEmbeddingTable(const std::vector &swap_out_dat return true; } std::vector lookup_ids(swap_out_ids_size, 0); - auto copy_len = swap_out_ids_size * sizeof(int); - auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); + size_t copy_len = swap_out_ids_size * sizeof(int); + size_t dest_len = copy_len; + auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_out_ids, copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; @@ -1000,8 +1006,9 @@ bool PsCacheManager::SyncHostEmbeddingTable() { RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr, host_to_server_indices_ptr.get(), swap_out_data.data())); - auto copy_len = swap_indices_lens * sizeof(int); - auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids_ptr.get(), copy_len); + size_t copy_len = swap_indices_lens * sizeof(int); + size_t dest_len = copy_len; + auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids_ptr.get(), copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false; @@ -1052,8 +1059,9 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(), device_to_server_indices_ptr.get(), swap_out_data.data())); - auto copy_len = swap_indices_lens * sizeof(int); - auto ret = memcpy_s(lookup_ids.data(), copy_len, device_to_server_ids_ptr.get(), copy_len); + size_t copy_len = swap_indices_lens * sizeof(int); + size_t dest_len = copy_len; + auto ret = memcpy_s(lookup_ids.data(), dest_len, device_to_server_ids_ptr.get(), copy_len); if (ret != EOK) { MS_LOG(ERROR) << "Lookup id memcpy failed."; return false;