|
|
|
@@ -81,11 +81,9 @@ void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
node_ = kernel_node; |
|
|
|
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
|
|
|
|
if (hashmap_shape.size() != 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; |
|
|
|
} |
|
|
|
|
|
|
|
hashmap_length_ = hashmap_shape[0]; |
|
|
|
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); |
|
|
|
} |
|
|
|
@@ -121,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr); |
|
|
|
auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr); |
|
|
|
auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr); |
|
|
|
|
|
|
|
std::vector<T> miss_idx; |
|
|
|
size_t miss_count = 0; |
|
|
|
float total_count = 0; |
|
|
|
@@ -134,9 +131,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
output_cache_idx[i] = -1; |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
T tmp_entry = HashFunc(key, hashmap_length_); |
|
|
|
|
|
|
|
size_t count = 1; |
|
|
|
count_size += 1; |
|
|
|
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { |
|
|
|
@@ -147,7 +142,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
} |
|
|
|
count += 1; |
|
|
|
} |
|
|
|
|
|
|
|
total_count += count; |
|
|
|
if (hashmap[tmp_entry].IsEmpty()) { |
|
|
|
miss_idx.emplace_back(i); |
|
|
|
@@ -163,10 +157,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
MS_LOG(INFO) << "Miss count: " << miss_count; |
|
|
|
MS_LOG(INFO) << "Avg search count: " << total_count / count_size; |
|
|
|
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; |
|
|
|
|
|
|
|
float total_insert_count = 0; |
|
|
|
float total_delete_count = 0; |
|
|
|
|
|
|
|
// swap hash map |
|
|
|
for (size_t i = 0; i < miss_count; ++i) { |
|
|
|
T emb_idx = output_miss_emb_idx[i]; |
|
|
|
@@ -180,11 +172,9 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
} |
|
|
|
tag_count++; |
|
|
|
} |
|
|
|
|
|
|
|
hashmap[entry].key = emb_idx; |
|
|
|
hashmap[entry].step = step_[0]; |
|
|
|
hashmap[entry].tag = tag_count; |
|
|
|
|
|
|
|
T tmp_entry = (entry + 1) % hashmap_length_; |
|
|
|
size_t delete_count = 1; |
|
|
|
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { |
|
|
|
@@ -195,7 +185,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
} |
|
|
|
delete_count++; |
|
|
|
} |
|
|
|
|
|
|
|
output_swap_cache_idx[i] = hashmap[tmp_entry].value; |
|
|
|
output_old_emb_idx[i] = hashmap[tmp_entry].key; |
|
|
|
hashmap[entry].value = output_swap_cache_idx[i]; |
|
|
|
@@ -204,19 +193,15 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
total_delete_count += (compress_count + delete_count); |
|
|
|
total_insert_count += tag_count; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; |
|
|
|
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; |
|
|
|
|
|
|
|
// update step |
|
|
|
step_[0] += 1; |
|
|
|
|
|
|
|
// update cache idx |
|
|
|
for (size_t i = 0; i < miss_count; ++i) { |
|
|
|
int idx = miss_idx[i]; |
|
|
|
output_cache_idx[idx] = output_swap_cache_idx[i]; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<size_t> out_shape; |
|
|
|
out_shape.emplace_back(miss_count); |
|
|
|
std::vector<TypeId> dtypes; |
|
|
|
|