Browse Source

Check attr exists before getting it in embeddinglookup cpu kernel

tags/v0.6.0-beta
yujianfeng 5 years ago
parent
commit
3fdc3629af
2 changed files with 8 additions and 2 deletions
  1. +6
    -2
      mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
  2. +2
    -0
      mindspore/ccsrc/utils/utils.h

+ 6
- 2
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc View File

@@ -36,7 +36,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
axis_ = 4 - input_shape_.size(); axis_ = 4 - input_shape_.size();
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
if (AnfAlgo::HasNodeAttr(kAttrReduceScatterFlag, kernel_node)) {
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, kAttrReduceScatterFlag);
}
#ifdef ENABLE_MPI #ifdef ENABLE_MPI
if (reduce_scatter_flag_) { if (reduce_scatter_flag_) {
size_t gatherv2_out_lens = 1; size_t gatherv2_out_lens = 1;
@@ -65,7 +67,9 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true";
} }
#endif #endif
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
if (AnfAlgo::HasNodeAttr(kAttrOffset, kernel_node)) {
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, kAttrOffset);
}
CPUKernelUtils::ExpandDimsTo4(&input_shape_); CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_); CPUKernelUtils::ExpandDimsTo4(&output_shape_);
} }


+ 2
- 0
mindspore/ccsrc/utils/utils.h View File

@@ -223,6 +223,8 @@ constexpr auto kAttrNumSplit = "num_split";
constexpr auto kAttrOutputNum = "output_num"; constexpr auto kAttrOutputNum = "output_num";
constexpr auto kAttrSizeSplits = "size_splits"; constexpr auto kAttrSizeSplits = "size_splits";
constexpr auto kAttrOutputDefault = "output_default"; constexpr auto kAttrOutputDefault = "output_default";
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";


// attr value // attr value
constexpr auto kValueTargetSwitch = "target_switch"; constexpr auto kValueTargetSwitch = "target_switch";


Loading…
Cancel
Save