|
|
|
@@ -50,16 +50,21 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel { |
|
|
|
|
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); |
|
|
|
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
|
|
|
|
input_dim0_ = input_shapes[0]; |
|
|
|
for (size_t i = 1; i < input_shapes.size(); i++) { |
|
|
|
input_dim1_ *= input_shapes[i]; |
|
|
|
auto axis = ids_shapes.size(); |
|
|
|
for (size_t i = 0; i < input_shapes.size(); i++) { |
|
|
|
if (i < axis) { |
|
|
|
input_dim0_ *= input_shapes[i]; |
|
|
|
} else { |
|
|
|
input_dim1_ *= input_shapes[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
output_dim0_ = output_shapes[0]; |
|
|
|
for (size_t i = 1; i < output_shapes.size(); i++) { |
|
|
|
output_dim1_ *= output_shapes[i]; |
|
|
|
for (size_t j = 1; j < output_shapes.size(); j++) { |
|
|
|
output_dim1_ *= output_shapes[j]; |
|
|
|
} |
|
|
|
|
|
|
|
InitSizeLists(); |
|
|
|
|