Browse Source

!1205 Gpu UnsortedSegmentSum fix

Merge pull request !1205 from chenweifeng/unsorted_segment_sum
tags/v0.3.0-alpha
mindspore-ci-bot Gitee 6 years ago
parent
commit
df1eb2f65d
1 changed files with 10 additions and 5 deletions
  1. +10
    -5
      mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h

+ 10
- 5
mindspore/ccsrc/kernel/gpu/arrays/unsorted_segment_sum_gpu_kernel.h View File

@@ -50,16 +50,21 @@ class UnsortedSegmentSumGpuKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
auto input_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto ids_shapes = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shapes = AnfAlgo::GetOutputInferShape(kernel_node, 0);

input_dim0_ = input_shapes[0];
for (size_t i = 1; i < input_shapes.size(); i++) {
input_dim1_ *= input_shapes[i];
auto axis = ids_shapes.size();
for (size_t i = 0; i < input_shapes.size(); i++) {
if (i < axis) {
input_dim0_ *= input_shapes[i];
} else {
input_dim1_ *= input_shapes[i];
}
}

output_dim0_ = output_shapes[0];
for (size_t i = 1; i < output_shapes.size(); i++) {
output_dim1_ *= output_shapes[i];
for (size_t j = 1; j < output_shapes.size(); j++) {
output_dim1_ *= output_shapes[j];
}

InitSizeLists();


Loading…
Cancel
Save