Browse Source

!6723 fix bug of cpu op arithmetic & softmax

Merge pull request !6723 from 陶云浩/lite
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
62f08d9dda
4 changed files with 21 additions and 12 deletions
  1. +8
    -8
      mindspore/lite/nnacl/arithmetic_common.h
  2. +1
    -1
      mindspore/lite/nnacl/softmax_parameter.h
  3. +8
    -3
      mindspore/lite/src/ops/arithmetic.cc
  4. +4
    -0
      mindspore/lite/src/ops/softmax.cc

+ 8
- 8
mindspore/lite/nnacl/arithmetic_common.h View File

@@ -28,20 +28,20 @@ typedef struct ArithmeticParameter {
bool broadcasting_;
size_t ndim_;
int activation_type_;
int in_shape0_[5];
int in_shape0_[10];
int in_elements_num0_;
int in_shape1_[5];
int in_shape1_[10];
int in_elements_num1_;

int out_shape_[5];
int out_shape_[10];
int out_elements_num_;

int in_strides0_[5];
int in_strides1_[5];
int out_strides_[5];
int in_strides0_[10];
int in_strides1_[10];
int out_strides_[10];

int multiples0_[5];
int multiples1_[5];
int multiples0_[10];
int multiples1_[10];
} ArithmeticParameter;

#ifdef __cplusplus


+ 1
- 1
mindspore/lite/nnacl/softmax_parameter.h View File

@@ -24,7 +24,7 @@ typedef struct SoftmaxParameter {
int32_t axis_;
int element_size_;
int n_dim_;
int input_shape_[4];
int input_shape_[5];
} SoftmaxParameter;

#endif // MINDSPORE_LITE_NNACL_SOFTMAX_PARAMETER_H_

+ 8
- 3
mindspore/lite/src/ops/arithmetic.cc View File

@@ -46,9 +46,14 @@ int Arithmetic::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite
if (!GetInferFlag()) {
return RET_OK;
}
in_shape0_.resize(5);
in_shape1_.resize(5);
out_shape_.resize(5);
if (input_shape0.size() > 10 || input_shape1.size() > 10) {
int wrong_dim = input_shape0.size() > input_shape1.size() ? input_shape0.size() : input_shape1.size();
MS_LOG(ERROR) << "Not support input dim: " << wrong_dim << ", The input dim must be less than 10";
return RET_ERROR;
}
in_shape0_.resize(10);
in_shape1_.resize(10);
out_shape_.resize(10);

ndim_ = input_shape0.size();
if (input_shape0.size() < input_shape1.size()) {


+ 4
- 0
mindspore/lite/src/ops/softmax.cc View File

@@ -82,6 +82,10 @@ int SoftMax::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> out
if (!GetInferFlag()) {
return RET_OK;
}
if (input->shape().size() > 5) {
MS_LOG(ERROR) << "Softmax input dim must be less than 5, get " << input->shape().size();
return RET_ERROR;
}
output->set_shape(input->shape());
return RET_OK;
}


Loading…
Cancel
Save