diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc index 6daef4c98b..33852c8536 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace kernel { template void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { - node_ = kernel_node; + node_wpt_ = kernel_node; CheckParam(kernel_node); axis_ = LongToInt(AnfAlgo::GetNodeAttr(kernel_node, AXIS)); @@ -35,6 +35,10 @@ template bool ConcatCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } size_t input_num = AnfAlgo::GetInputTensorNum(node_); std::vector> input_flat_shape_list; for (size_t i = 0; i < input_num; i++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h index 3f6887cdf7..0a63dd7ea1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -36,7 +36,7 @@ class ConcatCPUKernel : public CPUKernel { private: void CheckParam(const CNodePtr &kernel_node); int axis_ = 0; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL_T( diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.cc index 9bda0aec33..38982941db 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { void DynamicAssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { - node_ = kernel_node; + node_wpt_ = kernel_node; input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); input_x_dtype_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); } @@ -46,6 +46,10 @@ bool DynamicAssignCPUKernel::Launch(const std::vector &input template void DynamicAssignCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); auto input_y_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); batch_size_ = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.h index 33c8f1e87b..3f4e170996 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/dynamic_assign_cpu_kernel.h @@ -41,7 +41,7 @@ class DynamicAssignCPUKernel : public CPUKernel { size_t batch_size_{1}; TypeId input_x_dtype_{kTypeUnknown}; size_t input_x_dtype_size_ = 4; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL( diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc index 5e24d73c55..bfbfc8c8c4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -49,7 +49,7 @@ void LookUpTableTask(const float *input_addr, const T *indices_addr, float *outp void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (input_shape.empty()) { MS_LOG(EXCEPTION) << "param must be at least 1D"; @@ -73,7 +73,11 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { template void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { - if (node_ != nullptr) { + if (!node_wpt_.expired()) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } std::vector input_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); if (input_shape.empty()) { MS_LOG(EXCEPTION) << "param must be at least 1D"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h index b1639100da..0034379184 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h @@ -41,7 +41,7 @@ class EmbeddingLookUpCPUKernel : public CPUKernel { size_t first_dim_size_{1}; size_t outer_dim_size_{1}; TypeId indices_data_type_{kNumberTypeInt32}; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL( diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc index 00c18d8808..5dee725684 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc @@ -42,9 +42,21 @@ int Compress(HashmapEntry *entry_p, const size_t &length, T entry) { return compress_count; } +void UpdateShape(size_t miss_count, const CNodePtr &node_) { + std::vector out_shape; + out_shape.emplace_back(miss_count); + std::vector dtypes; + size_t output_num = AnfAlgo::GetOutputTensorNum(node_); + for (size_t i = 0; i < output_num; i++) { + dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); + } + AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, + node_.get()); +} + void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (hashmap_shape.size() != 2) { MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; @@ -73,6 +85,7 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector &inputs, template void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); batch_size_ = 1; for (size_t i = 0; i < emb_idx_shape.size(); ++i) { @@ -92,7 +105,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, float total_count = 0; int count_size = 0; float hit_count = 0; - // search_cache_idx for (size_t i = 0; i < batch_size_; ++i) { T key = input_indices[i] - offset; @@ -107,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, tmp_entry = (tmp_entry + 1) % hashmap_length_; if (count > hashmap_length_) { MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!"; - break; } count += 1; } @@ -130,7 +141,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, MS_LOG(INFO) << "Avg search count: " << total_count / count_size; MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; } - float total_insert_count = 0; float total_delete_count = 0; // swap hash map @@ -142,7 +152,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, entry = (entry + 1) % hashmap_length_; if (tag_count > hashmap_length_) { MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!"; - break; } tag_count++; } @@ -155,7 +164,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, tmp_entry = (tmp_entry + 1) % hashmap_length_; if (delete_count > hashmap_length_) { MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!"; - break; } delete_count++; } @@ -171,22 +179,11 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; } - // update step step_[0] += 1; - // update cache idx for (size_t i = 0; i < miss_count; ++i) { - int idx = miss_idx[i]; - output_cache_idx[idx] = output_swap_cache_idx[i]; + output_cache_idx[miss_idx[i]] = output_swap_cache_idx[i]; } - std::vector out_shape; - out_shape.emplace_back(miss_count); - std::vector dtypes; - size_t output_num = AnfAlgo::GetOutputTensorNum(node_); - for (size_t i = 0; i < output_num; i++) { - dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); - } - AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, - node_.get()); + UpdateShape(miss_count, node_); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h index 7c747b92c0..7deeed5bf7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h @@ -42,7 +42,7 @@ class MapCacheIdxCPUKernel : public CPUKernel { size_t batch_size_{1}; size_t hashmap_length_{1}; TypeId dtype_{kTypeUnknown}; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL(MapCacheIdx, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.cc index c16dddbbc3..90a7f665b1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace kernel { void MapUniformCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } @@ -45,6 +45,10 @@ bool MapUniformCPUKernel::Launch(const std::vector &inputs, template void MapUniformCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } auto input_x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); batch_size_ = 1; for (size_t i = 0; i < input_x_shape.size(); ++i) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.h index b69959a9c3..67e41076c5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_uniform_cpu_kernel.h @@ -41,7 +41,7 @@ class MapUniformCPUKernel : public CPUKernel { private: size_t batch_size_{1}; TypeId dtype_{kTypeUnknown}; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL(MapUniform, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.cc index 035201ca4b..995f602e52 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace kernel { void PadAndShiftCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); type_size_ = GetTypeByte(TypeIdToType(input_x_dtype_)); auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); @@ -77,6 +77,10 @@ void PadAndShiftCPUKernel::LaunchKernel(const std::vector &inputs, std::vector out_shape; out_shape.emplace_back(output_size); std::vector dtypes; + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } auto output_nums = AnfAlgo::GetOutputTensorNum(node_); for (size_t i = 0; i < output_nums; i++) { dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.h index 799a5d55ac..45be92c999 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/pad_and_shift_cpu_kernel.h @@ -42,7 +42,7 @@ class PadAndShiftCPUKernel : public CPUKernel { size_t cum_sum_size_{1}; size_t type_size_{4}; TypeId input_x_dtype_{kTypeUnknown}; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL(PadAndShift, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc index 07f71be8f5..b97d151201 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace kernel { void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; x_data_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); type_size_ = GetTypeByte(TypeIdToType(x_data_type_)); } @@ -28,6 +28,10 @@ void ReshapeCPUKernel::InitKernel(const CNodePtr &kernel_node) { bool ReshapeCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); if (inputs.empty() || outputs.empty()) { MS_LOG(EXCEPTION) << "input or output empty!"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h index 90e8befdf9..35b91ecb22 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/reshape_cpu_kernel.h @@ -33,7 +33,7 @@ class ReshapeCPUKernel : public CPUKernel { const std::vector &outputs) override; private: - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; TypeId x_data_type_{kNumberTypeInt32}; size_t type_size_ = 4; }; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc index 6a8152e86b..b1f7e0c953 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace kernel { void SubAndFilterCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } @@ -43,6 +43,10 @@ bool SubAndFilterCPUKernel::Launch(const std::vector &inputs template void SubAndFilterCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 0); batch_size_ = 1; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.h index cdc0af1e7d..401ada45d1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.h @@ -40,7 +40,7 @@ class SubAndFilterCPUKernel : public CPUKernel { private: size_t batch_size_{1}; TypeId input_x_dtype_{kTypeUnknown}; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL(SubAndFilter, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc index aa22342329..fbdabb2f31 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace kernel { const size_t kUseBucketUniqueSize = 100000; void UniqueCPUKernel::InitKernel(const CNodePtr &kernel_node) { - node_ = kernel_node; + node_wpt_ = kernel_node; CheckParam(kernel_node); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_size_ = input_shape[0]; @@ -45,7 +45,11 @@ bool UniqueCPUKernel::Launch(const std::vector &inputs, } else if (dtype_ == kNumberTypeFloat32) { LaunchKernel(inputs, workspace, outputs); } - if (node_ != nullptr) { + if (!node_wpt_.expired()) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } std::vector out_shape; out_shape.emplace_back(output_size_); std::vector dtypes; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.h index 45c3a2c0e7..4736441531 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/unique_cpu_kernel.h @@ -60,7 +60,7 @@ class UniqueCPUKernel : public CPUKernel { size_t input_size_{0}; TypeId dtype_{kTypeUnknown}; size_t output_size_{0}; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; template static size_t BucketId(DataType data, size_t bucket_num) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc index 466fcd34e1..2251d5a9b3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace kernel { void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); - node_ = kernel_node; + node_wpt_ = kernel_node; input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); @@ -53,6 +53,10 @@ bool UpdateCacheCPUKernel::Launch(const std::vector &inputs, template void UpdateCacheCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto node_ = node_wpt_.lock(); + if (!node_) { + MS_LOG(EXCEPTION) << "node_wpt_ is expired."; + } auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h index 553535dfe8..935fdacf02 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h @@ -46,7 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel { TypeId input_x_dtype_{kTypeUnknown}; TypeId indices_dtype_{kTypeUnknown}; size_t input_x_dtype_size_ = 4; - CNodePtr node_ = nullptr; + CNodeWeakPtr node_wpt_; }; MS_REG_CPU_KERNEL(UpdateCache, diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 966dc834ca..888eadbddf 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -62,6 +62,7 @@ using ValueNodePtr = std::shared_ptr; class CNode; using CNodePtr = std::shared_ptr; using CNodePtrList = std::vector; +using CNodeWeakPtr = std::weak_ptr; class FuncGraph; using FuncGraphSet = OrderedSet;