|
|
|
@@ -42,9 +42,21 @@ int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { |
|
|
|
return compress_count; |
|
|
|
} |
|
|
|
|
|
|
|
void UpdateShape(size_t miss_count, const CNodePtr &node_) { |
|
|
|
std::vector<size_t> out_shape; |
|
|
|
out_shape.emplace_back(miss_count); |
|
|
|
std::vector<TypeId> dtypes; |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(node_); |
|
|
|
for (size_t i = 0; i < output_num; i++) { |
|
|
|
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); |
|
|
|
} |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, |
|
|
|
node_.get()); |
|
|
|
} |
|
|
|
|
|
|
|
void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node); |
|
|
|
node_ = kernel_node; |
|
|
|
node_wpt_ = kernel_node; |
|
|
|
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
if (hashmap_shape.size() != 2) { |
|
|
|
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; |
|
|
|
@@ -73,6 +85,7 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, |
|
|
|
template <typename T> |
|
|
|
void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<kernel::AddressPtr> &outputs) { |
|
|
|
auto node_ = node_wpt_.lock(); |
|
|
|
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); |
|
|
|
batch_size_ = 1; |
|
|
|
for (size_t i = 0; i < emb_idx_shape.size(); ++i) { |
|
|
|
@@ -92,7 +105,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
float total_count = 0; |
|
|
|
int count_size = 0; |
|
|
|
float hit_count = 0; |
|
|
|
|
|
|
|
// search_cache_idx |
|
|
|
for (size_t i = 0; i < batch_size_; ++i) { |
|
|
|
T key = input_indices[i] - offset; |
|
|
|
@@ -107,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
tmp_entry = (tmp_entry + 1) % hashmap_length_; |
|
|
|
if (count > hashmap_length_) { |
|
|
|
MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!"; |
|
|
|
break; |
|
|
|
} |
|
|
|
count += 1; |
|
|
|
} |
|
|
|
@@ -130,7 +141,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
MS_LOG(INFO) << "Avg search count: " << total_count / count_size; |
|
|
|
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; |
|
|
|
} |
|
|
|
|
|
|
|
float total_insert_count = 0; |
|
|
|
float total_delete_count = 0; |
|
|
|
// swap hash map |
|
|
|
@@ -142,7 +152,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
entry = (entry + 1) % hashmap_length_; |
|
|
|
if (tag_count > hashmap_length_) { |
|
|
|
MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!"; |
|
|
|
break; |
|
|
|
} |
|
|
|
tag_count++; |
|
|
|
} |
|
|
|
@@ -155,7 +164,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
tmp_entry = (tmp_entry + 1) % hashmap_length_; |
|
|
|
if (delete_count > hashmap_length_) { |
|
|
|
MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!"; |
|
|
|
break; |
|
|
|
} |
|
|
|
delete_count++; |
|
|
|
} |
|
|
|
@@ -171,22 +179,11 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, |
|
|
|
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; |
|
|
|
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; |
|
|
|
} |
|
|
|
// update step |
|
|
|
step_[0] += 1; |
|
|
|
// update cache idx |
|
|
|
for (size_t i = 0; i < miss_count; ++i) { |
|
|
|
int idx = miss_idx[i]; |
|
|
|
output_cache_idx[idx] = output_swap_cache_idx[i]; |
|
|
|
output_cache_idx[miss_idx[i]] = output_swap_cache_idx[i]; |
|
|
|
} |
|
|
|
std::vector<size_t> out_shape; |
|
|
|
out_shape.emplace_back(miss_count); |
|
|
|
std::vector<TypeId> dtypes; |
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(node_); |
|
|
|
for (size_t i = 0; i < output_num; i++) { |
|
|
|
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); |
|
|
|
} |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, |
|
|
|
node_.get()); |
|
|
|
UpdateShape(miss_count, node_); |
|
|
|
} |
|
|
|
} // namespace kernel |
|
|
|
} // namespace mindspore |