Browse Source

fix CodeDex warning

pull/15114/head
lizhenyu 4 years ago
parent
commit
70aad9820d
2 changed files with 29 additions and 20 deletions
  1. +3
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc
  2. +26
    -18
      mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc

+ 3
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.cc View File

@@ -74,12 +74,13 @@ bool EmbeddingLookUpPSKernel::Execute(const std::vector<AddressPtr> &inputs, con

void EmbeddingLookUpPSKernel::UpdateEmbeddings(float *embedding_table, const size_t *lookup_ids,
const float *update_vals, size_t ids_size) {
size_t copy_lens = outer_dim_size_ * sizeof(float);
size_t copy_len = outer_dim_size_ * sizeof(float);
size_t dest_len = copy_len;
for (size_t i = 0; i < ids_size; ++i) {
int index = lookup_ids[i] - offset_;
if (index >= 0 && index < SizeToInt(first_dim_size_)) {
auto ret =
memcpy_s(embedding_table + index * outer_dim_size_, copy_lens, update_vals + i * outer_dim_size_, copy_lens);
memcpy_s(embedding_table + index * outer_dim_size_, dest_len, update_vals + i * outer_dim_size_, copy_len);
if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
}


+ 26
- 18
mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc View File

@@ -222,9 +222,8 @@ void PsCacheManager::AllocMemForHashTable() {
device_address.addr = addr;

auto &host_address = item.second.host_address;
auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size];
MS_EXCEPTION_IF_NULL(host_address_ptr);
host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>());
host_address =
std::shared_ptr<float[]>(new float[host_vocab_cache_size_ * embedding_size], std::default_delete<float[]>());
MS_EXCEPTION_IF_NULL(host_address);

max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size;
@@ -387,8 +386,9 @@ bool PsCacheManager::ProcessData() {
RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info));
RETURN_IF_FALSE(HashSwapHostToDevice(hash_info));
}
size_t dest_len = data_size;
// Replace the batch_ids by hash index for getNext-op getting hash index as input.
if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) {
if (memcpy_s(data, dest_len, hash_index.get(), data_size) != EOK) {
MS_LOG(ERROR) << "Process data memcpy failed.";
return false;
}
@@ -727,11 +727,13 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size,
const int *insert_indices, const float *insert_data, float *hash_table_addr) {
auto type_size = sizeof(float);
size_t lens = outer_dim_size * type_size;
size_t copy_len = outer_dim_size * type_size;
size_t dest_len = copy_len;
for (size_t i = 0; i < insert_indices_size; ++i) {
int index = insert_indices[i];
if (index >= 0 && index < SizeToInt(first_dim_size)) {
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens);
auto ret =
memcpy_s(hash_table_addr + index * outer_dim_size, dest_len, insert_data + i * outer_dim_size, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Insert hash table task memcpy failed.";
running_ = false;
@@ -836,8 +838,9 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index,
swap_out_data.data()));

auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len);
size_t copy_len = swap_indices_size * sizeof(int);
size_t dest_len = copy_len;
auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
@@ -858,8 +861,9 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_
auto embedding_size = hash_info.embedding_size;
std::vector<float> lookup_result(swap_indices_size * embedding_size, 0);
std::vector<int> lookup_ids(swap_indices_size, 0);
auto copy_len = swap_indices_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len);
size_t copy_len = swap_indices_size * sizeof(int);
size_t dest_len = copy_len;
auto ret = memcpy_s(lookup_ids.data(), dest_len, server_to_host_ids, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
@@ -912,8 +916,9 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device).
std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0);
std::vector<int> lookup_ids(swap_in_ids_size, 0);
auto copy_len = swap_in_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len);
size_t copy_len = swap_in_ids_size * sizeof(int);
size_t dest_len = copy_len;
auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_in_ids, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
@@ -940,8 +945,9 @@ bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_dat
return true;
}
std::vector<int> lookup_ids(swap_out_ids_size, 0);
auto copy_len = swap_out_ids_size * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len);
size_t copy_len = swap_out_ids_size * sizeof(int);
size_t dest_len = copy_len;
auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_out_ids, copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
@@ -1000,8 +1006,9 @@ bool PsCacheManager::SyncHostEmbeddingTable() {
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr,
host_to_server_indices_ptr.get(), swap_out_data.data()));

auto copy_len = swap_indices_lens * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids_ptr.get(), copy_len);
size_t copy_len = swap_indices_lens * sizeof(int);
size_t dest_len = copy_len;
auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids_ptr.get(), copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;
@@ -1052,8 +1059,9 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() {
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(),
device_to_server_indices_ptr.get(), swap_out_data.data()));

auto copy_len = swap_indices_lens * sizeof(int);
auto ret = memcpy_s(lookup_ids.data(), copy_len, device_to_server_ids_ptr.get(), copy_len);
size_t copy_len = swap_indices_lens * sizeof(int);
size_t dest_len = copy_len;
auto ret = memcpy_s(lookup_ids.data(), dest_len, device_to_server_ids_ptr.get(), copy_len);
if (ret != EOK) {
MS_LOG(ERROR) << "Lookup id memcpy failed.";
return false;


Loading…
Cancel
Save