|
|
|
@@ -55,7 +55,69 @@ void PadStridedSliceParameterTo8D(StridedSliceParameter *param) { |
|
|
|
|
|
|
|
bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; } |
|
|
|
|
|
|
|
int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) { |
|
|
|
if (in_data == NULL || out_data == NULL || param == NULL) { |
|
|
|
return NNACL_NULL_PTR; |
|
|
|
} |
|
|
|
if (param->num_axes_ > DIMENSION_8D) { |
|
|
|
return NNACL_PARAM_INVALID; |
|
|
|
} |
|
|
|
|
|
|
|
int *begins = param->begins_; |
|
|
|
int *ends = param->ends_; |
|
|
|
int *strides = param->strides_; |
|
|
|
int *in_shape = param->in_shape_; |
|
|
|
if (param->num_axes_ < DIMENSION_8D) { |
|
|
|
PadStridedSliceParameterTo8D(param); |
|
|
|
} |
|
|
|
size_t dim_offset[DIMENSION_8D - 1]; |
|
|
|
dim_offset[6] = in_shape[7]; |
|
|
|
dim_offset[5] = in_shape[6] * dim_offset[6]; |
|
|
|
dim_offset[4] = in_shape[5] * dim_offset[5]; |
|
|
|
dim_offset[3] = in_shape[4] * dim_offset[4]; |
|
|
|
dim_offset[2] = in_shape[3] * dim_offset[3]; |
|
|
|
dim_offset[1] = in_shape[2] * dim_offset[2]; |
|
|
|
dim_offset[0] = in_shape[1] * dim_offset[1]; |
|
|
|
size_t out_offset = 0; |
|
|
|
int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7; |
|
|
|
for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) { |
|
|
|
for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) { |
|
|
|
for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) { |
|
|
|
for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) { |
|
|
|
for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) { |
|
|
|
for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) { |
|
|
|
for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) { |
|
|
|
for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) { |
|
|
|
int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + |
|
|
|
dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] + |
|
|
|
dim6 * dim_offset[6] + dim7; |
|
|
|
if (param->data_type == kDataTypeInt) { |
|
|
|
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); |
|
|
|
} else if (param->data_type == kDataTypeInt8) { |
|
|
|
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); |
|
|
|
} else if (param->data_type == kDataTypeBool) { |
|
|
|
*((bool *)out_data + out_offset) = *((bool *)in_data + in_offset); |
|
|
|
} else if (param->data_type == kDataTypeFloat64) { |
|
|
|
*((double *)out_data + out_offset) = *((double *)in_data + in_offset); |
|
|
|
} else { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
out_offset++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return NNACL_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) { |
|
|
|
if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) { |
|
|
|
return DoStridedSliceIntFp64Bool(in_data, out_data, param); |
|
|
|
} |
|
|
|
if (in_data == NULL || out_data == NULL || param == NULL) { |
|
|
|
return NNACL_NULL_PTR; |
|
|
|
} |
|
|
|
@@ -93,10 +155,6 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p |
|
|
|
dim6 * dim_offset[6] + dim7; |
|
|
|
if (param->data_type == kDataTypeFloat) { |
|
|
|
*((float *)out_data + out_offset) = *((float *)in_data + in_offset); |
|
|
|
} else if (param->data_type == kDataTypeInt8) { |
|
|
|
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset); |
|
|
|
} else if (param->data_type == kDataTypeInt) { |
|
|
|
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset); |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
} else if (param->data_type == kDataTypeFloat16) { |
|
|
|
*((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset); |
|
|
|
|