diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/arithmetic_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/arithmetic_fp32.c index e16af7f61c..1f87f7e921 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/arithmetic_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/arithmetic_fp32.c @@ -140,7 +140,7 @@ int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) { } #endif for (; index < size; index++) { - out[index] = (float)((bool)(in0[index]) | (bool)(in1[index])); + out[index] = (float)((unsigned int)(in0[index]) | (unsigned int)(in1[index])); } return NNACL_OK; } @@ -148,7 +148,7 @@ int ElementLogicalOr(const float *in0, const float *in1, float *out, int size) { int ElementLogicalOrBool(const bool *in0, const bool *in1, bool *out, int size) { int index = 0; for (; index < size; index++) { - out[index] = (in0[index]) | (in1[index]); + out[index] = (bool)((unsigned int)(in0[index]) | (unsigned int)(in1[index])); } return NNACL_OK; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/log_softmax_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/log_softmax_fp32.c index 8d46d1557a..326030c160 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/log_softmax_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/log_softmax_fp32.c @@ -15,7 +15,6 @@ */ #include "nnacl/fp32/log_softmax_fp32.h" #include -#include #include "nnacl/fp32/softmax_fp32.h" #include "nnacl/fp32/exp_fp32.h" diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c index 74bd2b924f..cdb14b5589 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.c @@ -16,7 +16,7 @@ #include "nnacl/fp32/prelu_fp32.h" #ifdef ENABLE_ARM64 -static inline void PRelu4x16(const float *in, float *out, float *cur_slope, size_t step) { +static inline void PRelu4x16(const float *in, float *out, const float *cur_slope, size_t step) { asm volatile( "mov x10, %[in]\n" "mov x11, %[out]\n" @@ -85,7 +85,7 @@ static inline void PRelu4x16(const float *in, float *out, float *cur_slope, size } #endif -void PRelu(const float *input, float *output, float *slope, int start, int end, int channel) { +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel) { int i = start; #ifdef ENABLE_ARM64 for (; i < end - 3; i += 4) { @@ -95,7 +95,7 @@ void PRelu(const float *input, float *output, float *slope, int start, int end, for (; j < channel - 15; j += 16) { const float *in = cur_in + j; float *out = cur_out + j; - float *cur_slope = slope + j; + const float *cur_slope = slope + j; size_t step = channel * sizeof(float); PRelu4x16(in, out, cur_slope, step); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.h index 9d6701e55f..a57a3a3223 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/prelu_fp32.h @@ -22,7 +22,7 @@ #ifdef __cplusplus extern "C" { #endif -void PRelu(const float *input, float *output, float *slope, int start, int end, int channel); +void PRelu(const float *input, float *output, const float *slope, int start, int end, int channel); void PReluShareChannel(const float *input, float *output, float slope, int start, int end); #ifdef __cplusplus diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.c index 8b96d51d25..e1ce47cdca 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.c @@ -171,7 +171,7 @@ void TransposeDim6Fp32(const float *in_data, float *out_data, const int *strides } } -void TransposeDimsFp32(const float *in_data, float *out_data, const int *output_shape, int *size, int *position, +void TransposeDimsFp32(const float *in_data, float *out_data, const int *output_shape, const int *size, int *position, TransposeParameter *transpose_param, int task_id, int thread_num) { int *perm = transpose_param->perm_; int *strides = transpose_param->strides_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.h index 6cef303c10..d406ee64a2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/transpose_fp32.h @@ -26,7 +26,7 @@ extern "C" { #endif int DoTransposeFp32(const float *in_data, float *out_data, const int *output_shape, TransposeParameter *param); -void TransposeDimsFp32(const float *in_data, float *out_data, const int *output_shape, int *size, int *position, +void TransposeDimsFp32(const float *in_data, float *out_data, const int *output_shape, const int *size, int *position, TransposeParameter *transpose_param, int task_id, int thread_num); #ifdef __cplusplus } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.c index e8a5731c21..2db5416123 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/binary_cross_entropy.c @@ -43,7 +43,7 @@ void BinaryCrossEntropy(const int input_size, const int reduction, const float * if (input_size % 2 == 1) { tmp_loss[0] += tmp_loss[input_size - 1]; } - for (int stride = input_size / 2; stride > 0; stride >>= 1) { + for (int stride = input_size / 2; stride > 0; stride = stride / 2) { for (int i = 0; i < stride; i++) { tmp_loss[i] += tmp_loss[i + stride]; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pack_ext.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pack_ext.c index 0ac265340a..75032fb17c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pack_ext.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pack_ext.c @@ -222,7 +222,7 @@ void col2im_hwc(const float *data_col, float *data_im, const ConvParameter *conv int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { int offset = (input_row * in_width + input_col) * tot_channels; - float *data_im_ptr = &data_im[offset]; + float *data_im_ptr = data_im + offset; for (int i = 0; i < channels; i++) { data_im_ptr[i] += data_col[i]; } @@ -270,7 +270,7 @@ void rolling_col2im_hwc(const float *data_col, float *data_im, const ConvParamet int input_col = -pad_left + kernel_col * dilation_w + col_stride_offset; if (((unsigned)(input_row) < (unsigned)(in_height)) && ((unsigned)(input_col) < (unsigned)(in_width))) { int offset = (input_row * in_width + input_col) * tot_channels; - float *data_im_ptr = &data_im[offset]; + float *data_im_ptr = data_im + offset; *data_im_ptr += *data_col; } data_col++; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pooling_grad.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pooling_grad.c index cd80ee470f..6084d4ab63 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pooling_grad.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32_grad/pooling_grad.c @@ -36,8 +36,8 @@ void AvgPoolingGrad(const float *input_ptr, float *output_ptr, int count, Poolin const float32x4_t factor = vdupq_n_f32(kk); #endif for (int ib = 0; ib < count; ib++) { - float *out = &output_ptr[(ib * in_h * in_w * channel)]; - const float *inPtr = &input_ptr[(ib * output_h * output_w * channel)]; + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * output_h * output_w * channel; // iterate over yt for (int yh = 0; yh < output_h; yh++) { int over_h = pad_h - yh * stride_h; @@ -115,9 +115,9 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p int output_h = pooling_param->output_h_; for (int ib = 0; ib < output_batch; ib++) { - float *out = &output_ptr[(ib * in_h * in_w * channel)]; - const float *inPtr = &input_ptr[(ib * in_h * in_w * channel)]; - const float *dyPtr = &dy_ptr[(ib * output_h * output_w * channel)]; + float *out = output_ptr + ib * in_h * in_w * channel; + const float *inPtr = input_ptr + ib * in_h * in_w * channel; + const float *dyPtr = dy_ptr + ib * output_h * output_w * channel; for (int yh = 0; yh < output_h; yh++) { int over_h = pad_h - yh * stride_h; int kh_s = MSMAX(0, over_h); @@ -127,7 +127,7 @@ void MaxPoolingGrad(const float *input_ptr, const float *dy_ptr, float *output_p int kw_s = MSMAX(0, over_w); int kw_e = MSMIN(win_w, in_w + over_w); int ic = 0; - for (; ic < (channel & ~3); ic += 4) { + for (; ic <= channel - 4; ic += 4) { int idx = (yw + yh * output_w) * channel + ic; #ifdef ENABLE_ARM uint32x4_t max_idx = vdupq_n_u32(0); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/add_sub_grad_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/add_sub_grad_infer.c index 2b0f5e549c..c81971b627 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/add_sub_grad_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/add_sub_grad_infer.c @@ -16,7 +16,6 @@ #include "nnacl/infer/add_sub_grad_infer.h" #include "nnacl/arithmetic.h" -#include "nnacl/infer/arithmetic_grad_infer.h" #include "nnacl/infer/infer_register.h" int AddSubGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c index 121cf4065a..184f83ae58 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/argmin_max_infer.c @@ -52,7 +52,7 @@ int ArgMinMaxInferShape(const TensorC *const *inputs, const size_t inputs_size, if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); size_t input_shape_size = input->shape_size_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_grad_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_grad_infer.c index ed00572b17..3713740f7d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_grad_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_grad_infer.c @@ -40,13 +40,13 @@ int ArithmeticGradInferShape(const TensorC *const *inputs, size_t inputs_size, T TensorC *dx1 = outputs[0]; TensorC *dx2 = outputs[1]; - int in_shape0[MAX_SHAPE_SIZE]; + int in_shape0[MAX_SHAPE_SIZE] = {0}; size_t in_shape0_size = 0; ShapeSet(in_shape0, &in_shape0_size, x1->shape_, x1->shape_size_); - int in_shape1[MAX_SHAPE_SIZE]; + int in_shape1[MAX_SHAPE_SIZE] = {0}; size_t in_shape1_size = 0; ShapeSet(in_shape1, &in_shape1_size, x2->shape_, x2->shape_size_); - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; ShapeSet(out_shape, &out_shape_size, dy->shape_, dy->shape_size_); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c index 74f6ae4eec..7a1050ac41 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/arithmetic_infer.c @@ -17,6 +17,40 @@ #include "nnacl/infer/arithmetic_infer.h" #include "nnacl/infer/infer_register.h" +void UpdateInputShape(const int input_shape0_size, const int input_shape1_size, int *ndim, const int *input_shape0, + const int *input_shape1, int *in_shape0, int *in_shape1) { + if (input_shape0_size < input_shape1_size) { + *ndim = input_shape1_size; + int fill_dim_num = input_shape1_size - input_shape0_size; + int j = 0; + for (size_t i = 0; i < input_shape1_size; i++) { + if (i < fill_dim_num) { + in_shape0[i] = 1; + } else { + in_shape0[i] = input_shape0[j++]; + } + in_shape1[i] = input_shape1[i]; + } + } else if (input_shape0_size > input_shape1_size) { + *ndim = input_shape0_size; + int fill_dim_num = input_shape0_size - input_shape1_size; + int j = 0; + for (size_t i = 0; i < input_shape0_size; i++) { + if (i < fill_dim_num) { + in_shape1[i] = 1; + } else { + in_shape1[i] = input_shape1[j++]; + } + in_shape0[i] = input_shape0[i]; + } + } else { + for (size_t i = 0; i < input_shape0_size; i++) { + in_shape1[i] = input_shape1[i]; + in_shape0[i] = input_shape0[i]; + } + } +} + int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { #ifdef Debug @@ -46,75 +80,48 @@ int ArithmeticInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - if (input_shape0_size > 10 || input_shape1_size > 10) { + if (input_shape0_size >= MAX_SHAPE_SIZE || input_shape1_size >= MAX_SHAPE_SIZE) { return NNACL_ERR; } - int in_shape0_[10]; - int in_shape1_[10]; - int out_shape_[10]; + int in_shape0[10]; + int in_shape1[10]; + int out_shape[10]; + int ndim = input_shape0_size; + UpdateInputShape(input_shape0_size, input_shape1_size, &ndim, input_shape0, input_shape1, in_shape0, in_shape1); - int ndim_ = input_shape0_size; - if (input_shape0_size < input_shape1_size) { - ndim_ = input_shape1_size; - int fill_dim_num = input_shape1_size - input_shape0_size; - int j = 0; - for (size_t i = 0; i < input_shape1_size; i++) { - if (i < fill_dim_num) { - in_shape0_[i] = 1; - } else { - in_shape0_[i] = input_shape0[j++]; - } - in_shape1_[i] = input_shape1[i]; - } - } else if (input_shape0_size > input_shape1_size) { - ndim_ = input_shape0_size; - int fill_dim_num = input_shape0_size - input_shape1_size; - int j = 0; - for (size_t i = 0; i < input_shape0_size; i++) { - if (i < fill_dim_num) { - in_shape1_[i] = 1; - } else { - in_shape1_[i] = input_shape1[j++]; - } - in_shape0_[i] = input_shape0[i]; - } - } else { - for (size_t i = 0; i < input_shape0_size; i++) { - in_shape1_[i] = input_shape1[i]; - in_shape0_[i] = input_shape0[i]; - } - } - - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; - for (int i = 0; i < ndim_; i++) { - if (in_shape0_[i] != in_shape1_[i]) { - if (in_shape0_[i] == 1) { - out_shape_[i] = in_shape1_[i]; - } else if (in_shape1_[i] == 1) { - out_shape_[i] = in_shape0_[i]; + if (ndim >= MAX_SHAPE_SIZE) { + return NNACL_INFER_INVALID; + } + for (int i = 0; i < ndim; i++) { + if (in_shape0[i] != in_shape1[i]) { + if (in_shape0[i] == 1) { + out_shape[i] = in_shape1[i]; + } else if (in_shape1[i] == 1) { + out_shape[i] = in_shape0[i]; } else { return NNACL_ERR; } param->broadcasting_ = true; } else { - out_shape_[i] = in_shape0_[i]; + out_shape[i] = in_shape0[i]; } - output_shape[output_shape_size] = out_shape_[i]; + output_shape[output_shape_size] = out_shape[i]; output_shape_size++; } SetShapeArray(output, output_shape, output_shape_size); - param->ndim_ = ndim_; - memcpy(param->in_shape0_, in_shape0_, ndim_ * sizeof(int)); - memcpy(param->in_shape1_, in_shape1_, ndim_ * sizeof(int)); - memcpy(param->out_shape_, out_shape_, ndim_ * sizeof(int)); + param->ndim_ = ndim; + memcpy(param->in_shape0_, in_shape0, ndim * sizeof(int)); + memcpy(param->in_shape1_, in_shape1, ndim * sizeof(int)); + memcpy(param->out_shape_, out_shape, ndim * sizeof(int)); param->in_elements_num0_ = 1; param->in_elements_num1_ = 1; param->out_elements_num_ = 1; - for (int i = 0; i < ndim_; i++) { + for (int i = 0; i < ndim; i++) { param->in_elements_num0_ *= param->in_shape0_[i]; param->in_elements_num1_ *= param->in_shape1_[i]; param->out_elements_num_ *= param->out_shape_[i]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/audio_spectrogram_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/audio_spectrogram_infer.c index 8eaad1aa12..4a72db3f14 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/audio_spectrogram_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/audio_spectrogram_infer.c @@ -17,14 +17,14 @@ #include "nnacl/infer/audio_spectrogram_infer.h" #include "nnacl/infer/infer_register.h" -int Log2Ceil(uint32_t length) { +unsigned Log2Ceil(unsigned length) { if (length == 0) { - return -1; + return 0; } int floor = 0; for (int i = 4; i >= 0; --i) { - const int shift = (1 << i); - uint32_t tmp = length >> shift; + const unsigned shift = (1 << i); + unsigned tmp = length >> shift; if (tmp != 0) { length = tmp; floor += shift; @@ -33,8 +33,8 @@ int Log2Ceil(uint32_t length) { return length == (length & ~(length - 1)) ? floor : floor + 1; } -uint32_t GetFftLength(uint32_t length) { - int shift = Log2Ceil(length); +unsigned GetFftLength(unsigned length) { + unsigned shift = Log2Ceil(length); return 1 << shift; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/batch_to_space_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/batch_to_space_infer.c index 36f0261bcd..68689086af 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/batch_to_space_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/batch_to_space_infer.c @@ -18,7 +18,7 @@ #include "nnacl/infer/infer_register.h" int SetOutputShapeFromParam(const TensorC *const *inputs, TensorC **outputs, OpParameter *parameter) { - int input_shape[MAX_SHAPE_SIZE]; + int input_shape[MAX_SHAPE_SIZE] = {0}; size_t input_shape_size = 0; ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); @@ -60,7 +60,7 @@ int SetOutputShapeFromParam(const TensorC *const *inputs, TensorC **outputs, OpP } int SetOutputShapeFromInput(const TensorC *const *inputs, TensorC **outputs) { - int input_shape[MAX_SHAPE_SIZE]; + int input_shape[MAX_SHAPE_SIZE] = {0}; size_t input_shape_size = 0; ShapeSet(input_shape, &input_shape_size, inputs[0]->shape_, inputs[0]->shape_size_); if (input_shape_size != 4) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c index 59a55006be..92f2711fab 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.c @@ -61,7 +61,7 @@ int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const i return NNACL_OK; } -bool TensorListIsFullyDefined(int *shape, size_t shape_size) { +bool TensorListIsFullyDefined(const int *shape, size_t shape_size) { for (size_t i = 0; i < shape_size; ++i) { if (shape[i] < 0) { return false; @@ -145,7 +145,7 @@ int SetShapeTensor(TensorC *dst, const TensorC *src) { return NNACL_OK; } -int SetShapeArray(TensorC *dst, int *src, size_t src_size) { +int SetShapeArray(TensorC *dst, const int *src, size_t src_size) { for (size_t i = 0; i < src_size; i++) { dst->shape_[i] = src[i]; } @@ -359,7 +359,7 @@ int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - int input_shape[MAX_SHAPE_SIZE]; + int input_shape[MAX_SHAPE_SIZE] = {0}; size_t input_shape_size = 0; ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); input_shape_size--; @@ -381,23 +381,30 @@ int VectorCInit(VectorC *vc, size_t per_malloc_size) { return NNACL_OK; } -void VectorCSet(VectorC *vc, const int *src_shape, size_t src_shape_size) { +int VectorCSet(VectorC *vc, const int *src_shape, size_t src_shape_size) { if (src_shape_size == 0) { vc->size_ = 0; } else { free(vc->data_); vc->max_size_ = (src_shape_size / vc->per_malloc_size_ + 1) * vc->per_malloc_size_; vc->data_ = (int *)malloc(sizeof(int) * vc->max_size_); + if (vc->data_ == NULL) { + return NNACL_ERR; + } for (size_t i = 0; i < src_shape_size; i++) { vc->data_[i] = src_shape[i]; } vc->size_ = src_shape_size; } + return NNACL_OK; } -void VectorCPush(VectorC *vc, int value) { +int VectorCPush(VectorC *vc, int value) { if (vc->size_ + 1 > vc->max_size_) { int *tmp = (int *)malloc(vc->per_malloc_size_ * sizeof(int) + vc->max_size_ * sizeof(int)); + if (tmp == NULL) { + return NNACL_ERR; + } memcpy(tmp, vc->data_, vc->size_ * sizeof(int)); free(vc->data_); vc->data_ = tmp; @@ -405,6 +412,7 @@ void VectorCPush(VectorC *vc, int value) { } vc->data_[vc->size_] = value; vc->size_++; + return NNACL_OK; } void VectorCInsert(VectorC *vc, int index, int value) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h index 5a024f43b0..2d4c1fab1f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/common_infer.h @@ -157,7 +157,7 @@ typedef struct VectorC { int MallocTensorListData(TensorListC *tensor_list, TypeIdC dtype, vvector *tensor_shape); int TensorListMergeShape(int *element_shape, size_t *element_shape_size, const int *tmp, size_t tmp_size); -bool TensorListIsFullyDefined(int *shape, size_t shape_size); +bool TensorListIsFullyDefined(const int *shape, size_t shape_size); int GetBatch(const TensorC *tensor); int GetHeight(const TensorC *tensor); @@ -180,7 +180,7 @@ int CheckAugmentNullOutputSize(const TensorC *const *inputs, size_t inputs_size, void SetDataTypeFormat(TensorC *dst, const TensorC *src); int SetShapeTensor(TensorC *dst, const TensorC *src); -int SetShapeArray(TensorC *dst, int *src, size_t src_size); +int SetShapeArray(TensorC *dst, const int *src, size_t src_size); int ShapeSet(int *dst_shape, size_t *dst_shape_size, const int *src_shape, size_t src_shape_size); int ShapePush(int *shape, size_t *shape_size, int value); int ShapeInsert(int *shape, size_t *shape_size, int index, int value); @@ -198,8 +198,8 @@ int FftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou OpParameter *parameter); int VectorCInit(VectorC *vc, size_t per_malloc_size); -void VectorCSet(VectorC *vc, const int *src_shape, size_t src_shape_size); -void VectorCPush(VectorC *vc, int value); +int VectorCSet(VectorC *vc, const int *src_shape, size_t src_shape_size); +int VectorCPush(VectorC *vc, int value); void VectorCInsert(VectorC *vc, int index, int value); void VectorCErase(VectorC *vc, int index); bool VectorCEqual(VectorC *vc1, VectorC *vc2); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c index 4729a0aa87..1c17b1fc25 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/concat_infer.c @@ -41,13 +41,13 @@ int ConcatInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * if (axis < 0 || axis >= input0_shape_size) { return NNACL_ERR; } - int input0_shape_without_axis[MAX_SHAPE_SIZE]; + int input0_shape_without_axis[MAX_SHAPE_SIZE] = {0}; size_t input0_shape_without_axis_size = 0; ShapeSet(input0_shape_without_axis, &input0_shape_without_axis_size, input0_shape, input0_shape_size); ShapeErase(input0_shape_without_axis, &input0_shape_without_axis_size, axis); int output_axis_dim = input0_shape[axis]; for (size_t i = 1; i < inputs_size; ++i) { - int shape_tmp[MAX_SHAPE_SIZE]; + int shape_tmp[MAX_SHAPE_SIZE] = {0}; size_t shape_tmp_size = 0; ShapeSet(shape_tmp, &shape_tmp_size, inputs[i]->shape_, inputs[i]->shape_size_); if (shape_tmp_size != input0_shape_size) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/crop_and_resize_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/crop_and_resize_infer.c index 5b4f39c590..279e7c4650 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/crop_and_resize_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/crop_and_resize_infer.c @@ -37,7 +37,7 @@ int CropAndResizeInferShape(const TensorC *const *inputs, size_t inputs_size, Te return NNACL_INFER_INVALID; } - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; if (inputs[1]->data_ != NULL) { const TensorC *boxes_tensor = inputs[1]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/dedepthwise_conv2d_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/dedepthwise_conv2d_infer.c index eaef11caf4..20ac637687 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/dedepthwise_conv2d_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/dedepthwise_conv2d_infer.c @@ -15,7 +15,6 @@ */ #include "nnacl/infer/dedepthwise_conv2d_infer.h" -#include "nnacl/infer/infer_register.h" int DeDepthwiseConv2DInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depth_to_space_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depth_to_space_infer.c index 65a1105c78..ba276f0b6b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depth_to_space_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depth_to_space_infer.c @@ -35,7 +35,7 @@ int DepthToSpaceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - int input_shape[MAX_SHAPE_SIZE]; + int input_shape[MAX_SHAPE_SIZE] = {0}; size_t input_shape_size = 0; ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); if (input_shape_size != 4) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depthwise_conv2d_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depthwise_conv2d_infer.c index 866524d127..7ae3aeb564 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depthwise_conv2d_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/depthwise_conv2d_infer.c @@ -15,7 +15,6 @@ */ #include "nnacl/infer/depthwise_conv2d_infer.h" -#include "nnacl/infer/infer_register.h" int DepthwiseConv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c index e838b84ec2..21c872b792 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/embedding_lookup_infer.c @@ -37,18 +37,18 @@ int EmbeddingLookupInferShape(const TensorC *const *inputs, size_t inputs_size, return NNACL_INFER_INVALID; } - int embedding_shape[MAX_SHAPE_SIZE]; + int embedding_shape[MAX_SHAPE_SIZE] = {0}; size_t embedding_shape_size = 0; ShapeSet(embedding_shape, &embedding_shape_size, params_->shape_, params_->shape_size_); ShapeErase(embedding_shape, &embedding_shape_size, 0); - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; ShapeSet(output_shape, &output_shape_size, ids->shape_, ids->shape_size_); for (size_t i = 0; i < embedding_shape_size; ++i) { ShapePush(output_shape, &output_shape_size, embedding_shape[i]); } for (size_t i = 1; i < inputs_size - 1; ++i) { - int embedding_shape_t[MAX_SHAPE_SIZE]; + int embedding_shape_t[MAX_SHAPE_SIZE] = {0}; size_t embedding_shape_t_size = 0; ShapeSet(embedding_shape_t, &embedding_shape_t_size, inputs[i]->shape_, inputs[i]->shape_size_); ShapeErase(embedding_shape_t, &embedding_shape_t_size, 0); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/fill_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/fill_infer.c index 0a357b34eb..e0fd92ea1c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/fill_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/fill_infer.c @@ -41,7 +41,7 @@ int FillInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o if (num_dims != 0 && dst_shape == NULL) { return NNACL_INFER_INVALID; } - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; for (size_t i = 0; i < num_dims; i++) { ShapePush(output_shape, &output_shape_size, dst_shape[i]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/flatten_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/flatten_infer.c index d1de665d96..f0d77a120a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/flatten_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/flatten_infer.c @@ -34,7 +34,7 @@ int FlattenInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC return NNACL_INFER_INVALID; } - int input_shape[MAX_SHAPE_SIZE]; + int input_shape[MAX_SHAPE_SIZE] = {0}; size_t input_shape_size = 0; ShapeSet(input_shape, &input_shape_size, input->shape_, input->shape_size_); int output_shape[2]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c index 78c025cc86..e8be381ff7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_infer.c @@ -41,14 +41,14 @@ int GatherInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * size_t indices_shape_size = 0; ShapeSet(indices_shape, &indices_shape_size, indices->shape_, indices->shape_size_); int indices_rank = indices_shape_size; - int in_shape[MAX_SHAPE_SIZE]; + int in_shape[MAX_SHAPE_SIZE] = {0}; size_t in_shape_size = 0; ShapeSet(in_shape, &in_shape_size, input->shape_, input->shape_size_); int in_rank = in_shape_size; if (in_rank < axis + 1) { return NNACL_ERR; } - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; ShapeSet(out_shape, &out_shape_size, in_shape, in_shape_size); ShapeErase(out_shape, &out_shape_size, axis); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_nd_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_nd_infer.c index 9811cc5314..821ea878c3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_nd_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/gather_nd_infer.c @@ -40,7 +40,7 @@ int GatherNdInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC return NNACL_OK; } int i = 0; - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; for (i = 0; i < indices_rank - 1; ++i) { ShapePush(out_shape, &out_shape_size, indices->shape_[i]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/group_conv2d_grad_input_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/group_conv2d_grad_input_infer.c index 007ceaaf79..012eb49808 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/group_conv2d_grad_input_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/group_conv2d_grad_input_infer.c @@ -15,7 +15,6 @@ */ #include "nnacl/infer/group_conv2d_grad_input_infer.h" -#include "nnacl/infer/infer_register.h" int GroupConv2dGradInputInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lsh_projection_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lsh_projection_infer.c index e2dea5ad31..0ba800f230 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lsh_projection_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/lsh_projection_infer.c @@ -34,7 +34,7 @@ int LshProjectionInferShape(const TensorC *const *inputs, size_t inputs_size, Te out_tensor->data_type_ = kNumberTypeInt32; out_tensor->format_ = Format_NHWC; - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; LshProjectionParameter *param = (LshProjectionParameter *)parameter; switch (param->lsh_type_) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c index fff6ff5577..66d6f9773d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/matmul_infer.c @@ -36,10 +36,10 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * return NNACL_INFER_INVALID; } - int a_shape[MAX_SHAPE_SIZE]; + int a_shape[MAX_SHAPE_SIZE] = {0}; size_t a_shape_size = 0; ShapeSet(a_shape, &a_shape_size, input0->shape_, input0->shape_size_); - int b_shape[MAX_SHAPE_SIZE]; + int b_shape[MAX_SHAPE_SIZE] = {0}; size_t b_shape_size = 0; ShapeSet(b_shape, &b_shape_size, input1->shape_, input1->shape_size_); @@ -67,9 +67,15 @@ int MatmulInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * } if (param->a_transpose_) { + if (a_shape_size < 2) { + return NNACL_ERR; + } iswap(&a_shape[a_shape_size - 1], &a_shape[a_shape_size - 2]); } if (param->b_transpose_) { + if (b_shape_size < 2) { + return NNACL_ERR; + } iswap(&b_shape[b_shape_size - 1], &b_shape[b_shape_size - 2]); } int c_shape[MAX_SHAPE_SIZE]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/mean_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/mean_infer.c index 6ad78f3fb2..906437eb83 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/mean_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/mean_infer.c @@ -15,7 +15,6 @@ */ #include "nnacl/infer/mean_infer.h" -#include "nnacl/infer/infer_register.h" int MeanInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, OpParameter *parameter) { @@ -34,7 +33,7 @@ int MeanInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o } ReduceParameter *param = (ReduceParameter *)parameter; bool keep_dims = (bool)(param->keep_dims_); - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; int *axes = param->axes_; int num_axes = param->num_axes_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/pad_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/pad_infer.c index 97651cb714..da55234393 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/pad_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/pad_infer.c @@ -47,7 +47,7 @@ int PadInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **ou param->paddings_[i] = ((int *)paddings->data_)[i]; } - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; if (input->shape_size_ > 4) { return NNACL_INPUT_TENSOR_ERROR; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/random_standard_normal_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/random_standard_normal_infer.c index 584621087e..ae8640ac41 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/random_standard_normal_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/random_standard_normal_infer.c @@ -36,7 +36,7 @@ int RandomStandardNormalInferShape(const TensorC *const *inputs, size_t inputs_s return NNACL_INFER_INVALID; } int input_num = GetElementNum(inputs[0]); - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; for (int i = 0; i < input_num; i++) { ShapePush(output_shape, &output_shape_size, input_data[i]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reduce_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reduce_infer.c index 9b2eeb1d67..b5e28ae5a1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reduce_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reduce_infer.c @@ -28,7 +28,7 @@ int ReduceOnAllAxes(const TensorC *input, TensorC *output, int *out_shape, size_ return NNACL_OK; } -int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, int *actual_axes, TensorC *output, int *out_shape, +int ReduceOnSelectedAxes(const TensorC *input, size_t num_axes, const int *actual_axes, TensorC *output, int *out_shape, size_t out_shape_size, bool keep_dims) { for (size_t i = 0; i < input->shape_size_; i++) { bool reduce_axis = false; @@ -67,7 +67,7 @@ int ReduceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * return NNACL_INFER_INVALID; } bool keep_dims = param->keep_dims_; - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; const size_t out_shape_size = 0; // get axes from input tensor const TensorC *axes_input = inputs[1]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reshape_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reshape_infer.c index 5f0d8c938f..039620f3ed 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reshape_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/reshape_infer.c @@ -17,7 +17,7 @@ #include "nnacl/infer/reshape_infer.h" #include "nnacl/infer/infer_register.h" -void CalShape(int *data, const TensorC *const *inputs, int *out_shape, size_t *out_shape_size, int shape_size) { +void CalShape(const int *data, const TensorC *const *inputs, int *out_shape, size_t *out_shape_size, int shape_size) { int input_count = GetElementNum(inputs[0]); int index = 0; int size = 1; @@ -68,6 +68,9 @@ int CalNewShape(const TensorC *in_tensor, int *out_shape, size_t out_shape_size) return NNACL_ERR; } if (infer_index != -1) { + if (out_shape_size_new == 0) { + return NNACL_ERR; + } out_shape[infer_index] = in_shape_size / out_shape_size_new; } return NNACL_OK; @@ -118,6 +121,9 @@ int CalShapeByType(const TensorC *const *inputs, size_t shape_size, int *out_sha case kNumberTypeUInt32: { uint32_t *data = (uint32_t *)(shape_tensor->data_); int *data_int = (int *)malloc(sizeof(int) * shape_size); + if (data_int == NULL) { + return NNACL_ERR; + } for (size_t i = 0; i < shape_size; i++) { data_int[i] = data[i]; } @@ -147,7 +153,7 @@ int ReshapeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC return NNACL_INFER_INVALID; } - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; if (inputs_size == 2) { const TensorC *shape_tensor = inputs[1]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_grad_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_grad_infer.c index 08b1243a8b..b9b1254dd2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_grad_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_grad_infer.c @@ -39,7 +39,7 @@ int ResizeGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso if (input_1->shape_size_ == 4) { ShapeSet(output->shape_, &output->shape_size_, input_1->shape_, input_1->shape_size_); } else if (input_1->shape_size_ == 1 && input_1->shape_[0] == 2 && input_1->data_type_ == kNumberTypeInt32) { - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; int32_t *data = (int32_t *)(input_1->data_); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_infer.c index 562c9137e9..4a53d353bd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/resize_infer.c @@ -127,7 +127,7 @@ int ResizeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC * return NNACL_INFER_INVALID; } - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; ShapePush(output_shape, &output_shape_size, GetBatch(input)); int ret = CalculateNewHeightAndWidth(inputs, inputs_size, param); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/rfft_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/rfft_infer.c index 695f8d8470..288c85eefb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/rfft_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/rfft_infer.c @@ -35,6 +35,9 @@ int RfftInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **o } ShapeSet(output->shape_, &(output->shape_size_), input->shape_, input->shape_size_); RfftParameter *param = (RfftParameter *)parameter; + if (input->shape_size_ < 1) { + return NNACL_ERR; + } output->shape_[input->shape_size_ - 1] = param->fft_length_ / 2 + 1; ShapePush(output->shape_, &(output->shape_size_), 2); return NNACL_OK; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/sparse_to_dense_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/sparse_to_dense_infer.c index 9c74075ca1..65ec900a99 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/sparse_to_dense_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/sparse_to_dense_infer.c @@ -34,7 +34,7 @@ int SparseToDenseInferShape(const TensorC *const *inputs, size_t inputs_size, Te return NNACL_INFER_INVALID; } int *input1_data = (int *)(input1->data_); - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; for (int i = 0; i < GetElementNum(input1); i++) { ShapePush(output_shape, &output_shape_size, input1_data[i]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/squeeze_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/squeeze_infer.c index 766e90d2b8..c8fc6ac18a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/squeeze_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/squeeze_infer.c @@ -32,7 +32,7 @@ int SqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; for (size_t i = 0; i < param->axis_size_; i++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c index 3bebcd5b0e..a5d8303bab 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/stack_infer.c @@ -31,7 +31,7 @@ int StackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC ** if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - int32_t output_shape[MAX_SHAPE_SIZE]; + int32_t output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; ShapeSet(output_shape, &output_shape_size, input->shape_, input->shape_size_); int axis = param->axis_ < 0 ? param->axis_ + input->shape_size_ + 1 : param->axis_; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c index 2f1501deb2..096e73e5ba 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_grad_infer.c @@ -39,16 +39,16 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, SetDataTypeFormat(outputs[0], input); bool inferflag = parameter->infer_flag_; - int in_shape_[MAX_SHAPE_SIZE]; + int in_shape_[MAX_SHAPE_SIZE] = {0}; size_t in_shape_size = 0; if (inferflag) { ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_); } - int begins_[MAX_SHAPE_SIZE]; + int begins_[MAX_SHAPE_SIZE] = {0}; size_t begins_size = 0; - int ends_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE] = {0}; size_t ends_size = 0; - int strides_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE] = {0}; size_t strides_size = 0; if (!StridedSliceCheckInputs(inputs, inputs_size)) { @@ -69,17 +69,17 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, } // set all mask to original input shape - uint32_t begins_mask_[MAX_SHAPE_SIZE]; - uint32_t ends_mask_[MAX_SHAPE_SIZE]; - uint32_t ellipsis_mask_[MAX_SHAPE_SIZE]; - uint32_t new_axis_mask_[MAX_SHAPE_SIZE]; + uint32_t begins_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ends_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t ellipsis_mask_[MAX_SHAPE_SIZE] = {0}; + uint32_t new_axis_mask_[MAX_SHAPE_SIZE] = {0}; StridedSliceParameter *param = (StridedSliceParameter *)parameter; for (size_t i = 0; i < ndim_; i++) { - begins_mask_[i] = (bool)(param->begins_mask_) & (1 << i); - ends_mask_[i] = (bool)(param->ends_mask_) & (1 << i); - ellipsis_mask_[i] = (bool)(param->ellipsisMask_) & (1 << i); - new_axis_mask_[i] = (bool)(param->newAxisMask_) & (1 << i); + begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); } param->num_axes_ = in_shape_size; param->in_shape_length_ = in_shape_size; @@ -133,7 +133,7 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, } size_t output_size = inputs[1]->shape_[0]; - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; if (inputs[1]->data_ == NULL) { return NNACL_ERR; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c index a118b3abc5..a067336c23 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/strided_slice_infer.c @@ -94,9 +94,6 @@ int HandleAxesInputExist(const TensorC *const *inputs, int *ndim, int *in_shape, return NNACL_ERR; } stride_data = (int *)(stride_tensor->data_); - if (stride_data == NULL) { - return NNACL_ERR; - } } int axes[MAX_SHAPE_SIZE]; @@ -138,6 +135,9 @@ int HandleAxesInputExist(const TensorC *const *inputs, int *ndim, int *in_shape, // begins or ends exceed limit will be set to limit begins[i] = imax(imin(begin_data[axis], input_tensor->shape_[i] - 1), -input_tensor->shape_[i]); ends[i] = imax(imin(end_data[axis], input_tensor->shape_[i]), -input_tensor->shape_[i] - 1); + if (stride_data == NULL) { + return NNACL_ERR; + } strides[i] = stride_data[axis]; } else { begins[i] = 0; @@ -164,12 +164,12 @@ int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, Tensor } void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param) { - for (int i = 0; i < transfer_buffer->ndim_; i++) { - transfer_buffer->begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i); - transfer_buffer->ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i); - transfer_buffer->ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i); - transfer_buffer->new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i); - transfer_buffer->shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i); + for (unsigned i = 0; i < (unsigned)transfer_buffer->ndim_; i++) { + transfer_buffer->begins_mask_[i] = (unsigned)(param->begins_mask_) & (1 << i); + transfer_buffer->ends_mask_[i] = (unsigned)(param->ends_mask_) & (1 << i); + transfer_buffer->ellipsis_mask_[i] = (unsigned)(param->ellipsisMask_) & (1 << i); + transfer_buffer->new_axis_mask_[i] = (unsigned)(param->newAxisMask_) & (1 << i); + transfer_buffer->shrink_axis_mask_[i] = (unsigned)(param->shrinkAxisMask_) & (1 << i); } } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_getitem_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_getitem_infer.c index e54de2aef7..e0ce890ff6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_getitem_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_getitem_infer.c @@ -63,7 +63,7 @@ int TensorListGetItemInferShape(const TensorC *const *inputs, size_t inputs_size return NNACL_NULL_PTR; } int *ele_shape_data = (int *)(input2->data_); - int element_shape[MAX_SHAPE_SIZE]; + int element_shape[MAX_SHAPE_SIZE] = {0}; size_t element_shape_size = 0; for (int i = 0; i < GetElementNum(input2); ++i) { ShapePush(element_shape, &element_shape_size, ele_shape_data[i]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c index 3402dd1f1d..c64936525a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/tensorlist_stack_infer.c @@ -40,7 +40,7 @@ int TensorListStackInferShape(const TensorC *const *inputs, size_t inputs_size, return NNACL_NULL_PTR; } int *ele_shape_ptr = (int *)(ele_shape->data_); - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; for (int i = 0; i < GetElementNum(ele_shape); ++i) { ShapePush(output_shape, &output_shape_size, ele_shape_ptr[i]); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c index 24e974a36d..13f1ec740e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/transpose_infer.c @@ -52,12 +52,12 @@ int TransposeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor if (perms_num != 0 && perm_data == NULL) { return NNACL_INFER_INVALID; } - int perm[MAX_SHAPE_SIZE]; + int perm[MAX_SHAPE_SIZE] = {0}; size_t perm_size = 0; for (size_t i = 0; i < perms_num; i++) { ShapePush(perm, &perm_size, perm_data[i]); } - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; if (input->shape_size_ != 4 && perms_num == 4) { for (size_t i = 0; i < input->shape_size_; ++i) { out_shape[i] = input->shape_[i]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsorted_segment_sum_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsorted_segment_sum_infer.c index 55f218d54a..e75483f8c2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsorted_segment_sum_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsorted_segment_sum_infer.c @@ -30,7 +30,7 @@ int UnsortedSegmentSumInferShape(const TensorC *const *inputs, size_t inputs_siz const TensorC *x = inputs[0]; const TensorC *segment_id = inputs[1]; int num_segments = *(int *)(inputs[2]->data_); - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; ShapePush(output_shape, &output_shape_size, num_segments); for (int index = segment_id->shape_size_; index < (int)(x->shape_size_); index++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsqueeze_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsqueeze_infer.c index 428b856e27..2ba99f9d38 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsqueeze_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unsqueeze_infer.c @@ -37,7 +37,7 @@ int UnsqueezeInferShape(const TensorC *const *inputs, size_t inputs_size, Tensor UnSqueezeParameter *param = (UnSqueezeParameter *)parameter; int in_rank = input->shape_size_; int dim_rank = param->num_dim_; - int out_shape[MAX_SHAPE_SIZE]; + int out_shape[MAX_SHAPE_SIZE] = {0}; size_t out_shape_size = 0; if (dim_rank == 0) { for (size_t i = 0; i < input->shape_size_; i++) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unstack_infer.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unstack_infer.c index bad2cb768a..c947be6372 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unstack_infer.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/infer/unstack_infer.c @@ -39,7 +39,7 @@ int UnstackInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC if (!parameter->infer_flag_) { return NNACL_INFER_INVALID; } - int output_shape[MAX_SHAPE_SIZE]; + int output_shape[MAX_SHAPE_SIZE] = {0}; size_t output_shape_size = 0; for (size_t i = 0; i < input->shape_size_; ++i) { if (i != axis) {