|
|
|
@@ -74,6 +74,26 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s |
|
|
|
AnfAlgo::SetNodeAttr(kAttrSize, MakeValue(Convert2Long(unsort_segment_sum_shape)), slice); |
|
|
|
return slice; |
|
|
|
} |
|
|
|
|
|
|
|
bool CheckInputs(const CNodePtr &origin_node) { |
|
|
|
MS_EXCEPTION_IF_NULL(origin_node); |
|
|
|
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { |
|
|
|
MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum |
|
|
|
<< ". CNode= " << origin_node->DebugString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto x_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); |
|
|
|
auto y_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 1); |
|
|
|
if (x_shape.empty() || y_shape.empty()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (x_shape[x_shape.size() - 1] != 1) { |
|
|
|
MS_LOG(DEBUG) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " |
|
|
|
<< x_shape[x_shape.size() - 1]; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return x_shape.size() > y_shape.size(); |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
const BaseRef UnsortSegmentSumFission::DefinePattern() const { |
|
|
|
@@ -88,19 +108,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
auto origin_node = node->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(origin_node); |
|
|
|
if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { |
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum |
|
|
|
<< ". CNode= " << origin_node->DebugString(); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); |
|
|
|
if (input0_shape.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Input0's shape size less than 2, not optimize"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (input0_shape[input0_shape.size() - 1] != 1) { |
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " |
|
|
|
<< input0_shape[input0_shape.size() - 1]; |
|
|
|
if (!CheckInputs(origin_node)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
size_t pad_dim_size; |
|
|
|
@@ -110,7 +118,7 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con |
|
|
|
} else if (input_dtype == kNumberTypeFloat16) { |
|
|
|
pad_dim_size = 16; |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum data type not in (float21, float16), no need change"; |
|
|
|
MS_LOG(DEBUG) << "UnsortedSegmentSum data type not in (float32, float16), no need change"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
|