| @@ -55,7 +55,69 @@ void PadStridedSliceParameterTo8D(StridedSliceParameter *param) { | |||||
| bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } | bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } | ||||
| int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) { | |||||
| if (in_data == NULL || out_data == NULL || param == NULL) { | |||||
| return NNACL_NULL_PTR; | |||||
| } | |||||
| if (param->num_axes_ > DIMENSION_8D) { | |||||
| return NNACL_PARAM_INVALID; | |||||
| } | |||||
| int *begins = param->begins_; | |||||
| int *ends = param->ends_; | |||||
| int *strides = param->strides_; | |||||
| int *in_shape = param->in_shape_; | |||||
| if (param->num_axes_ < DIMENSION_8D) { | |||||
| PadStridedSliceParameterTo8D(param); | |||||
| } | |||||
| size_t dim_offset[DIMENSION_8D - 1]; | |||||
| dim_offset[6] = in_shape[7]; | |||||
| dim_offset[5] = in_shape[6] * dim_offset[6]; | |||||
| dim_offset[4] = in_shape[5] * dim_offset[5]; | |||||
| dim_offset[3] = in_shape[4] * dim_offset[4]; | |||||
| dim_offset[2] = in_shape[3] * dim_offset[3]; | |||||
| dim_offset[1] = in_shape[2] * dim_offset[2]; | |||||
| dim_offset[0] = in_shape[1] * dim_offset[1]; | |||||
| size_t out_offset = 0; | |||||
| int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; | |||||
| for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { | |||||
| for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { | |||||
| for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { | |||||
| for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { | |||||
| for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { | |||||
| for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { | |||||
| for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { | |||||
| for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { | |||||
| int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + | |||||
| dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + | |||||
| dim6 * dim_offset[6] + dim7; | |||||
| if (param->data_type == kDataTypeInt) { | |||||
| *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | |||||
| } else if (param->data_type == kDataTypeInt8) { | |||||
| *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | |||||
| } else if (param->data_type == kDataTypeBool) { | |||||
| *((bool *)out_data + out_offset) = *((bool *)in_data + in_offset); | |||||
| } else if (param->data_type == kDataTypeFloat64) { | |||||
| *((double *)out_data + out_offset) = *((double *)in_data + in_offset); | |||||
| } else { | |||||
| return NNACL_ERR; | |||||
| } | |||||
| out_offset++; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return NNACL_OK; | |||||
| } | |||||
| int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { | int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { | ||||
| if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) { | |||||
| return DoStridedSliceIntFp64Bool(in_data, out_data, param); | |||||
| } | |||||
| if (in_data == NULL || out_data == NULL || param == NULL) { | if (in_data == NULL || out_data == NULL || param == NULL) { | ||||
| return NNACL_NULL_PTR; | return NNACL_NULL_PTR; | ||||
| } | } | ||||
| @@ -93,10 +155,6 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p | |||||
| dim6 * dim_offset[6] + dim7; | dim6 * dim_offset[6] + dim7; | ||||
| if (param->data_type == kDataTypeFloat) { | if (param->data_type == kDataTypeFloat) { | ||||
| *((float *)out_data + out_offset) = *((float *)in_data + in_offset); | *((float *)out_data + out_offset) = *((float *)in_data + in_offset); | ||||
| } else if (param->data_type == kDataTypeInt8) { | |||||
| *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); | |||||
| } else if (param->data_type == kDataTypeInt) { | |||||
| *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); | |||||
| #ifdef ENABLE_ARM64 | #ifdef ENABLE_ARM64 | ||||
| } else if (param->data_type == kDataTypeFloat16) { | } else if (param->data_type == kDataTypeFloat16) { | ||||
| *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset); | *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset); | ||||