|
|
|
@@ -94,6 +94,10 @@ const AnfNodePtr UnsortSegmentSumFission::Process(const FuncGraphPtr &graph, con |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(origin_node, 0); |
|
|
|
if (input0_shape.size() < 2) { |
|
|
|
MS_LOG(INFO) << "Input0's shape size less than 2, not optimize"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (input0_shape[input0_shape.size() - 1] != 1) { |
|
|
|
MS_LOG(INFO) << "UnsortedSegmentSum is not need fission. The last value of input0's shape is " |
|
|
|
<< input0_shape[input0_shape.size() - 1]; |
|
|
|
|