From: @fangzehua Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -35,7 +35,6 @@ void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| } | |||
| } | |||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) { | |||
| input_x_dtype_size_ = 4; | |||
| } else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) { | |||
| @@ -75,6 +74,5 @@ void AssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -60,7 +60,6 @@ MS_REG_CPU_KERNEL( | |||
| Assign, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| AssignCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -20,7 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||
| T i = (entry + 1) % length, off = 1; | |||
| @@ -107,6 +106,5 @@ void CacheSwapHashmapCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inpu | |||
| } | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -25,7 +25,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class CacheSwapHashmapCPUKernel : public CPUKernel { | |||
| public: | |||
| CacheSwapHashmapCPUKernel() = default; | |||
| @@ -82,7 +81,6 @@ MS_REG_CPU_KERNEL(CacheSwapHashmap, | |||
| .AddOutputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| CacheSwapHashmapCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -22,7 +22,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| struct HashmapEntry { | |||
| T key; | |||
| @@ -60,8 +59,9 @@ T HashFunc(const T &key, const size_t &m) { | |||
| } | |||
| template <typename T> | |||
| void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||
| int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||
| T i = (entry + 1) % length, off = 1; | |||
| int compress_count = 0; | |||
| for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { | |||
| if (entry_p[i].tag > off) { | |||
| entry_p[entry].key = entry_p[i].key; | |||
| @@ -72,21 +72,20 @@ void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||
| off = 0; | |||
| entry = i; | |||
| } | |||
| compress_count++; | |||
| } | |||
| return compress_count; | |||
| } | |||
| void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| node_ = kernel_node; | |||
| auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| if (hashmap_shape.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; | |||
| } | |||
| for (size_t i = 0; i < emb_idx_shape.size(); ++i) { | |||
| batch_size_ *= emb_idx_shape[i]; | |||
| } | |||
| hashmap_length_ = hashmap_shape[0]; | |||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| } | |||
| @@ -108,100 +107,124 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| template <typename T> | |||
| void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | |||
| batch_size_ = 1; | |||
| for (size_t i = 0; i < emb_idx_shape.size(); ++i) { | |||
| batch_size_ *= emb_idx_shape[i]; | |||
| } | |||
| HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr); | |||
| auto input_indices = reinterpret_cast<T *>(inputs[1]->addr); | |||
| T *step_ = reinterpret_cast<T *>(inputs[2]->addr); | |||
| T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr); | |||
| T cache_max_num = *reinterpret_cast<T *>(inputs[4]->addr); | |||
| T offset = *reinterpret_cast<T *>(inputs[4]->addr); | |||
| auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr); | |||
| auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr); | |||
| auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr); | |||
| std::vector<T> output_miss_idx(batch_size_, -1); | |||
| std::vector<T> miss_idx; | |||
| size_t miss_count = 0; | |||
| float total_count = 0; | |||
| int count_size = 0; | |||
| float hit_count = 0; | |||
| // search_cache_idx | |||
| for (size_t i = 0; i < batch_size_; ++i) { | |||
| if (input_indices[i] == emb_max_num) { | |||
| output_miss_idx[i] = -1; | |||
| output_cache_idx[i] = cache_max_num; | |||
| output_miss_emb_idx[i] = -1; | |||
| T key = input_indices[i] - offset; | |||
| if (key >= emb_max_num || key < 0) { | |||
| output_cache_idx[i] = -1; | |||
| continue; | |||
| } | |||
| T key = input_indices[i]; | |||
| T tmp_entry = HashFunc(key, hashmap_length_); | |||
| int count = 1; | |||
| size_t count = 1; | |||
| count_size += 1; | |||
| while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { | |||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||
| if (count > hashmap_length_) { | |||
| MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!"; | |||
| break; | |||
| } | |||
| count += 1; | |||
| } | |||
| total_count += count; | |||
| if (hashmap[tmp_entry].IsEmpty()) { | |||
| output_miss_idx[i] = i; | |||
| output_miss_emb_idx[i] = key; | |||
| miss_idx.emplace_back(i); | |||
| output_miss_emb_idx[miss_count] = key; | |||
| output_cache_idx[i] = -1; | |||
| miss_count++; | |||
| } else { | |||
| hit_count += 1; | |||
| output_miss_idx[i] = -1; | |||
| output_cache_idx[i] = hashmap[tmp_entry].value; | |||
| hashmap[tmp_entry].step = step_[0]; | |||
| output_miss_emb_idx[i] = -1; | |||
| } | |||
| } | |||
| MS_LOG(INFO) << "avg search count: " << total_count / count_size; | |||
| MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; | |||
| MS_LOG(INFO) << "Miss count: " << miss_count; | |||
| MS_LOG(INFO) << "Avg search count: " << total_count / count_size; | |||
| MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; | |||
| float total_insert_count = 0; | |||
| float total_delete_count = 0; | |||
| // swap hash map | |||
| for (size_t i = 0; i < batch_size_; ++i) { | |||
| if (output_miss_emb_idx[i] < 0) { | |||
| output_swap_cache_idx[i] = -1; | |||
| output_old_emb_idx[i] = -1; | |||
| } else { | |||
| T emb_idx = output_miss_emb_idx[i]; | |||
| T entry = HashFunc(emb_idx, hashmap_length_); | |||
| T tag_count = 1; | |||
| while (!hashmap[entry].IsEmpty()) { | |||
| entry = (entry + 1) % hashmap_length_; | |||
| tag_count++; | |||
| for (size_t i = 0; i < miss_count; ++i) { | |||
| T emb_idx = output_miss_emb_idx[i]; | |||
| T entry = HashFunc(emb_idx, hashmap_length_); | |||
| size_t tag_count = 1; | |||
| while (!hashmap[entry].IsEmpty()) { | |||
| entry = (entry + 1) % hashmap_length_; | |||
| if (tag_count > hashmap_length_) { | |||
| MS_LOG(ERROR) << "Hashmap is full, insert new key failed!"; | |||
| break; | |||
| } | |||
| tag_count++; | |||
| } | |||
| hashmap[entry].key = emb_idx; | |||
| hashmap[entry].step = step_[0]; | |||
| hashmap[entry].tag = tag_count; | |||
| T tmp_entry = (entry + 1) % hashmap_length_; | |||
| hashmap[entry].key = emb_idx; | |||
| hashmap[entry].step = step_[0]; | |||
| hashmap[entry].tag = tag_count; | |||
| while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | |||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||
| T tmp_entry = (entry + 1) % hashmap_length_; | |||
| size_t delete_count = 1; | |||
| while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { | |||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | |||
| if (delete_count > hashmap_length_) { | |||
| MS_LOG(ERROR) << "Hashmap is full, delete old key failed!"; | |||
| break; | |||
| } | |||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value; | |||
| output_old_emb_idx[i] = hashmap[tmp_entry].key; | |||
| hashmap[entry].value = output_swap_cache_idx[i]; | |||
| hashmap[tmp_entry].SetEmpty(); | |||
| Compress(hashmap, hashmap_length_, tmp_entry); | |||
| delete_count++; | |||
| } | |||
| output_swap_cache_idx[i] = hashmap[tmp_entry].value; | |||
| output_old_emb_idx[i] = hashmap[tmp_entry].key; | |||
| hashmap[entry].value = output_swap_cache_idx[i]; | |||
| hashmap[tmp_entry].SetEmpty(); | |||
| int compress_count = Compress(hashmap, hashmap_length_, tmp_entry); | |||
| total_delete_count += (compress_count + delete_count); | |||
| total_insert_count += tag_count; | |||
| } | |||
| MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; | |||
| MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; | |||
| // update step | |||
| step_[0] += 1; | |||
| // update cache idx | |||
| for (size_t i = 0; i < batch_size_; ++i) { | |||
| if (output_miss_idx[i] < 0 || output_miss_idx[i] >= cache_max_num) { | |||
| continue; | |||
| } | |||
| output_cache_idx[i] = output_swap_cache_idx[i]; | |||
| for (size_t i = 0; i < miss_count; ++i) { | |||
| int idx = miss_idx[i]; | |||
| output_cache_idx[idx] = output_swap_cache_idx[i]; | |||
| } | |||
| } | |||
| std::vector<size_t> out_shape; | |||
| out_shape.emplace_back(miss_count); | |||
| std::vector<TypeId> dtypes; | |||
| for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { | |||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||
| } | |||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, | |||
| node_.get()); | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class MapCacheIdxCPUKernel : public CPUKernel { | |||
| public: | |||
| MapCacheIdxCPUKernel() = default; | |||
| @@ -45,6 +44,7 @@ class MapCacheIdxCPUKernel : public CPUKernel { | |||
| size_t batch_size_{1}; | |||
| size_t hashmap_length_{1}; | |||
| TypeId dtype_{kTypeUnknown}; | |||
| CNodePtr node_ = nullptr; | |||
| }; | |||
| MS_REG_CPU_KERNEL(MapCacheIdx, | |||
| @@ -98,7 +98,6 @@ MS_REG_CPU_KERNEL(MapCacheIdx, | |||
| .AddOutputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| MapCacheIdxCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -99,6 +99,5 @@ void SearchCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs | |||
| MS_LOG(INFO) << "avg search count: " << total_count / count_size; | |||
| MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -27,7 +27,6 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| struct HashmapEntry { | |||
| T key; | |||
| @@ -133,7 +132,6 @@ MS_REG_CPU_KERNEL(SearchCacheIdx, | |||
| .AddOutputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| SearchCacheIdxCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -21,20 +21,9 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); | |||
| if (indices_shape.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "indices shape less than 2"; | |||
| } | |||
| for (size_t i = 0; i < indices_shape.size(); ++i) { | |||
| batch_size_ *= indices_shape[i]; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| node_ = kernel_node; | |||
| for (size_t i = 0; i < update_shape.size(); ++i) { | |||
| update_size_ *= update_shape[i]; | |||
| } | |||
| update_length_ = update_size_ / batch_size_; | |||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | |||
| indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); | |||
| @@ -64,6 +53,19 @@ bool UpdateCacheCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| template <typename T> | |||
| void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | |||
| auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2); | |||
| batch_size_ = 1; | |||
| for (size_t i = 0; i < indices_shape.size(); ++i) { | |||
| batch_size_ *= indices_shape[i]; | |||
| } | |||
| MS_LOG(INFO) << "UpdateCache batch_size:" << batch_size_; | |||
| update_size_ = 1; | |||
| for (size_t i = 0; i < update_shape.size(); ++i) { | |||
| update_size_ *= update_shape[i]; | |||
| } | |||
| update_length_ = update_shape[1]; | |||
| char *input_x = reinterpret_cast<char *>(inputs[0]->addr); | |||
| T *indices = reinterpret_cast<T *>(inputs[1]->addr); | |||
| char *update = reinterpret_cast<char *>(inputs[2]->addr); | |||
| @@ -80,6 +82,5 @@ void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| } | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -46,6 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel { | |||
| TypeId input_x_dtype_{kTypeUnknown}; | |||
| TypeId indices_dtype_{kTypeUnknown}; | |||
| size_t input_x_dtype_size_ = 4; | |||
| CNodePtr node_ = nullptr; | |||
| }; | |||
| MS_REG_CPU_KERNEL(UpdateCache, | |||
| @@ -101,7 +102,6 @@ MS_REG_CPU_KERNEL(UpdateCache, | |||
| .AddInputAttr(kNumberTypeInt64) | |||
| .AddOutputAttr(kNumberTypeInt64), | |||
| UpdateCacheCPUKernel); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -201,7 +201,12 @@ AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &prim | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list); | |||
| AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| @@ -273,6 +273,99 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv | |||
| return std::make_shared<AbstractTensor>(x->element(), x->shape()); | |||
| } | |||
| AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 5); | |||
| auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(hash_map); | |||
| MS_EXCEPTION_IF_NULL(hash_map->shape()); | |||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| auto indices_shp = indices->shape(); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(indices_shp); | |||
| ShapeVector shape; | |||
| ShapeVector min_shape; | |||
| ShapeVector max_shape; | |||
| if (!indices_shp->max_shape().empty()) { | |||
| max_shape = indices_shp->max_shape(); | |||
| } else { | |||
| max_shape = indices_shp->shape(); | |||
| } | |||
| for (size_t i = 0; i < max_shape.size(); i++) { | |||
| shape.emplace_back(Shape::SHP_ANY); | |||
| min_shape.emplace_back(1); | |||
| } | |||
| auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape()); | |||
| auto old_emb_idx = | |||
| std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| auto miss_emb_idx = | |||
| std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| auto swap_emb_idx = | |||
| std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx}; | |||
| return std::make_shared<AbstractTuple>(elements); | |||
| } | |||
| AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| CheckArgsSize(op_name, args_spec_list, 3); | |||
| auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| auto cache_table_shp = cache_table->shape(); | |||
| MS_EXCEPTION_IF_NULL(cache_table); | |||
| MS_EXCEPTION_IF_NULL(cache_table_shp); | |||
| auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| auto swap_cache_idx_shp = swap_cache_idx->shape(); | |||
| MS_EXCEPTION_IF_NULL(swap_cache_idx); | |||
| MS_EXCEPTION_IF_NULL(swap_cache_idx_shp); | |||
| auto cache_table_shape = cache_table_shp->shape(); | |||
| auto swap_cache_idx_shape = swap_cache_idx_shp->shape(); | |||
| ShapeVector shape; | |||
| shape.emplace_back(swap_cache_idx_shape[0]); | |||
| shape.emplace_back(cache_table_shape[1]); | |||
| auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape(); | |||
| ShapeVector max_shape; | |||
| ShapeVector min_shape; | |||
| if (!swap_cache_idx_max_shape.empty()) { | |||
| max_shape.emplace_back(swap_cache_idx_max_shape[0]); | |||
| max_shape.emplace_back(cache_table_shape[1]); | |||
| } else { | |||
| max_shape = shape; | |||
| } | |||
| for (size_t i = 0; i < max_shape.size(); ++i) { | |||
| min_shape.emplace_back(1); | |||
| } | |||
| AbstractTensorPtr ret = | |||
| std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| MS_EXCEPTION_IF_NULL(input_x); | |||
| MS_EXCEPTION_IF_NULL(input_x->shape()); | |||
| auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); | |||
| MS_EXCEPTION_IF_NULL(indices); | |||
| MS_EXCEPTION_IF_NULL(indices->shape()); | |||
| ShapeVector shape; | |||
| shape.emplace_back(1); | |||
| AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape)); | |||
| return ret; | |||
| } | |||
| AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| const std::string op_name = primitive->name(); | |||
| @@ -57,6 +57,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { | |||
| {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, | |||
| {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, | |||
| {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, | |||
| {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, | |||
| {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, | |||
| {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, | |||
| {prim::kPrimDiv, {InferImplDiv, true}}, | |||
| {prim::kPrimRealDiv, {InferImplRealDiv, true}}, | |||
| {prim::kPrimShape, {InferImplShape, false}}, | |||
| @@ -98,6 +98,9 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>( | |||
| inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin"); | |||
| inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset"); | |||
| inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape"); | |||
| inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx"); | |||
| inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache"); | |||
| inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable"); | |||
| inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile"); | |||
| inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN"); | |||
| inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2"); | |||
| @@ -15,11 +15,11 @@ | |||
| """cache_ops""" | |||
| from ..._checkparam import Validator as validator | |||
| from ...common import dtype as mstype | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register | |||
| from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck | |||
| from .. import signature as sig | |||
| class UpdateCache(PrimitiveWithInfer): | |||
| class UpdateCache(PrimitiveWithCheck): | |||
| """ | |||
| Update the value fo input_x, similar to ScatterNdUpdate. | |||
| The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num. | |||
| @@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], | |||
| outputs=['out']) | |||
| def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): | |||
| if len(indices_shape) < 2: | |||
| raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, " | |||
| "but got %d." % len(indices_shape)) | |||
| def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): | |||
| return [1] | |||
| def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | |||
| validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) | |||
| def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): | |||
| validator.check_tensor_dtype_valid( | |||
| "indices", indices_dtype, mstype.int_type, self.name) | |||
| return input_x_dtype | |||
| @@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer): | |||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | |||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||
| validator.check_tensors_dtypes_same_and_valid( | |||
| args, mstype.int_type, self.name) | |||
| out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) | |||
| return out_dtype | |||
| @@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer): | |||
| outputs=['swap_cache_idx', 'old_emb_idx']) | |||
| def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape): | |||
| if len(hashmap_shape) != 2: | |||
| raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, " | |||
| "but got %d." % len(hashmap_shape)) | |||
| @@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer): | |||
| return out_shape | |||
| def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): | |||
| validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) | |||
| validator.check_tensor_dtype_valid( | |||
| "miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) | |||
| out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) | |||
| return out_dtype | |||
| class CacheSwapTable(PrimitiveWithInfer): | |||
| class CacheSwapTable(PrimitiveWithCheck): | |||
| """ | |||
| Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. | |||
| @@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer): | |||
| self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], | |||
| outputs=['old_value']) | |||
| def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): | |||
| def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): | |||
| if len(cache_table_shape) != 2: | |||
| raise ValueError( | |||
| "cache table shape must be 2, but got %d" % len(cache_table_shape)) | |||
| if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape: | |||
| raise ValueError( | |||
| "swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape") | |||
| return miss_value_shape | |||
| def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | |||
| validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) | |||
| def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): | |||
| validator.check_tensor_dtype_valid( | |||
| "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) | |||
| return miss_value_dtype | |||
| class MapCacheIdx(PrimitiveWithInfer): | |||
| class MapCacheIdx(PrimitiveWithCheck): | |||
| """ | |||
| MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. | |||
| When input an indices tensor, it will output the cache indices which search in hashmap. | |||
| @@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer): | |||
| def __init__(self): | |||
| """init MapCacheIdx""" | |||
| self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'], | |||
| self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], | |||
| outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) | |||
| def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape): | |||
| def __check__(self, hashmap, indices, step, emb_max_num, offset): | |||
| hashmap_shape = hashmap['shape'] | |||
| if len(hashmap_shape) != 2: | |||
| raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " | |||
| "but got %d." % len(hashmap_shape)) | |||
| out_shape = (indices_shape, indices_shape, | |||
| indices_shape, indices_shape) | |||
| return out_shape | |||
| out_shape = (indices['shape'], -1, -1, -1) | |||
| def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): | |||
| hashmap_dtype = hashmap['dtype'] | |||
| indices_dtype = indices['dtype'] | |||
| args = {"hashmap": hashmap_dtype, "indices": indices_dtype} | |||
| validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) | |||
| validator.check_tensor_type_same(args, mstype.int_type, self.name) | |||
| out_dtype = (hashmap_dtype, hashmap_dtype, | |||
| hashmap_dtype, hashmap_dtype) | |||
| return out_dtype | |||
| out = {'shape': out_shape, | |||
| 'dtype': out_dtype, | |||
| 'value': None} | |||
| if 'max_shape' in indices: | |||
| out['max_shape'] = (indices['max_shape'], indices['max_shape'], | |||
| indices['max_shape'], indices['max_shape']) | |||
| else: | |||
| out['max_shape'] = (indices['shape'], indices['shape'], | |||
| indices['shape'], indices['shape']) | |||
| if 'min_shape' in indices: | |||
| out['min_shape'] = (indices['min_shape'], 0, 0, 0) | |||
| else: | |||
| out['min_shape'] = (0, 0, 0, 0) | |||
| return out | |||
| @@ -75,19 +75,6 @@ class CacheSwapHashmapNet(nn.Cell): | |||
| return self.ops(self.net.hashmap, miss_emb_idx, self.step) | |||
| class MapCacheIdxNet(nn.Cell): | |||
| def __init__(self, hashmap_np): | |||
| super().__init__() | |||
| self.ops = P.MapCacheIdx() | |||
| self.hashmap = Parameter(Tensor(hashmap_np), name="hashmap") | |||
| self.emb_max = 25 | |||
| self.cache_max = 10 | |||
| self.step = 0 | |||
| def construct(self, indices): | |||
| return self.ops(self.hashmap, indices, self.step, self.emb_max, self.cache_max) | |||
| class UpdateCacheNet(nn.Cell): | |||
| def __init__(self, x): | |||
| super().__init__() | |||
| @@ -165,45 +152,6 @@ def test_cache_swap_hashmap(): | |||
| np.array(hashmap_np_after_ops, np.int32)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_map_cache_idx(): | |||
| hashmap_np = init_hashmap(10) | |||
| indices_np = np.array([10, 2, 20, 5, 3], np.int32) | |||
| map_cache_idx = MapCacheIdxNet(hashmap_np) | |||
| indices = Tensor(indices_np) | |||
| cache_idx, old_emb_idx, miss_emb_idx, swap_cache_idx = map_cache_idx( | |||
| indices) | |||
| expect_cache_idx = [5, 1, 9, 7, 3] | |||
| expect_old_emb_idx = [-1, -1, 21, 15, -1] | |||
| expect_miss_emb_idx = [-1, -1, 20, 5, -1] | |||
| expect_swap_cache_idx = [-1, -1, 9, 7, -1] | |||
| hashmap_np_after_ops = [[5, 7, 0, 1], | |||
| [10, 5, 0, 1], | |||
| [2, 1, 0, 1], | |||
| [20, 9, 0, 1], | |||
| [20, 9, 0, 0], | |||
| [0, 0, 0, 0], | |||
| [0, 0, 0, 0], | |||
| [0, 0, 0, 0], | |||
| [3, 3, 0, 1], | |||
| [21, 9, -5, 0]] | |||
| assert np.allclose(cache_idx.asnumpy(), | |||
| np.array(expect_cache_idx, np.int32)) | |||
| assert np.allclose(old_emb_idx.asnumpy(), | |||
| np.array(expect_old_emb_idx, np.int32)) | |||
| assert np.allclose(miss_emb_idx.asnumpy(), | |||
| np.array(expect_miss_emb_idx, np.int32)) | |||
| assert np.allclose(swap_cache_idx.asnumpy(), | |||
| np.array(expect_swap_cache_idx, np.int32)) | |||
| assert np.allclose(map_cache_idx.hashmap.data.asnumpy(), | |||
| np.array(hashmap_np_after_ops, np.int32)) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||