|
|
|
@@ -71,10 +71,10 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel { |
|
|
|
} else { |
|
|
|
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2"; |
|
|
|
} |
|
|
|
auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0); |
|
|
|
if (value_count.size() != 1) { |
|
|
|
MS_LOG(ERROR) << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank: 1, got Rank: " |
|
|
|
<< value_count.size() << "."; |
|
|
|
if (output_shapes.size() < 1) { |
|
|
|
MS_LOG(EXCEPTION) |
|
|
|
<< "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank at least rank 1, got Rank: " |
|
|
|
<< output_shapes.size() << "."; |
|
|
|
} |
|
|
|
num_segments_ = output_shapes[0]; |
|
|
|
input_size_ = 1; |
|
|
|
|