From 01abb41f95eb462e24a1625fdd30e7d844c1fcfa Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Thu, 29 Apr 2021 11:31:33 +0800 Subject: [PATCH] strided slice fp64 bool --- .../cpu/nnacl/fp32/strided_slice_fp32.c | 66 +++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/strided_slice_fp32.c b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/strided_slice_fp32.c index 3cb6be5fb9..d510cacccd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/strided_slice_fp32.c +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/strided_slice_fp32.c @@ -55,7 +55,69 @@ void PadStridedSliceParameterTo8D(StridedSliceParameter *param) { bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } +int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) { + if (in_data == NULL || out_data == NULL || param == NULL) { + return NNACL_NULL_PTR; + } + if (param->num_axes_ > DIMENSION_8D) { + return NNACL_PARAM_INVALID; + } + + int *begins = param->begins_; + int *ends = param->ends_; + int *strides = param->strides_; + int *in_shape = param->in_shape_; + if (param->num_axes_ < DIMENSION_8D) { + PadStridedSliceParameterTo8D(param); + } + size_t dim_offset[DIMENSION_8D - 1]; + dim_offset[6] = in_shape[7]; + dim_offset[5] = in_shape[6] * dim_offset[6]; + dim_offset[4] = in_shape[5] * dim_offset[5]; + dim_offset[3] = in_shape[4] * dim_offset[4]; + dim_offset[2] = in_shape[3] * dim_offset[3]; + dim_offset[1] = in_shape[2] * dim_offset[2]; + dim_offset[0] = in_shape[1] * dim_offset[1]; + size_t out_offset = 0; + int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; + for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { + for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { + for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { + for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { + for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { + for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { + for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { + for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { + int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + + dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + + dim6 * dim_offset[6] + dim7; + if (param->data_type == kDataTypeInt) { + *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); + } else if (param->data_type == kDataTypeInt8) { + *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); + } else if (param->data_type == kDataTypeBool) { + *((bool *)out_data + out_offset) = *((bool *)in_data + in_offset); + } else if (param->data_type == kDataTypeFloat64) { + *((double *)out_data + out_offset) = *((double *)in_data + in_offset); + } else { + return NNACL_ERR; + } + out_offset++; + } + } + } + } + } + } + } + } + return NNACL_OK; +} + int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { + if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) { + return DoStridedSliceIntFp64Bool(in_data, out_data, param); + } if (in_data == NULL || out_data == NULL || param == NULL) { return NNACL_NULL_PTR; } @@ -93,10 +155,6 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p dim6 * dim_offset[6] + dim7; if (param->data_type == kDataTypeFloat) { *((float *)out_data + out_offset) = *((float *)in_data + in_offset); - } else if (param->data_type == kDataTypeInt8) { - *((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); - } else if (param->data_type == kDataTypeInt) { - *((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); #ifdef ENABLE_ARM64 } else if (param->data_type == kDataTypeFloat16) { *((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset);