| @@ -47,7 +47,7 @@ void PadSliceParameterTo4D(SliceParameter *param) { | |||||
| param->param_length_ = DIMENSION_4D; | param->param_length_ = DIMENSION_4D; | ||||
| } | } | ||||
| void DoSlice(const float *input, float *output, SliceParameter *param) { | |||||
| void DoSlice(const float *input, float *output, SliceParameter *param, int thread_id) { | |||||
| int32_t out_dim1 = param->size_[1]; | int32_t out_dim1 = param->size_[1]; | ||||
| int32_t out_dim2 = param->size_[2]; | int32_t out_dim2 = param->size_[2]; | ||||
| int32_t out_dim3 = param->size_[3]; | int32_t out_dim3 = param->size_[3]; | ||||
| @@ -55,7 +55,6 @@ void DoSlice(const float *input, float *output, SliceParameter *param) { | |||||
| size_t out_stride1 = out_stride2 * out_dim2; | size_t out_stride1 = out_stride2 * out_dim2; | ||||
| size_t out_stride0 = out_stride1 * out_dim1; | size_t out_stride0 = out_stride1 * out_dim1; | ||||
| size_t count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_); | size_t count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_); | ||||
| int thread_id = param->thread_id_; | |||||
| size_t thread_stride = thread_id * count_per_thread; | size_t thread_stride = thread_id * count_per_thread; | ||||
| size_t copy_size = param->size_[3] * sizeof(float); | size_t copy_size = param->size_[3] * sizeof(float); | ||||
| size_t in_stride2 = param->shape_[3]; | size_t in_stride2 = param->shape_[3]; | ||||
| @@ -23,7 +23,7 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void PadSliceParameterTo4D(SliceParameter *param); | void PadSliceParameterTo4D(SliceParameter *param); | ||||
| void DoSlice(const float *input, float *output, SliceParameter *param); | |||||
| void DoSlice(const float *input, float *output, SliceParameter *param, int thread_id); | |||||
| void DoSliceNoParallel(const float *input, float *output, SliceParameter *param); | void DoSliceNoParallel(const float *input, float *output, SliceParameter *param); | ||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *par | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) { | |||||
| int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param, int thread_id) { | |||||
| double input_scale = param->quant_arg_.in_args_.scale_; | double input_scale = param->quant_arg_.in_args_.scale_; | ||||
| int input_zp = param->quant_arg_.in_args_.zp_; | int input_zp = param->quant_arg_.in_args_.zp_; | ||||
| double output_scale = param->quant_arg_.out_args_.scale_; | double output_scale = param->quant_arg_.out_args_.scale_; | ||||
| @@ -81,7 +81,6 @@ int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param) { | |||||
| int out_stride1 = out_stride2 * out_dim2; | int out_stride1 = out_stride2 * out_dim2; | ||||
| int out_stride0 = out_stride1 * out_dim1; | int out_stride0 = out_stride1 * out_dim1; | ||||
| int count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_); | int count_per_thread = UP_DIV(out_dim1, param->op_parameter_.thread_num_); | ||||
| int thread_id = param->thread_id_; | |||||
| int thread_stride = thread_id * count_per_thread; | int thread_stride = thread_id * count_per_thread; | ||||
| int unit_size = param->size_[3] * sizeof(int8_t); | int unit_size = param->size_[3] * sizeof(int8_t); | ||||
| int in_stride2 = param->shape_[3]; | int in_stride2 = param->shape_[3]; | ||||
| @@ -23,7 +23,7 @@ | |||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param); | int SliceInt8NoParallel(const int8_t *input, int8_t *output, SliceParameter *param); | ||||
| int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param); | |||||
| int SliceInt8(const int8_t *input, int8_t *output, SliceParameter *param, int thread_id); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -30,7 +30,6 @@ typedef struct SliceParameter { | |||||
| int32_t size_[SLICE_SHAPE_MAX_SIZE]; | int32_t size_[SLICE_SHAPE_MAX_SIZE]; | ||||
| int32_t shape_[SLICE_SHAPE_MAX_SIZE]; | int32_t shape_[SLICE_SHAPE_MAX_SIZE]; | ||||
| int32_t param_length_; | int32_t param_length_; | ||||
| int32_t thread_id_; | |||||
| } SliceParameter; | } SliceParameter; | ||||
| #endif // MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_ | #endif // MINDSPORE_LITE_NNACL_SLICE_PARAMETER_H_ | ||||
| @@ -78,7 +78,7 @@ int SliceCPUKernel::SliceParallelRun(int thread_id) { | |||||
| const float *input_data = reinterpret_cast<const float *>(in_tensors_[0]->Data()); | const float *input_data = reinterpret_cast<const float *>(in_tensors_[0]->Data()); | ||||
| float *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data()); | float *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data()); | ||||
| SliceParameter *param = reinterpret_cast<SliceParameter *>(op_parameter_); | SliceParameter *param = reinterpret_cast<SliceParameter *>(op_parameter_); | ||||
| DoSlice(input_data, output_data, param); | |||||
| DoSlice(input_data, output_data, param, thread_id); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -60,8 +60,7 @@ int SliceInt8CPUKernel::DoSlice(int task_id) { | |||||
| const int8_t *input_data = reinterpret_cast<const int8_t *>(in_tensors_[0]->Data()); | const int8_t *input_data = reinterpret_cast<const int8_t *>(in_tensors_[0]->Data()); | ||||
| int8_t *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->Data()); | int8_t *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->Data()); | ||||
| param_->thread_id_ = task_id; | |||||
| auto ret = SliceInt8(input_data, output_data, param_); | |||||
| auto ret = SliceInt8(input_data, output_data, param_, task_id); | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| MS_LOG(ERROR) << "SliceInt8 error ,task_id[" << task_id << "] error_code[" << ret << "]"; | MS_LOG(ERROR) << "SliceInt8 error ,task_id[" << task_id << "] error_code[" << ret << "]"; | ||||
| } | } | ||||