diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc index f1b06868d3..6d7e0bab37 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc @@ -119,8 +119,8 @@ int GatherFp16CPUKernel::DoGather(int task_id) { } int8_t *int8_out = reinterpret_cast(out_tensor->data_c()); int data_size = lite::DataTypeSize(kNumberTypeFloat16); - int8_in += thread_stride * limit * data_size; - int8_out += thread_stride * indices_element_size * data_size; + int8_in += thread_stride * limit * inner_size * data_size; + int8_out += thread_stride * indices_element_size * inner_size * data_size; int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size); return error_code; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc index f2953903a0..0a2d6c80cd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather_fp32.cc @@ -65,8 +65,8 @@ int GatherCPUKernel::DoGather(int task_id) { int8_t *int8_out = reinterpret_cast(out_tensor->data_c()); int data_size = lite::DataTypeSize(input_tensor->data_type()); - int8_in += thread_stride * limit * data_size; - int8_out += thread_stride * indices_element_size * data_size; + int8_in += thread_stride * limit * inner_size * data_size; + int8_out += thread_stride * indices_element_size * inner_size * data_size; int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc index 0f2118e4db..844425633a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/gather_int8.cc @@ -82,8 +82,8 @@ int GatherInt8CPUKernel::DoGather(int task_id) { int count = MSMIN(stride, outer_size - stride * task_id); auto thread_stride = stride * task_id; - input_ptr += thread_stride * limit; - output_ptr += thread_stride * indices_element_size; + input_ptr += thread_stride * inner_size * limit; + output_ptr += thread_stride * inner_size * indices_element_size; return GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_); }