From c060fcbafb486feec405cdf474cb3e9a7af37690 Mon Sep 17 00:00:00 2001 From: zhaodezan Date: Thu, 11 Mar 2021 21:11:20 +0800 Subject: [PATCH] reorganize strided_slice_infer code --- .../lite/nnacl/infer/strided_slice_infer.c | 334 ++++++++++-------- 1 file changed, 191 insertions(+), 143 deletions(-) diff --git a/mindspore/lite/nnacl/infer/strided_slice_infer.c b/mindspore/lite/nnacl/infer/strided_slice_infer.c index b7d336519b..a6b16b2c0e 100644 --- a/mindspore/lite/nnacl/infer/strided_slice_infer.c +++ b/mindspore/lite/nnacl/infer/strided_slice_infer.c @@ -21,6 +21,26 @@ const size_t kStridedSliceInputNum = 1; const size_t kStridedSliceMultiInputNumMin = 3; const size_t kStridedSliceMultiInputNumMax = 5; +typedef struct StridedSliceTransferBuffer { + int ndim_; + + int begins_[MAX_SHAPE_SIZE]; + int ends_[MAX_SHAPE_SIZE]; + int strides_[MAX_SHAPE_SIZE]; + int begins_mask_[MAX_SHAPE_SIZE]; + int ends_mask_[MAX_SHAPE_SIZE]; + int ellipsis_mask_[MAX_SHAPE_SIZE]; + int new_axis_mask_[MAX_SHAPE_SIZE]; + int shrink_axis_mask_[MAX_SHAPE_SIZE]; + + size_t begins_size_; + size_t ends_size_; + size_t strides_size_; + size_t ellipsis_mask_size_; + size_t new_axis_mask_size_; + size_t shrink_axis_mask_size_; +} StridedSliceTransferBuffer; + bool CheckInputs(const TensorC *const *inputs, size_t inputs_size) { for (size_t i = 1; i < inputs_size; ++i) { if (inputs[i]->data_ == NULL) { @@ -128,10 +148,8 @@ int HandleAxesInputExist(const TensorC *const *inputs, int *ndim_, int *in_shape return NNACL_OK; } -// note: begin, end, stride length are equal, but may less than rank of input -int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, - OpParameter *parameter) { -#ifdef Debug +int StrideSlicePreCheck(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { if (outputs_size != kStridedSliceOutputNum) { return NNACL_PARAM_INVALID; } @@ -142,6 +160,138 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten if (parameter == NULL || outputs[0] == NULL || inputs[0] == NULL) { return NNACL_NULL_PTR; } + return NNACL_OK; +} + +void Bit2Vector(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + transfer_buffer->begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i); + transfer_buffer->ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i); + transfer_buffer->ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i); + transfer_buffer->new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i); + transfer_buffer->shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i); + } +} + +void ApplyNewAxisMask(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape_, + size_t *in_shape_size) { + for (size_t i = 0; i < transfer_buffer->new_axis_mask_size_; i++) { + if (transfer_buffer->new_axis_mask_[i]) { + transfer_buffer->ndim_ += 1; + ShapeInsert(in_shape_, in_shape_size, i, 1); + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = 1; + transfer_buffer->strides_[i] = 1; + + ShapePush(transfer_buffer->begins_, &transfer_buffer->begins_size_, 0); + ShapePush(transfer_buffer->ends_, &transfer_buffer->ends_size_, in_shape_[transfer_buffer->ndim_ - 1]); + ShapePush(transfer_buffer->strides_, &transfer_buffer->strides_size_, 1); + + transfer_buffer->begins_mask_[i] = false; + transfer_buffer->ends_mask_[i] = false; + transfer_buffer->ellipsis_mask_[i] = false; + transfer_buffer->shrink_axis_mask_[i] = false; + } + } +} + +void ApplyBeginMask(StridedSliceTransferBuffer *transfer_buffer) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->begins_mask_[i]) { + transfer_buffer->begins_[i] = 0; + } + } +} + +void ApplyEndMask(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + if (transfer_buffer->ends_mask_[i]) { + transfer_buffer->ends_[i] = in_shape_[i]; + } + } +} + +void ApplyEllipsisMask(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) { + for (size_t i = 0; i < transfer_buffer->ellipsis_mask_size_; i++) { + if (transfer_buffer->ellipsis_mask_[i]) { + transfer_buffer->begins_[i] = 0; + transfer_buffer->ends_[i] = in_shape_[i]; + break; + } + } +} + +void TransIndexToPositive(StridedSliceTransferBuffer *transfer_buffer, int *in_shape_) { + for (int i = 0; i < (int)(transfer_buffer->begins_size_); ++i) { + if (transfer_buffer->begins_[i] < 0) { + transfer_buffer->begins_[i] += in_shape_[i]; + } + if (transfer_buffer->ends_[i] < 0) { + transfer_buffer->ends_[i] += in_shape_[i]; + } + } +} + +void ApplyShrinkMask(StridedSliceTransferBuffer *transfer_buffer, int *output_shape, size_t *output_shape_size) { + int old_out_shape[MAX_SHAPE_SIZE]; + size_t old_out_shape_size = 0; + ShapeSet(old_out_shape, &old_out_shape_size, output_shape, *output_shape_size); + *output_shape_size = 0; + for (size_t i = 0; i < transfer_buffer->shrink_axis_mask_size_; i++) { + if (transfer_buffer->shrink_axis_mask_[i]) { + transfer_buffer->ends_[i] = transfer_buffer->begins_[i] + 1; + transfer_buffer->strides_[i] = 1; + } else { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } + } + for (size_t i = transfer_buffer->shrink_axis_mask_size_; i < old_out_shape_size; i++) { + ShapePush(output_shape, output_shape_size, old_out_shape[i]); + } +} + +void TransferBuffer2Param(StridedSliceTransferBuffer *transfer_buffer, StridedSliceParameter *param, int *in_shape_) { + for (int i = 0; i < transfer_buffer->ndim_; i++) { + param->begins_[i] = transfer_buffer->begins_[i]; + param->ends_[i] = transfer_buffer->ends_[i]; + param->in_shape_[i] = in_shape_[i]; + param->strides_[i] = transfer_buffer->strides_[i]; + } + + for (int i = transfer_buffer->ndim_; i < param->in_shape_length_; i++) { + param->begins_[i] = 0; + param->ends_[i] = in_shape_[i]; + param->in_shape_[i] = in_shape_[i]; + param->strides_[i] = 1; + } +} + +void InitStridedSliceTransferBuffer(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->begins_size_ = 0; + transfer_buffer->ends_size_ = 0; + transfer_buffer->strides_size_ = 0; + transfer_buffer->ellipsis_mask_size_ = 0; + transfer_buffer->new_axis_mask_size_ = 0; + transfer_buffer->shrink_axis_mask_size_ = 0; +} + +void SetMaskSize(StridedSliceTransferBuffer *transfer_buffer) { + transfer_buffer->ellipsis_mask_size_ = transfer_buffer->ndim_; + transfer_buffer->new_axis_mask_size_ = transfer_buffer->ndim_; + transfer_buffer->shrink_axis_mask_size_ = transfer_buffer->ndim_; + transfer_buffer->begins_size_ = transfer_buffer->ndim_; + transfer_buffer->ends_size_ = transfer_buffer->ndim_; + transfer_buffer->strides_size_ = transfer_buffer->ndim_; +} + +// note: begin, end, stride length are equal, but may less than rank of input +int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, + OpParameter *parameter) { +#ifdef Debug + int check_ret = StrideSlicePreCheck(inputs, inputs_size, outputs, outputs_size, parameter); + if (check_ret != NNACL_OK) { + return check_ret; + } #endif const TensorC *input = inputs[0]; @@ -152,42 +302,29 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten } int in_shape_[MAX_SHAPE_SIZE]; - int begins_[MAX_SHAPE_SIZE]; - int ends_[MAX_SHAPE_SIZE]; - size_t in_shape_size_ = 0; - if (parameter->infer_flag_) { - ShapeSet(in_shape_, &in_shape_size_, input->shape_, input->shape_size_); - } - size_t begins_size_ = 0; - size_t ends_size_ = 0; - int strides_[MAX_SHAPE_SIZE]; - size_t strides_size_ = 0; - int begins_mask_[MAX_SHAPE_SIZE]; - int ends_mask_[MAX_SHAPE_SIZE]; - int ellipsis_mask_[MAX_SHAPE_SIZE]; - size_t ellipsis_mask_size_ = 0; - int new_axis_mask_[MAX_SHAPE_SIZE]; - size_t new_axis_mask_size_ = 0; - int shrink_axis_mask_[MAX_SHAPE_SIZE]; - size_t shrink_axis_mask_size_ = 0; + size_t in_shape_size = 0; + ShapeSet(in_shape_, &in_shape_size, input->shape_, input->shape_size_); + + StridedSliceTransferBuffer transfer_buffer; + InitStridedSliceTransferBuffer(&transfer_buffer); StridedSliceParameter *param = (StridedSliceParameter *)parameter; - param->num_axes_ = in_shape_size_; - param->in_shape_length_ = in_shape_size_; + param->num_axes_ = in_shape_size; + param->in_shape_length_ = in_shape_size; - int ndim_ = 0; + transfer_buffer.ndim_ = 0; if (inputs_size == kStridedSliceInputNum) { - ndim_ = (int)(param->num_axes_); - - for (int i = 0; i < ndim_; i++) { - ShapePush(begins_, &begins_size_, param->begins_[i]); - ShapePush(ends_, &ends_size_, param->ends_[i]); - ShapePush(strides_, &strides_size_, param->strides_[i]); + transfer_buffer.ndim_ = (int)(param->num_axes_); + for (int i = 0; i < transfer_buffer.ndim_; i++) { + ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, param->begins_[i]); + ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, param->ends_[i]); + ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, param->strides_[i]); } } if (!CheckInputs(inputs, inputs_size)) { return NNACL_INFER_INVALID; } + if (inputs_size == 4) { const TensorC *begin_tensor = inputs[1]; int *begin_data = (int *)(begin_tensor->data_); @@ -198,134 +335,45 @@ int StridedSliceInferShape(const TensorC *const *inputs, size_t inputs_size, Ten if (begin_data == NULL || end_data == NULL || stride_data == NULL) { return NNACL_ERR; } - ndim_ = GetElementNum(begin_tensor); - for (int i = 0; i < ndim_; ++i) { - ShapePush(begins_, &begins_size_, begin_data[i]); - ShapePush(ends_, &ends_size_, end_data[i]); - ShapePush(strides_, &strides_size_, stride_data[i]); + transfer_buffer.ndim_ = GetElementNum(begin_tensor); + for (int i = 0; i < transfer_buffer.ndim_; ++i) { + ShapePush(transfer_buffer.begins_, &transfer_buffer.begins_size_, begin_data[i]); + ShapePush(transfer_buffer.ends_, &transfer_buffer.ends_size_, end_data[i]); + ShapePush(transfer_buffer.strides_, &transfer_buffer.strides_size_, stride_data[i]); } } + if (inputs_size == 5) { - int ret = HandleAxesInputExist(inputs, &ndim_, in_shape_, begins_, strides_, ends_); + int ret = HandleAxesInputExist(inputs, &transfer_buffer.ndim_, in_shape_, transfer_buffer.begins_, + transfer_buffer.strides_, transfer_buffer.ends_); if (ret != NNACL_OK) { return ret; } } // set all mask to original input shape - ellipsis_mask_size_ = ndim_; - new_axis_mask_size_ = ndim_; - shrink_axis_mask_size_ = ndim_; - begins_size_ = ndim_; - ends_size_ = ndim_; - strides_size_ = ndim_; - - // convert bit to vector - for (int i = 0; i < ndim_; i++) { - begins_mask_[i] = (uint32_t)(param->begins_mask_) & (1 << i); - ends_mask_[i] = (uint32_t)(param->ends_mask_) & (1 << i); - ellipsis_mask_[i] = (uint32_t)(param->ellipsisMask_) & (1 << i); - new_axis_mask_[i] = (uint32_t)(param->newAxisMask_) & (1 << i); - shrink_axis_mask_[i] = (uint32_t)(param->shrinkAxisMask_) & (1 << i); - } - - // ApplyNewAxisMask(); - for (size_t i = 0; i < new_axis_mask_size_; i++) { - if (new_axis_mask_[i]) { - ndim_ += 1; - ShapeInsert(in_shape_, &in_shape_size_, i, 1); - begins_[i] = 0; - ends_[i] = 1; - strides_[i] = 1; - - ShapePush(begins_, &begins_size_, 0); - ShapePush(ends_, &ends_size_, in_shape_[ndim_ - 1]); - ShapePush(strides_, &strides_size_, 1); - - begins_mask_[i] = false; - ends_mask_[i] = false; - ellipsis_mask_[i] = false; - shrink_axis_mask_[i] = false; - } - } - // ApplyBeginMask(); - for (int i = 0; i < ndim_; i++) { - if (begins_mask_[i]) { - begins_[i] = 0; - } - } - // ApplyEndMask(); - for (int i = 0; i < ndim_; i++) { - if (ends_mask_[i]) { - ends_[i] = in_shape_[i]; - } - } - // ApplyEllipsisMask(); - for (size_t i = 0; i < ellipsis_mask_size_; i++) { - if (ellipsis_mask_[i]) { - begins_[i] = 0; - ends_[i] = in_shape_[i]; - break; - } - } - - if (!parameter->infer_flag_) { - return NNACL_INFER_INVALID; - } + SetMaskSize(&transfer_buffer); + Bit2Vector(&transfer_buffer, param); + ApplyNewAxisMask(&transfer_buffer, param, in_shape_, &in_shape_size); + ApplyBeginMask(&transfer_buffer); + ApplyEndMask(&transfer_buffer, in_shape_); + ApplyEllipsisMask(&transfer_buffer, in_shape_); int output_shape[MAX_SHAPE_SIZE]; size_t output_shape_size = 0; - ShapeSet(output_shape, &output_shape_size, in_shape_, in_shape_size_); - - // TransIndexToPositive(); - for (int i = 0; i < (int)(begins_size_); ++i) { - if (begins_[i] < 0) { - begins_[i] += in_shape_[i]; - } - if (ends_[i] < 0) { - ends_[i] += in_shape_[i]; - } - } - - for (int i = 0; i < ndim_; i++) { - if (strides_[i] == 0) { + ShapeSet(output_shape, &output_shape_size, in_shape_, in_shape_size); + TransIndexToPositive(&transfer_buffer, in_shape_); + for (int i = 0; i < transfer_buffer.ndim_; i++) { + if (transfer_buffer.strides_[i] == 0) { return NNACL_ERR; } - output_shape[i] = (ends_[i] - begins_[i] + strides_[i] + (strides_[i] < 0 ? 1 : -1)) / strides_[i]; + output_shape[i] = (transfer_buffer.ends_[i] - transfer_buffer.begins_[i] + transfer_buffer.strides_[i] + + (transfer_buffer.strides_[i] < 0 ? 1 : -1)) / + transfer_buffer.strides_[i]; } - - // ApplyShrinkMask - int old_out_shape[MAX_SHAPE_SIZE]; - size_t old_out_shape_size = 0; - ShapeSet(old_out_shape, &old_out_shape_size, output_shape, output_shape_size); - output_shape_size = 0; - for (size_t i = 0; i < shrink_axis_mask_size_; i++) { - if (shrink_axis_mask_[i]) { - ends_[i] = begins_[i] + 1; - strides_[i] = 1; - } else { - ShapePush(output_shape, &output_shape_size, old_out_shape[i]); - } - } - for (size_t i = shrink_axis_mask_size_; i < old_out_shape_size; i++) { - ShapePush(output_shape, &output_shape_size, old_out_shape[i]); - } - + ApplyShrinkMask(&transfer_buffer, output_shape, &output_shape_size); SetShapeArray(outputs[0], output_shape, output_shape_size); - - for (int i = 0; i < ndim_; i++) { - param->begins_[i] = begins_[i]; - param->ends_[i] = ends_[i]; - param->in_shape_[i] = in_shape_[i]; - param->strides_[i] = strides_[i]; - } - - for (int i = ndim_; i < param->in_shape_length_; i++) { - param->begins_[i] = 0; - param->ends_[i] = in_shape_[i]; - param->in_shape_[i] = in_shape_[i]; - param->strides_[i] = 1; - } + TransferBuffer2Param(&transfer_buffer, param, in_shape_); return NNACL_OK; }