From: @fangzehua Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -35,7 +35,6 @@ void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| } | } | ||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) { | if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) { | ||||
| input_x_dtype_size_ = 4; | input_x_dtype_size_ = 4; | ||||
| } else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) { | } else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) { | ||||
| @@ -75,6 +74,5 @@ void AssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | ||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -60,7 +60,6 @@ MS_REG_CPU_KERNEL( | |||||
| Assign, | Assign, | ||||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | ||||
| AssignCPUKernel); | AssignCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,7 +20,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | ||||
| T i = (entry + 1) % length, off = 1; | T i = (entry + 1) % length, off = 1; | ||||
| @@ -107,6 +106,5 @@ void CacheSwapHashmapCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inpu | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,7 +25,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| class CacheSwapHashmapCPUKernel : public CPUKernel { | class CacheSwapHashmapCPUKernel : public CPUKernel { | ||||
| public: | public: | ||||
| CacheSwapHashmapCPUKernel() = default; | CacheSwapHashmapCPUKernel() = default; | ||||
| @@ -82,7 +81,6 @@ MS_REG_CPU_KERNEL(CacheSwapHashmap, | |||||
| .AddOutputAttr(kNumberTypeInt32) | .AddOutputAttr(kNumberTypeInt32) | ||||
| .AddOutputAttr(kNumberTypeInt32), | .AddOutputAttr(kNumberTypeInt32), | ||||
| CacheSwapHashmapCPUKernel); | CacheSwapHashmapCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -22,7 +22,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| struct HashmapEntry { | struct HashmapEntry { | ||||
| T key; | T key; | ||||
| @@ -60,8 +59,9 @@ T HashFunc(const T &key, const size_t &m) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||||
| int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||||
| T i = (entry + 1) % length, off = 1; | T i = (entry + 1) % length, off = 1; | ||||
| int compress_count = 0; | |||||
| for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { | for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { | ||||
| if (entry_p[i].tag > off) { | if (entry_p[i].tag > off) { | ||||
| entry_p[entry].key = entry_p[i].key; | entry_p[entry].key = entry_p[i].key; | ||||
| @@ -72,21 +72,20 @@ void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||||
| off = 0; | off = 0; | ||||
| entry = i; | entry = i; | ||||
| } | } | ||||
| compress_count++; | |||||
| } | } | ||||
| return compress_count; | |||||
| } | } | ||||
| void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| node_ = kernel_node; | |||||
| auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| if (hashmap_shape.size() != 2) { | if (hashmap_shape.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; | MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; | ||||
| } | } | ||||
| for (size_t i = 0; i < emb_idx_shape.size(); ++i) { | |||||
| batch_size_ *= emb_idx_shape[i]; | |||||
| } | |||||
| hashmap_length_ = hashmap_shape[0]; | hashmap_length_ = hashmap_shape[0]; | ||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| } | } | ||||
| @@ -108,100 +107,124 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| template <typename T> | template <typename T> | ||||
| void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | |||||
| batch_size_ = 1; | |||||
| for (size_t i = 0; i < emb_idx_shape.size(); ++i) { | |||||
| batch_size_ *= emb_idx_shape[i]; | |||||
| } | |||||
| HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr); | HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr); | ||||
| auto input_indices = reinterpret_cast<T *>(inputs[1]->addr); | auto input_indices = reinterpret_cast<T *>(inputs[1]->addr); | ||||
| T *step_ = reinterpret_cast<T *>(inputs[2]->addr); | T *step_ = reinterpret_cast<T *>(inputs[2]->addr); | ||||
| T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr); | T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr); | ||||
| T cache_max_num = *reinterpret_cast<T *>(inputs[4]->addr); | |||||
| T offset = *reinterpret_cast<T *>(inputs[4]->addr); | |||||
| auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr); | auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr); | ||||
| auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr); | auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr); | ||||
| auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr); | auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr); | ||||
| auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr); | auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr); | ||||
| std::vector<T> output_miss_idx(batch_size_, -1); | |||||
| std::vector<T> miss_idx; | |||||
| size_t miss_count = 0; | |||||
| float total_count = 0; | float total_count = 0; | ||||
| int count_size = 0; | int count_size = 0; | ||||
| float hit_count = 0; | float hit_count = 0; | ||||
| // search_cache_idx | // search_cache_idx | ||||
| for (size_t i = 0; i < batch_size_; ++i) { | for (size_t i = 0; i < batch_size_; ++i) { | ||||
| if (input_indices[i] == emb_max_num) { | |||||
| output_miss_idx[i] = -1; | |||||
| output_cache_idx[i] = cache_max_num; | |||||
| output_miss_emb_idx[i] = -1; | |||||
| T key = input_indices[i] - offset; | |||||
| if (key >= emb_max_num || key < 0) { | |||||
| output_cache_idx[i] = -1; | |||||
| continue; | continue; | ||||
| } | } | ||||
| T key = input_indices[i]; | |||||
| T tmp_entry = HashFunc(key, hashmap_length_); | T tmp_entry = HashFunc(key, hashmap_length_); | ||||
| int count = 1; | |||||
| size_t count = 1; | |||||
| count_size += 1; | count_size += 1; | ||||
| while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { | while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { | ||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | tmp_entry = (tmp_entry + 1) % hashmap_length_; | ||||
| if (count > hashmap_length_) { | |||||
| MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!"; | |||||
| break; | |||||
| } | |||||
| count += 1; | count += 1; | ||||
| } | } | ||||
| total_count += count; | total_count += count; | ||||
| if (hashmap[tmp_entry].IsEmpty()) { | if (hashmap[tmp_entry].IsEmpty()) { | ||||
| output_miss_idx[i] = i; | |||||
| output_miss_emb_idx[i] = key; | |||||
| miss_idx.emplace_back(i); | |||||
| output_miss_emb_idx[miss_count] = key; | |||||
| output_cache_idx[i] = -1; | output_cache_idx[i] = -1; | ||||
| miss_count++; | |||||
| } else { | } else { | ||||
| hit_count += 1; | hit_count += 1; | ||||
| output_miss_idx[i] = -1; | |||||
| output_cache_idx[i] = hashmap[tmp_entry].value; | output_cache_idx[i] = hashmap[tmp_entry].value; | ||||
| hashmap[tmp_entry].step = step_[0]; | hashmap[tmp_entry].step = step_[0]; | ||||
| output_miss_emb_idx[i] = -1; | |||||
| } | } | ||||
| } | } | ||||
| MS_LOG(INFO) << "avg search count: " << total_count / count_size; | |||||
| MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; | |||||
| MS_LOG(INFO) << "Miss count: " << miss_count; | |||||
| MS_LOG(INFO) << "Avg search count: " << total_count / count_size; | |||||
| MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; | |||||
| float total_insert_count = 0; | |||||
| float total_delete_count = 0; | |||||
| // swap hash map | // swap hash map | ||||
| for (size_t i = 0; i < batch_size_; ++i) { | |||||
| if (output_miss_emb_idx[i] < 0) { | |||||
| output_swap_cache_idx[i] = -1; | |||||
| output_old_emb_idx[i] = -1; | |||||
| } else { | |||||
| T emb_idx = output_miss_emb_idx[i]; | |||||
| T entry = HashFunc(emb_idx, hashmap_length_); | |||||
| T tag_count = 1; | |||||
| while (!hashmap[entry].IsEmpty()) { | |||||
| entry = (entry + 1) % hashmap_length_; | |||||
| tag_count++; | |||||
| for (size_t i = 0; i < miss_count; ++i) { | |||||
| T emb_idx = output_miss_emb_idx[i]; | |||||
| T entry = HashFunc(emb_idx, hashmap_length_); | |||||
| size_t tag_count = 1; | |||||
| while (!hashmap[entry].IsEmpty()) { | |||||
| entry = (entry + 1) % hashmap_length_; | |||||
| if (tag_count > hashmap_length_) { | |||||
| MS_LOG(ERROR) << "Hashmap is full, insert new key failed!"; | |||||
| break; | |||||
| } | } | ||||
| tag_count++; | |||||
| } | |||||
| hashmap[entry].key = emb_idx; | |||||
| hashmap[entry].step = step_[0]; | |||||
| hashmap[entry].tag = tag_count; | |||||
| T tmp_entry = (entry + 1) % hashmap_length_; | |||||
| hashmap[entry].key = emb_idx; | |||||
| hashmap[entry].step = step_[0]; | |||||
| hashmap[entry].tag = tag_count; | |||||
| while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | |||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||||
| T tmp_entry = (entry + 1) % hashmap_length_; | |||||
| size_t delete_count = 1; | |||||
| while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | |||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||||
| if (delete_count > hashmap_length_) { | |||||
| MS_LOG(ERROR) << "Hashmap is full, delete old key failed!"; | |||||
| break; | |||||
| } | } | ||||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value; | |||||
| output_old_emb_idx[i] = hashmap[tmp_entry].key; | |||||
| hashmap[entry].value = output_swap_cache_idx[i]; | |||||
| hashmap[tmp_entry].SetEmpty(); | |||||
| Compress(hashmap, hashmap_length_, tmp_entry); | |||||
| delete_count++; | |||||
| } | } | ||||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value; | |||||
| output_old_emb_idx[i] = hashmap[tmp_entry].key; | |||||
| hashmap[entry].value = output_swap_cache_idx[i]; | |||||
| hashmap[tmp_entry].SetEmpty(); | |||||
| int compress_count = Compress(hashmap, hashmap_length_, tmp_entry); | |||||
| total_delete_count += (compress_count + delete_count); | |||||
| total_insert_count += tag_count; | |||||
| } | } | ||||
| MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; | |||||
| MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; | |||||
| // update step | // update step | ||||
| step_[0] += 1; | step_[0] += 1; | ||||
| // update cache idx | // update cache idx | ||||
| for (size_t i = 0; i < batch_size_; ++i) { | |||||
| if (output_miss_idx[i] < 0 || output_miss_idx[i] >= cache_max_num) { | |||||
| continue; | |||||
| } | |||||
| output_cache_idx[i] = output_swap_cache_idx[i]; | |||||
| for (size_t i = 0; i < miss_count; ++i) { | |||||
| int idx = miss_idx[i]; | |||||
| output_cache_idx[idx] = output_swap_cache_idx[i]; | |||||
| } | } | ||||
| } | |||||
| std::vector<size_t> out_shape; | |||||
| out_shape.emplace_back(miss_count); | |||||
| std::vector<TypeId> dtypes; | |||||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { | |||||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, | |||||
| node_.get()); | |||||
| } | |||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| class MapCacheIdxCPUKernel : public CPUKernel { | class MapCacheIdxCPUKernel : public CPUKernel { | ||||
| public: | public: | ||||
| MapCacheIdxCPUKernel() = default; | MapCacheIdxCPUKernel() = default; | ||||
| @@ -45,6 +44,7 @@ class MapCacheIdxCPUKernel : public CPUKernel { | |||||
| size_t batch_size_{1}; | size_t batch_size_{1}; | ||||
| size_t hashmap_length_{1}; | size_t hashmap_length_{1}; | ||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| CNodePtr node_ = nullptr; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(MapCacheIdx, | MS_REG_CPU_KERNEL(MapCacheIdx, | ||||
| @@ -98,7 +98,6 @@ MS_REG_CPU_KERNEL(MapCacheIdx, | |||||
| .AddOutputAttr(kNumberTypeInt32) | .AddOutputAttr(kNumberTypeInt32) | ||||
| .AddOutputAttr(kNumberTypeInt32), | .AddOutputAttr(kNumberTypeInt32), | ||||
| MapCacheIdxCPUKernel); | MapCacheIdxCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -99,6 +99,5 @@ void SearchCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||||
| MS_LOG(INFO) << "avg search count: " << total_count / count_size; | MS_LOG(INFO) << "avg search count: " << total_count / count_size; | ||||
| MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; | MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -27,7 +27,6 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| struct HashmapEntry { | struct HashmapEntry { | ||||
| T key; | T key; | ||||
| @@ -133,7 +132,6 @@ MS_REG_CPU_KERNEL(SearchCacheIdx, | |||||
| .AddOutputAttr(kNumberTypeInt32) | .AddOutputAttr(kNumberTypeInt32) | ||||
| .AddOutputAttr(kNumberTypeInt32), | .AddOutputAttr(kNumberTypeInt32), | ||||
| SearchCacheIdxCPUKernel); | SearchCacheIdxCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,20 +21,9 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||||
| auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||||
| if (indices_shape.size() < 2) { | |||||
| MS_LOG(EXCEPTION) << "indices shape less than 2"; | |||||
| } | |||||
| for (size_t i = 0; i < indices_shape.size(); ++i) { | |||||
| batch_size_ *= indices_shape[i]; | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||||
| node_ = kernel_node; | |||||
| for (size_t i = 0; i < update_shape.size(); ++i) { | |||||
| update_size_ *= update_shape[i]; | |||||
| } | |||||
| update_length_ = update_size_ / batch_size_; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); | indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); | ||||
| @@ -64,6 +53,19 @@ bool UpdateCacheCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| template <typename T> | template <typename T> | ||||
| void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | |||||
| auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2); | |||||
| batch_size_ = 1; | |||||
| for (size_t i = 0; i < indices_shape.size(); ++i) { | |||||
| batch_size_ *= indices_shape[i]; | |||||
| } | |||||
| MS_LOG(INFO) << "UpdateCache batch_size:" << batch_size_; | |||||
| update_size_ = 1; | |||||
| for (size_t i = 0; i < update_shape.size(); ++i) { | |||||
| update_size_ *= update_shape[i]; | |||||
| } | |||||
| update_length_ = update_shape[1]; | |||||
| char *input_x = reinterpret_cast<char *>(inputs[0]->addr); | char *input_x = reinterpret_cast<char *>(inputs[0]->addr); | ||||
| T *indices = reinterpret_cast<T *>(inputs[1]->addr); | T *indices = reinterpret_cast<T *>(inputs[1]->addr); | ||||
| char *update = reinterpret_cast<char *>(inputs[2]->addr); | char *update = reinterpret_cast<char *>(inputs[2]->addr); | ||||
| @@ -80,6 +82,5 @@ void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -46,6 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel { | |||||
| TypeId input_x_dtype_{kTypeUnknown}; | TypeId input_x_dtype_{kTypeUnknown}; | ||||
| TypeId indices_dtype_{kTypeUnknown}; | TypeId indices_dtype_{kTypeUnknown}; | ||||
| size_t input_x_dtype_size_ = 4; | size_t input_x_dtype_size_ = 4; | ||||
| CNodePtr node_ = nullptr; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(UpdateCache, | MS_REG_CPU_KERNEL(UpdateCache, | ||||
| @@ -101,7 +102,6 @@ MS_REG_CPU_KERNEL(UpdateCache, | |||||
| .AddInputAttr(kNumberTypeInt64) | .AddInputAttr(kNumberTypeInt64) | ||||
| .AddOutputAttr(kNumberTypeInt64), | .AddOutputAttr(kNumberTypeInt64), | ||||
| UpdateCacheCPUKernel); | UpdateCacheCPUKernel); | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -201,7 +201,12 @@ AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &prim | |||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list); | |||||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list); | const AbstractBasePtrList &args_spec_list); | ||||
| AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| @@ -273,6 +273,99 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv | |||||
| return std::make_shared<AbstractTensor>(x->element(), x->shape()); | return std::make_shared<AbstractTensor>(x->element(), x->shape()); | ||||
| } | } | ||||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 5); | |||||
| auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(hash_map); | |||||
| MS_EXCEPTION_IF_NULL(hash_map->shape()); | |||||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| auto indices_shp = indices->shape(); | |||||
| MS_EXCEPTION_IF_NULL(indices); | |||||
| MS_EXCEPTION_IF_NULL(indices_shp); | |||||
| ShapeVector shape; | |||||
| ShapeVector min_shape; | |||||
| ShapeVector max_shape; | |||||
| if (!indices_shp->max_shape().empty()) { | |||||
| max_shape = indices_shp->max_shape(); | |||||
| } else { | |||||
| max_shape = indices_shp->shape(); | |||||
| } | |||||
| for (size_t i = 0; i < max_shape.size(); i++) { | |||||
| shape.emplace_back(Shape::SHP_ANY); | |||||
| min_shape.emplace_back(1); | |||||
| } | |||||
| auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape()); | |||||
| auto old_emb_idx = | |||||
| std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| auto miss_emb_idx = | |||||
| std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| auto swap_emb_idx = | |||||
| std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx}; | |||||
| return std::make_shared<AbstractTuple>(elements); | |||||
| } | |||||
| AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| CheckArgsSize(op_name, args_spec_list, 3); | |||||
| auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| auto cache_table_shp = cache_table->shape(); | |||||
| MS_EXCEPTION_IF_NULL(cache_table); | |||||
| MS_EXCEPTION_IF_NULL(cache_table_shp); | |||||
| auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| auto swap_cache_idx_shp = swap_cache_idx->shape(); | |||||
| MS_EXCEPTION_IF_NULL(swap_cache_idx); | |||||
| MS_EXCEPTION_IF_NULL(swap_cache_idx_shp); | |||||
| auto cache_table_shape = cache_table_shp->shape(); | |||||
| auto swap_cache_idx_shape = swap_cache_idx_shp->shape(); | |||||
| ShapeVector shape; | |||||
| shape.emplace_back(swap_cache_idx_shape[0]); | |||||
| shape.emplace_back(cache_table_shape[1]); | |||||
| auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape(); | |||||
| ShapeVector max_shape; | |||||
| ShapeVector min_shape; | |||||
| if (!swap_cache_idx_max_shape.empty()) { | |||||
| max_shape.emplace_back(swap_cache_idx_max_shape[0]); | |||||
| max_shape.emplace_back(cache_table_shape[1]); | |||||
| } else { | |||||
| max_shape = shape; | |||||
| } | |||||
| for (size_t i = 0; i < max_shape.size(); ++i) { | |||||
| min_shape.emplace_back(1); | |||||
| } | |||||
| AbstractTensorPtr ret = | |||||
| std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||||
| const AbstractBasePtrList &args_spec_list) { | |||||
| const std::string op_name = primitive->name(); | |||||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| MS_EXCEPTION_IF_NULL(input_x); | |||||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||||
| MS_EXCEPTION_IF_NULL(indices); | |||||
| MS_EXCEPTION_IF_NULL(indices->shape()); | |||||
| ShapeVector shape; | |||||
| shape.emplace_back(1); | |||||
| AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape)); | |||||
| return ret; | |||||
| } | |||||
| AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| const std::string op_name = primitive->name(); | const std::string op_name = primitive->name(); | ||||
| @@ -57,6 +57,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | ||||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | ||||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | ||||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | |||||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, | |||||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | |||||
| {prim::kPrimDiv, {InferImplDiv, true}}, | {prim::kPrimDiv, {InferImplDiv, true}}, | ||||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | ||||
| {prim::kPrimShape, {InferImplShape, false}}, | {prim::kPrimShape, {InferImplShape, false}}, | ||||
| @@ -98,6 +98,9 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>( | |||||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | ||||
| inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | ||||
| inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | ||||
| inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx"); | |||||
| inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | |||||
| inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | |||||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | ||||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | ||||
| inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | ||||
| @@ -15,11 +15,11 @@ | |||||
| """cache_ops""" | """cache_ops""" | ||||
| from ..._checkparam import Validator as validator | from ..._checkparam import Validator as validator | ||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck | |||||
| from .. import signature as sig | from .. import signature as sig | ||||
| class UpdateCache(PrimitiveWithInfer): | |||||
| class UpdateCache(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Update the value fo input_x, similar to ScatterNdUpdate. | Update the value fo input_x, similar to ScatterNdUpdate. | ||||
| The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num. | The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num. | ||||
| @@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], | self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], | ||||
| outputs=['out']) | outputs=['out']) | ||||
| def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): | |||||
| if len(indices_shape) < 2: | |||||
| raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, " | |||||
| "but got %d." % len(indices_shape)) | |||||
| def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): | |||||
| return [1] | return [1] | ||||
| def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | |||||
| validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) | |||||
| def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | |||||
| validator.check_tensor_dtype_valid( | |||||
| "indices", indices_dtype, mstype.int_type, self.name) | |||||
| return input_x_dtype | return input_x_dtype | ||||
| @@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer): | |||||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | ||||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | ||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||||
| validator.check_tensors_dtypes_same_and_valid( | |||||
| args, mstype.int_type, self.name) | |||||
| out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) | out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) | ||||
| return out_dtype | return out_dtype | ||||
| @@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer): | |||||
| outputs=['swap_cache_idx', 'old_emb_idx']) | outputs=['swap_cache_idx', 'old_emb_idx']) | ||||
| def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape): | def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape): | ||||
| if len(hashmap_shape) != 2: | if len(hashmap_shape) != 2: | ||||
| raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, " | raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, " | ||||
| "but got %d." % len(hashmap_shape)) | "but got %d." % len(hashmap_shape)) | ||||
| @@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer): | |||||
| return out_shape | return out_shape | ||||
| def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): | def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): | ||||
| validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) | |||||
| validator.check_tensor_dtype_valid( | |||||
| "miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) | |||||
| out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) | out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) | ||||
| return out_dtype | return out_dtype | ||||
| class CacheSwapTable(PrimitiveWithInfer): | |||||
| class CacheSwapTable(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. | Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. | ||||
| @@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer): | |||||
| self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], | self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], | ||||
| outputs=['old_value']) | outputs=['old_value']) | ||||
| def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): | |||||
| def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): | |||||
| if len(cache_table_shape) != 2: | if len(cache_table_shape) != 2: | ||||
| raise ValueError( | raise ValueError( | ||||
| "cache table shape must be 2, but got %d" % len(cache_table_shape)) | "cache table shape must be 2, but got %d" % len(cache_table_shape)) | ||||
| if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape: | |||||
| raise ValueError( | |||||
| "swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape") | |||||
| return miss_value_shape | return miss_value_shape | ||||
| def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | |||||
| validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) | |||||
| def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | |||||
| validator.check_tensor_dtype_valid( | |||||
| "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) | |||||
| return miss_value_dtype | return miss_value_dtype | ||||
| class MapCacheIdx(PrimitiveWithInfer): | |||||
| class MapCacheIdx(PrimitiveWithCheck): | |||||
| """ | """ | ||||
| MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. | MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. | ||||
| When input an indices tensor, it will output the cache indices which search in hashmap. | When input an indices tensor, it will output the cache indices which search in hashmap. | ||||
| @@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer): | |||||
| def __init__(self): | def __init__(self): | ||||
| """init MapCacheIdx""" | """init MapCacheIdx""" | ||||
| self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'], | |||||
| self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], | |||||
| outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) | outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) | ||||
| def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape): | |||||
| def __check__(self, hashmap, indices, step, emb_max_num, offset): | |||||
| hashmap_shape = hashmap['shape'] | |||||
| if len(hashmap_shape) != 2: | if len(hashmap_shape) != 2: | ||||
| raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " | raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " | ||||
| "but got %d." % len(hashmap_shape)) | "but got %d." % len(hashmap_shape)) | ||||
| out_shape = (indices_shape, indices_shape, | |||||
| indices_shape, indices_shape) | |||||
| return out_shape | |||||
| out_shape = (indices['shape'], -1, -1, -1) | |||||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | |||||
| hashmap_dtype = hashmap['dtype'] | |||||
| indices_dtype = indices['dtype'] | |||||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | ||||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||||
| out_dtype = (hashmap_dtype, hashmap_dtype, | out_dtype = (hashmap_dtype, hashmap_dtype, | ||||
| hashmap_dtype, hashmap_dtype) | hashmap_dtype, hashmap_dtype) | ||||
| return out_dtype | |||||
| out = {'shape': out_shape, | |||||
| 'dtype': out_dtype, | |||||
| 'value': None} | |||||
| if 'max_shape' in indices: | |||||
| out['max_shape'] = (indices['max_shape'], indices['max_shape'], | |||||
| indices['max_shape'], indices['max_shape']) | |||||
| else: | |||||
| out['max_shape'] = (indices['shape'], indices['shape'], | |||||
| indices['shape'], indices['shape']) | |||||
| if 'min_shape' in indices: | |||||
| out['min_shape'] = (indices['min_shape'], 0, 0, 0) | |||||
| else: | |||||
| out['min_shape'] = (0, 0, 0, 0) | |||||
| return out | |||||
| @@ -75,19 +75,6 @@ class CacheSwapHashmapNet(nn.Cell): | |||||
| return self.ops(self.net.hashmap, miss_emb_idx, self.step) | return self.ops(self.net.hashmap, miss_emb_idx, self.step) | ||||
| class MapCacheIdxNet(nn.Cell): | |||||
| def __init__(self, hashmap_np): | |||||
| super().__init__() | |||||
| self.ops = P.MapCacheIdx() | |||||
| self.hashmap = Parameter(Tensor(hashmap_np), name="hashmap") | |||||
| self.emb_max = 25 | |||||
| self.cache_max = 10 | |||||
| self.step = 0 | |||||
| def construct(self, indices): | |||||
| return self.ops(self.hashmap, indices, self.step, self.emb_max, self.cache_max) | |||||
| class UpdateCacheNet(nn.Cell): | class UpdateCacheNet(nn.Cell): | ||||
| def __init__(self, x): | def __init__(self, x): | ||||
| super().__init__() | super().__init__() | ||||
| @@ -165,45 +152,6 @@ def test_cache_swap_hashmap(): | |||||
| np.array(hashmap_np_after_ops, np.int32)) | np.array(hashmap_np_after_ops, np.int32)) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu | |||||
| @pytest.mark.env_onecard | |||||
| def test_map_cache_idx(): | |||||
| hashmap_np = init_hashmap(10) | |||||
| indices_np = np.array([10, 2, 20, 5, 3], np.int32) | |||||
| map_cache_idx = MapCacheIdxNet(hashmap_np) | |||||
| indices = Tensor(indices_np) | |||||
| cache_idx, old_emb_idx, miss_emb_idx, swap_cache_idx = map_cache_idx( | |||||
| indices) | |||||
| expect_cache_idx = [5, 1, 9, 7, 3] | |||||
| expect_old_emb_idx = [-1, -1, 21, 15, -1] | |||||
| expect_miss_emb_idx = [-1, -1, 20, 5, -1] | |||||
| expect_swap_cache_idx = [-1, -1, 9, 7, -1] | |||||
| hashmap_np_after_ops = [[5, 7, 0, 1], | |||||
| [10, 5, 0, 1], | |||||
| [2, 1, 0, 1], | |||||
| [20, 9, 0, 1], | |||||
| [20, 9, 0, 0], | |||||
| [0, 0, 0, 0], | |||||
| [0, 0, 0, 0], | |||||
| [0, 0, 0, 0], | |||||
| [3, 3, 0, 1], | |||||
| [21, 9, -5, 0]] | |||||
| assert np.allclose(cache_idx.asnumpy(), | |||||
| np.array(expect_cache_idx, np.int32)) | |||||
| assert np.allclose(old_emb_idx.asnumpy(), | |||||
| np.array(expect_old_emb_idx, np.int32)) | |||||
| assert np.allclose(miss_emb_idx.asnumpy(), | |||||
| np.array(expect_miss_emb_idx, np.int32)) | |||||
| assert np.allclose(swap_cache_idx.asnumpy(), | |||||
| np.array(expect_swap_cache_idx, np.int32)) | |||||
| assert np.allclose(map_cache_idx.hashmap.data.asnumpy(), | |||||
| np.array(hashmap_np_after_ops, np.int32)) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_cpu | @pytest.mark.platform_x86_cpu | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||