Browse Source

strided slice fp64 bool

pull/15880/head
zhaozhenlong 4 years ago
parent
commit
01abb41f95
1 changed files with 62 additions and 4 deletions
  1. +62
    -4
      mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/strided_slice_fp32.c

+ 62
- 4
mindspore/ccsrc/backend/kernel_compiler/cpu/nnacl/fp32/strided_slice_fp32.c View File

@@ -55,7 +55,69 @@ void PadStridedSliceParameterTo8D(StridedSliceParameter *param) {

bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; }

int DoStridedSliceIntFp64Bool(const void *in_data, void *out_data, StridedSliceParameter *param) {
if (in_data == NULL || out_data == NULL || param == NULL) {
return NNACL_NULL_PTR;
}
if (param->num_axes_ > DIMENSION_8D) {
return NNACL_PARAM_INVALID;
}

int *begins = param->begins_;
int *ends = param->ends_;
int *strides = param->strides_;
int *in_shape = param->in_shape_;
if (param->num_axes_ < DIMENSION_8D) {
PadStridedSliceParameterTo8D(param);
}
size_t dim_offset[DIMENSION_8D - 1];
dim_offset[6] = in_shape[7];
dim_offset[5] = in_shape[6] * dim_offset[6];
dim_offset[4] = in_shape[5] * dim_offset[5];
dim_offset[3] = in_shape[4] * dim_offset[4];
dim_offset[2] = in_shape[3] * dim_offset[3];
dim_offset[1] = in_shape[2] * dim_offset[2];
dim_offset[0] = in_shape[1] * dim_offset[1];
size_t out_offset = 0;
int32_t dim0, dim1, dim2, dim3, dim4, dim5, dim6, dim7;
for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) {
for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) {
for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) {
for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) {
for (dim4 = begins[4]; LoopContinue(strides[4], dim4, ends[4]); dim4 += strides[4]) {
for (dim5 = begins[5]; LoopContinue(strides[5], dim5, ends[5]); dim5 += strides[5]) {
for (dim6 = begins[6]; LoopContinue(strides[6], dim6, ends[6]); dim6 += strides[6]) {
for (dim7 = begins[7]; LoopContinue(strides[7], dim7, ends[7]); dim7 += strides[7]) {
int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] +
dim3 * dim_offset[3] + dim4 * dim_offset[4] + dim5 * dim_offset[5] +
dim6 * dim_offset[6] + dim7;
if (param->data_type == kDataTypeInt) {
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeInt8) {
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeBool) {
*((bool *)out_data + out_offset) = *((bool *)in_data + in_offset);
} else if (param->data_type == kDataTypeFloat64) {
*((double *)out_data + out_offset) = *((double *)in_data + in_offset);
} else {
return NNACL_ERR;
}
out_offset++;
}
}
}
}
}
}
}
}
return NNACL_OK;
}

int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) {
if (param->data_type != kDataTypeFloat && param->data_type != kDataTypeFloat16) {
return DoStridedSliceIntFp64Bool(in_data, out_data, param);
}
if (in_data == NULL || out_data == NULL || param == NULL) {
return NNACL_NULL_PTR;
}
@@ -93,10 +155,6 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
dim6 * dim_offset[6] + dim7;
if (param->data_type == kDataTypeFloat) {
*((float *)out_data + out_offset) = *((float *)in_data + in_offset);
} else if (param->data_type == kDataTypeInt8) {
*((int8_t *)out_data + out_offset) = *((int8_t *)in_data + in_offset);
} else if (param->data_type == kDataTypeInt) {
*((int32_t *)out_data + out_offset) = *((int32_t *)in_data + in_offset);
#ifdef ENABLE_ARM64
} else if (param->data_type == kDataTypeFloat16) {
*((float16_t *)out_data + out_offset) = *((float16_t *)in_data + in_offset);


Loading…
Cancel
Save