From: @yangruoqi713 Reviewed-by: @zhanghaibo5,@zhang_xue_tong Signed-off-by: @zhang_xue_tongpull/13751/MERGE
| @@ -119,8 +119,8 @@ int GatherFp16CPUKernel::DoGather(int task_id) { | |||||
| } | } | ||||
| int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data_c()); | int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data_c()); | ||||
| int data_size = lite::DataTypeSize(kNumberTypeFloat16); | int data_size = lite::DataTypeSize(kNumberTypeFloat16); | ||||
| int8_in += thread_stride * limit * data_size; | |||||
| int8_out += thread_stride * indices_element_size * data_size; | |||||
| int8_in += thread_stride * limit * inner_size * data_size; | |||||
| int8_out += thread_stride * indices_element_size * inner_size * data_size; | |||||
| int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size); | int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size); | ||||
| return error_code; | return error_code; | ||||
| } | } | ||||
| @@ -65,8 +65,8 @@ int GatherCPUKernel::DoGather(int task_id) { | |||||
| int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data_c()); | int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data_c()); | ||||
| int data_size = lite::DataTypeSize(input_tensor->data_type()); | int data_size = lite::DataTypeSize(input_tensor->data_type()); | ||||
| int8_in += thread_stride * limit * data_size; | |||||
| int8_out += thread_stride * indices_element_size * data_size; | |||||
| int8_in += thread_stride * limit * inner_size * data_size; | |||||
| int8_out += thread_stride * indices_element_size * inner_size * data_size; | |||||
| int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size); | int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size); | ||||
| @@ -82,8 +82,8 @@ int GatherInt8CPUKernel::DoGather(int task_id) { | |||||
| int count = MSMIN(stride, outer_size - stride * task_id); | int count = MSMIN(stride, outer_size - stride * task_id); | ||||
| auto thread_stride = stride * task_id; | auto thread_stride = stride * task_id; | ||||
| input_ptr += thread_stride * limit; | |||||
| output_ptr += thread_stride * indices_element_size; | |||||
| input_ptr += thread_stride * inner_size * limit; | |||||
| output_ptr += thread_stride * inner_size * indices_element_size; | |||||
| return GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_); | return GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_); | ||||
| } | } | ||||