Browse Source

UnsortedSegMin/Max output_shape validation fix

tags/v1.1.0
danishnxt 5 years ago
parent
commit
098d588d7d
2 changed files with 8 additions and 8 deletions
  1. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h
  2. +4
    -4
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h

+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_max_gpu_kernel.h View File

@@ -71,10 +71,10 @@ class UnsortedSegmentMaxGpuKernel : public GpuKernel {
} else {
MS_LOG(INFO) << "UnsortedSegmentMax Kernel Input count is 2";
}
auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
if (value_count.size() != 1) {
MS_LOG(ERROR) << "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank: 1, got Rank: "
<< value_count.size() << ".";
if (output_shapes.size() < 1) {
MS_LOG(EXCEPTION)
<< "For UnsortedSegmentMax, output shape incorrect rank. Expect Rank at least rank 1, got Rank: "
<< output_shapes.size() << ".";
}
num_segments_ = output_shapes[0];
input_size_ = 1;


+ 4
- 4
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/unsorted_segment_min_gpu_kernel.h View File

@@ -65,10 +65,10 @@ class UnsortedSegmentMinGpuKernel : public GpuKernel {
} else {
MS_LOG(INFO) << "UnsortedSegmentMin Kernel Input count is 2";
}
auto value_count = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
if (value_count.size() != 1) {
MS_LOG(ERROR) << "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank: 1, got Rank: "
<< value_count.size() << ".";
if (output_shapes.size() < 1) {
MS_LOG(EXCEPTION)
<< "For UnsortedSegmentMin, output shape incorrect rank. Expect Rank at least rank 1, got Rank: "
<< output_shapes.size() << ".";
}
num_segments_ = output_shapes[0];
input_size_ = 1;


Loading…
Cancel
Save