|
|
|
@@ -22,16 +22,17 @@ bool StridedSliceCheckInputs(const TensorC *const *inputs, size_t inputs_size) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return true; // note: the original code is ndim_ <= in_shape_size |
|
|
|
} |
|
|
|
|
|
|
|
int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size, |
|
|
|
OpParameter *parameter) { |
|
|
|
#ifdef Debug |
|
|
|
int check_ret = CheckAugmentNullSize(inputs, inputs_size, outputs, outputs_size, parameter, 5, 1); |
|
|
|
if (check_ret != NNACL_OK) { |
|
|
|
return check_ret; |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
const TensorC *input = inputs[0]; |
|
|
|
SetDataTypeFormat(outputs[0], input); |
|
|
|
@@ -58,9 +59,7 @@ int StridedSliceGradInferShape(const TensorC *const *inputs, size_t inputs_size, |
|
|
|
int *begin_data = (int *)(begin_tensor->data_); |
|
|
|
int *end_data = (int *)(inputs[3]->data_); |
|
|
|
int *stride_data = (int *)(inputs[4]->data_); |
|
|
|
if (begin_data == NULL || end_data == NULL || stride_data == NULL) { |
|
|
|
return NNACL_ERR; |
|
|
|
} |
|
|
|
|
|
|
|
size_t ndim_ = GetElementNum(begin_tensor); |
|
|
|
for (int i = 0; i < ndim_; ++i) { |
|
|
|
ShapePush(begins_, &begins_size, begin_data[i]); |
|
|
|
|