| @@ -55,7 +55,69 @@ void PadStridedSliceParameterTo8D(StridedSliceParameter *param) { | |||
| bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } | |||
| int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) { | |||
| if (in_data == NULL || out_data == NULL || param == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| if (param->num_axes_ > DIMENSION_8D) { | |||
| return NNACL_PARAM_INVALID; | |||
| } | |||
| int *begins = param->begins_; | |||
| int *ends = param->ends_; | |||
| int *strides = param->strides_; | |||
| int *in_shape = param->in_shape_; | |||
| if (param->num_axes_ < DIMENSION_8D) { | |||
| PadStridedSliceParameterTo8D(param); | |||
| } | |||
| size_t dim_offset[DIMENSION_8D - 1]; | |||
| dim_offset[6] = in_shape[7]; | |||
| dim_offset[5] = in_shape[6] * dim_offset[6]; | |||
| dim_offset[4] = in_shape[5] * dim_offset[5]; | |||
| dim_offset[3] = in_shape[4] * dim_offset[4]; | |||
| dim_offset[2] = in_shape[3] * dim_offset[3]; | |||
| dim_offset[1] = in_shape[2] * dim_offset[2]; | |||
| dim_offset[0] = in_shape[1] * dim_offset[1]; | |||
| size_t out_offset = 0; | |||
| int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; | |||
| for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { | |||
| for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { | |||
| for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { | |||
| for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { | |||
| for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { | |||
| for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { | |||
| for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { | |||
| for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { | |||
| int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + | |||
| dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + | |||
| dim6 * dim_offset[6] + dim7; | |||
| if (param->data_type == kDataTypeInt) { | |||
| *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | |||
| } else if (param->data_type == kDataTypeInt8) { | |||
| *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | |||
| } else if (param->data_type == kDataTypeBool) { | |||
| *((bool *)out_data + out_offset) = *((bool *)in_data + in_offset); | |||
| } else if (param->data_type == kDataTypeFloat64) { | |||
| *((double *)out_data + out_offset) = *((double *)in_data + in_offset); | |||
| } else { | |||
| return NNACL_ERR; | |||
| } | |||
| out_offset++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { | |||
| if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) { | |||
| return DoStridedSliceIntFp64Bool(in_data, out_data, param); | |||
| } | |||
| if (in_data == NULL || out_data == NULL || param == NULL) { | |||
| return NNACL_NULL_PTR; | |||
| } | |||
| @@ -93,10 +155,6 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p | |||
| dim6 * dim_offset[6] + dim7; | |||
| if (param->data_type == kDataTypeFloat) { | |||
| *((float *)out_data + out_offset) = *((float *)in_data + in_offset); | |||
| } else if (param->data_type == kDataTypeInt8) { | |||
| *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | |||
| } else if (param->data_type == kDataTypeInt) { | |||
| *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | |||
| #ifdef ENABLE_ARM64 | |||
| } else if (param->data_type == kDataTypeFloat16) { | |||
| *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset); | |||