Browse Source

[MSLITE][Develop] fix bug of fp32 grad op: unsorted_segmaent_sum

tags/v1.2.0-rc1
yangruoqi713 4 years ago
parent
commit
a0e0bb5f68
2 changed files with 7 additions and 4 deletions
  1. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc
  2. +4
    -4
      mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h

+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.cc View File

@@ -38,6 +38,8 @@ int UnsortedSegmentSumCPUKernel::Init() {
auto input_shape = in_tensors_.at(0)->shape();
auto segment_ids_shape = in_tensors_.at(1)->shape();
auto output_shape = out_tensors_.at(0)->shape();
unit_num_ = 1;
input_dim1_ = 1;
for (size_t i = 0; i < input_shape.size(); ++i) {
unit_num_ *= input_shape[i];
if (i >= segment_ids_shape.size()) {
@@ -45,6 +47,7 @@ int UnsortedSegmentSumCPUKernel::Init() {
}
}
output_dim0_ = output_shape[0];
output_dim1_ = 1;
for (size_t j = 1; j < output_shape.size(); j++) {
output_dim1_ *= output_shape[j];
}


+ 4
- 4
mindspore/lite/src/runtime/kernel/arm/fp32_grad/unsorted_segment_sum.h View File

@@ -32,10 +32,10 @@ class UnsortedSegmentSumCPUKernel : public LiteKernel {
int ReSize() override;
int Run() override;
int Execute(int task_id);
size_t unit_num_;
size_t input_dim1_;
size_t output_dim0_;
size_t output_dim1_;
size_t unit_num_ = 0;
size_t input_dim1_ = 0;
size_t output_dim0_ = 0;
size_t output_dim1_ = 0;

private:
};


Loading…
Cancel
Save