| @@ -21,7 +21,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| template <typename T> | template <typename T> | ||||
| void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| CheckParam(kernel_node); | CheckParam(kernel_node); | ||||
| axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS)); | axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS)); | ||||
| @@ -35,6 +35,10 @@ template <typename T> | |||||
| bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| size_t input_num = AnfAlgo::GetInputTensorNum(node_); | size_t input_num = AnfAlgo::GetInputTensorNum(node_); | ||||
| std::vector<std::vector<size_t>> input_flat_shape_list; | std::vector<std::vector<size_t>> input_flat_shape_list; | ||||
| for (size_t i = 0; i < input_num; i++) { | for (size_t i = 0; i < input_num; i++) { | ||||
| @@ -36,7 +36,7 @@ class ConcatCPUKernel : public CPUKernel { | |||||
| private: | private: | ||||
| void CheckParam(const CNodePtr &kernel_node); | void CheckParam(const CNodePtr &kernel_node); | ||||
| int axis_ = 0; | int axis_ = 0; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL_T( | MS_REG_CPU_KERNEL_T( | ||||
| @@ -20,7 +20,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace kernel { | namespace kernel { | ||||
| void DynamicAssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void DynamicAssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| input_x_dtype_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); | input_x_dtype_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); | ||||
| } | } | ||||
| @@ -46,6 +46,10 @@ bool DynamicAssignCPUKernel::Launch(const std::vector<kernel::AddressPtr> &input | |||||
| template <typename T> | template <typename T> | ||||
| void DynamicAssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void DynamicAssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | ||||
| auto input_y_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | auto input_y_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | ||||
| batch_size_ = 1; | batch_size_ = 1; | ||||
| @@ -41,7 +41,7 @@ class DynamicAssignCPUKernel : public CPUKernel { | |||||
| size_t batch_size_{1}; | size_t batch_size_{1}; | ||||
| TypeId input_x_dtype_{kTypeUnknown}; | TypeId input_x_dtype_{kTypeUnknown}; | ||||
| size_t input_x_dtype_size_ = 4; | size_t input_x_dtype_size_ = 4; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL( | MS_REG_CPU_KERNEL( | ||||
| @@ -49,7 +49,7 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp | |||||
| void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| CheckParam(kernel_node); | CheckParam(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| if (input_shape.empty()) { | if (input_shape.empty()) { | ||||
| MS_LOG(EXCEPTION) << "param must be at least 1D"; | MS_LOG(EXCEPTION) << "param must be at least 1D"; | ||||
| @@ -73,7 +73,11 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| template <typename T> | template <typename T> | ||||
| void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (node_ != nullptr) { | |||||
| if (!node_wpt_.expired()) { | |||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | std::vector<size_t> input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | ||||
| if (input_shape.empty()) { | if (input_shape.empty()) { | ||||
| MS_LOG(EXCEPTION) << "param must be at least 1D"; | MS_LOG(EXCEPTION) << "param must be at least 1D"; | ||||
| @@ -41,7 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel { | |||||
| size_t first_dim_size_{1}; | size_t first_dim_size_{1}; | ||||
| size_t outer_dim_size_{1}; | size_t outer_dim_size_{1}; | ||||
| TypeId indices_data_type_{kNumberTypeInt32}; | TypeId indices_data_type_{kNumberTypeInt32}; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL( | MS_REG_CPU_KERNEL( | ||||
| @@ -42,9 +42,21 @@ int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) { | |||||
| return compress_count; | return compress_count; | ||||
| } | } | ||||
| void UpdateShape(size_t miss_count, const CNodePtr &node_) { | |||||
| std::vector<size_t> out_shape; | |||||
| out_shape.emplace_back(miss_count); | |||||
| std::vector<TypeId> dtypes; | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node_); | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, | |||||
| node_.get()); | |||||
| } | |||||
| void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| if (hashmap_shape.size() != 2) { | if (hashmap_shape.size() != 2) { | ||||
| MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; | MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; | ||||
| @@ -73,6 +85,7 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| template <typename T> | template <typename T> | ||||
| void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | ||||
| batch_size_ = 1; | batch_size_ = 1; | ||||
| for (size_t i = 0; i < emb_idx_shape.size(); ++i) { | for (size_t i = 0; i < emb_idx_shape.size(); ++i) { | ||||
| @@ -92,7 +105,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| float total_count = 0; | float total_count = 0; | ||||
| int count_size = 0; | int count_size = 0; | ||||
| float hit_count = 0; | float hit_count = 0; | ||||
| // search_cache_idx | // search_cache_idx | ||||
| for (size_t i = 0; i < batch_size_; ++i) { | for (size_t i = 0; i < batch_size_; ++i) { | ||||
| T key = input_indices[i] - offset; | T key = input_indices[i] - offset; | ||||
| @@ -107,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | tmp_entry = (tmp_entry + 1) % hashmap_length_; | ||||
| if (count > hashmap_length_) { | if (count > hashmap_length_) { | ||||
| MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!"; | MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!"; | ||||
| break; | |||||
| } | } | ||||
| count += 1; | count += 1; | ||||
| } | } | ||||
| @@ -130,7 +141,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| MS_LOG(INFO) << "Avg search count: " << total_count / count_size; | MS_LOG(INFO) << "Avg search count: " << total_count / count_size; | ||||
| MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; | MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; | ||||
| } | } | ||||
| float total_insert_count = 0; | float total_insert_count = 0; | ||||
| float total_delete_count = 0; | float total_delete_count = 0; | ||||
| // swap hash map | // swap hash map | ||||
| @@ -142,7 +152,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| entry = (entry + 1) % hashmap_length_; | entry = (entry + 1) % hashmap_length_; | ||||
| if (tag_count > hashmap_length_) { | if (tag_count > hashmap_length_) { | ||||
| MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!"; | MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!"; | ||||
| break; | |||||
| } | } | ||||
| tag_count++; | tag_count++; | ||||
| } | } | ||||
| @@ -155,7 +164,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| tmp_entry = (tmp_entry + 1) % hashmap_length_; | tmp_entry = (tmp_entry + 1) % hashmap_length_; | ||||
| if (delete_count > hashmap_length_) { | if (delete_count > hashmap_length_) { | ||||
| MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!"; | MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!"; | ||||
| break; | |||||
| } | } | ||||
| delete_count++; | delete_count++; | ||||
| } | } | ||||
| @@ -171,22 +179,11 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; | MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; | ||||
| MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; | MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; | ||||
| } | } | ||||
| // update step | |||||
| step_[0] += 1; | step_[0] += 1; | ||||
| // update cache idx | |||||
| for (size_t i = 0; i < miss_count; ++i) { | for (size_t i = 0; i < miss_count; ++i) { | ||||
| int idx = miss_idx[i]; | |||||
| output_cache_idx[idx] = output_swap_cache_idx[i]; | |||||
| output_cache_idx[miss_idx[i]] = output_swap_cache_idx[i]; | |||||
| } | } | ||||
| std::vector<size_t> out_shape; | |||||
| out_shape.emplace_back(miss_count); | |||||
| std::vector<TypeId> dtypes; | |||||
| size_t output_num = AnfAlgo::GetOutputTensorNum(node_); | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | |||||
| } | |||||
| AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, | |||||
| node_.get()); | |||||
| UpdateShape(miss_count, node_); | |||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -42,7 +42,7 @@ class MapCacheIdxCPUKernel : public CPUKernel { | |||||
| size_t batch_size_{1}; | size_t batch_size_{1}; | ||||
| size_t hashmap_length_{1}; | size_t hashmap_length_{1}; | ||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(MapCacheIdx, | MS_REG_CPU_KERNEL(MapCacheIdx, | ||||
| @@ -24,7 +24,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void MapUniformCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void MapUniformCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| } | } | ||||
| @@ -45,6 +45,10 @@ bool MapUniformCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| template <typename T> | template <typename T> | ||||
| void MapUniformCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void MapUniformCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | ||||
| batch_size_ = 1; | batch_size_ = 1; | ||||
| for (size_t i = 0; i < input_x_shape.size(); ++i) { | for (size_t i = 0; i < input_x_shape.size(); ++i) { | ||||
| @@ -41,7 +41,7 @@ class MapUniformCPUKernel : public CPUKernel { | |||||
| private: | private: | ||||
| size_t batch_size_{1}; | size_t batch_size_{1}; | ||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(MapUniform, | MS_REG_CPU_KERNEL(MapUniform, | ||||
| @@ -22,7 +22,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void PadAndShiftCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void PadAndShiftCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| type_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); | type_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); | ||||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| @@ -77,6 +77,10 @@ void PadAndShiftCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| std::vector<size_t> out_shape; | std::vector<size_t> out_shape; | ||||
| out_shape.emplace_back(output_size); | out_shape.emplace_back(output_size); | ||||
| std::vector<TypeId> dtypes; | std::vector<TypeId> dtypes; | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| auto output_nums = AnfAlgo::GetOutputTensorNum(node_); | auto output_nums = AnfAlgo::GetOutputTensorNum(node_); | ||||
| for (size_t i = 0; i < output_nums; i++) { | for (size_t i = 0; i < output_nums; i++) { | ||||
| dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); | ||||
| @@ -42,7 +42,7 @@ class PadAndShiftCPUKernel : public CPUKernel { | |||||
| size_t cum_sum_size_{1}; | size_t cum_sum_size_{1}; | ||||
| size_t type_size_{4}; | size_t type_size_{4}; | ||||
| TypeId input_x_dtype_{kTypeUnknown}; | TypeId input_x_dtype_{kTypeUnknown}; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(PadAndShift, | MS_REG_CPU_KERNEL(PadAndShift, | ||||
| @@ -20,7 +20,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| x_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | x_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | ||||
| type_size_ = GetTypeByte(TypeIdToType(x_data_type_)); | type_size_ = GetTypeByte(TypeIdToType(x_data_type_)); | ||||
| } | } | ||||
| @@ -28,6 +28,10 @@ void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | bool ReshapeCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | const std::vector<kernel::AddressPtr> & /*workspace*/, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | ||||
| if (inputs.empty() || outputs.empty()) { | if (inputs.empty() || outputs.empty()) { | ||||
| MS_LOG(EXCEPTION) << "input or output empty!"; | MS_LOG(EXCEPTION) << "input or output empty!"; | ||||
| @@ -33,7 +33,7 @@ class ReshapeCPUKernel : public CPUKernel { | |||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| private: | private: | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| TypeId x_data_type_{kNumberTypeInt32}; | TypeId x_data_type_{kNumberTypeInt32}; | ||||
| size_t type_size_ = 4; | size_t type_size_ = 4; | ||||
| }; | }; | ||||
| @@ -22,7 +22,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void SubAndFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void SubAndFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| } | } | ||||
| @@ -43,6 +43,10 @@ bool SubAndFilterCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs | |||||
| template <typename T> | template <typename T> | ||||
| void SubAndFilterCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void SubAndFilterCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); | ||||
| batch_size_ = 1; | batch_size_ = 1; | ||||
| @@ -40,7 +40,7 @@ class SubAndFilterCPUKernel : public CPUKernel { | |||||
| private: | private: | ||||
| size_t batch_size_{1}; | size_t batch_size_{1}; | ||||
| TypeId input_x_dtype_{kTypeUnknown}; | TypeId input_x_dtype_{kTypeUnknown}; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(SubAndFilter, | MS_REG_CPU_KERNEL(SubAndFilter, | ||||
| @@ -21,7 +21,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| const size_t kUseBucketUniqueSize = 100000; | const size_t kUseBucketUniqueSize = 100000; | ||||
| void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| CheckParam(kernel_node); | CheckParam(kernel_node); | ||||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | ||||
| input_size_ = input_shape[0]; | input_size_ = input_shape[0]; | ||||
| @@ -45,7 +45,11 @@ bool UniqueCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| } else if (dtype_ == kNumberTypeFloat32) { | } else if (dtype_ == kNumberTypeFloat32) { | ||||
| LaunchKernel<float, int>(inputs, workspace, outputs); | LaunchKernel<float, int>(inputs, workspace, outputs); | ||||
| } | } | ||||
| if (node_ != nullptr) { | |||||
| if (!node_wpt_.expired()) { | |||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| std::vector<size_t> out_shape; | std::vector<size_t> out_shape; | ||||
| out_shape.emplace_back(output_size_); | out_shape.emplace_back(output_size_); | ||||
| std::vector<TypeId> dtypes; | std::vector<TypeId> dtypes; | ||||
| @@ -60,7 +60,7 @@ class UniqueCPUKernel : public CPUKernel { | |||||
| size_t input_size_{0}; | size_t input_size_{0}; | ||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| size_t output_size_{0}; | size_t output_size_{0}; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| template <typename DataType> | template <typename DataType> | ||||
| static size_t BucketId(DataType data, size_t bucket_num) { | static size_t BucketId(DataType data, size_t bucket_num) { | ||||
| @@ -22,7 +22,7 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { | void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { | ||||
| MS_EXCEPTION_IF_NULL(kernel_node); | MS_EXCEPTION_IF_NULL(kernel_node); | ||||
| node_ = kernel_node; | |||||
| node_wpt_ = kernel_node; | |||||
| input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); | ||||
| indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); | indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); | ||||
| @@ -53,6 +53,10 @@ bool UpdateCacheCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| template <typename T> | template <typename T> | ||||
| void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | ||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| auto node_ = node_wpt_.lock(); | |||||
| if (!node_) { | |||||
| MS_LOG(EXCEPTION) << "node_wpt_ is expired."; | |||||
| } | |||||
| auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); | ||||
| auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2); | auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2); | ||||
| @@ -46,7 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel { | |||||
| TypeId input_x_dtype_{kTypeUnknown}; | TypeId input_x_dtype_{kTypeUnknown}; | ||||
| TypeId indices_dtype_{kTypeUnknown}; | TypeId indices_dtype_{kTypeUnknown}; | ||||
| size_t input_x_dtype_size_ = 4; | size_t input_x_dtype_size_ = 4; | ||||
| CNodePtr node_ = nullptr; | |||||
| CNodeWeakPtr node_wpt_; | |||||
| }; | }; | ||||
| MS_REG_CPU_KERNEL(UpdateCache, | MS_REG_CPU_KERNEL(UpdateCache, | ||||
| @@ -62,6 +62,7 @@ using ValueNodePtr = std::shared_ptr<ValueNode>; | |||||
| class CNode; | class CNode; | ||||
| using CNodePtr = std::shared_ptr<CNode>; | using CNodePtr = std::shared_ptr<CNode>; | ||||
| using CNodePtrList = std::vector<CNodePtr>; | using CNodePtrList = std::vector<CNodePtr>; | ||||
| using CNodeWeakPtr = std::weak_ptr<CNode>; | |||||
| class FuncGraph; | class FuncGraph; | ||||
| using FuncGraphSet = OrderedSet<FuncGraphPtr>; | using FuncGraphSet = OrderedSet<FuncGraphPtr>; | ||||