| @@ -17,7 +17,8 @@ | |||||
| #include "nnacl/fp32/concat.h" | #include "nnacl/fp32/concat.h" | ||||
| #include <string.h> | #include <string.h> | ||||
| void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output) { | |||||
| void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, | |||||
| int task_id, int thread_num) { | |||||
| int before_axis_size = 1; | int before_axis_size = 1; | ||||
| for (int i = 0; i < axis; ++i) { | for (int i = 0; i < axis; ++i) { | ||||
| before_axis_size *= inputs_output_shape[0][i]; | before_axis_size *= inputs_output_shape[0][i]; | ||||
| @@ -33,10 +34,12 @@ void Concat(void **input, int input_num, int axis, int **inputs_output_shape, si | |||||
| for (int i = 0; i < input_num; ++i) { | for (int i = 0; i < input_num; ++i) { | ||||
| uint8_t *src_base = (input[i]); | uint8_t *src_base = (input[i]); | ||||
| size_t input_stride = after_axis_size * inputs_output_shape[i][axis]; | size_t input_stride = after_axis_size * inputs_output_shape[i][axis]; | ||||
| for (int j = 0; j < before_axis_size; ++j) { | |||||
| uint8_t *src = src_base + j * input_stride; | |||||
| uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size; | |||||
| memcpy(dst, src, input_stride); | |||||
| int offset = UP_DIV(input_stride, thread_num); | |||||
| int count = MSMIN(offset, input_stride - offset * task_id); | |||||
| for (int j = 0; j < before_axis_size; j++) { | |||||
| uint8_t *src = src_base + j * input_stride + task_id * offset; | |||||
| uint8_t *dst = dst_base + j * output_stride + axis_offset * after_axis_size + task_id * offset; | |||||
| memcpy(dst, src, count); | |||||
| } | } | ||||
| axis_offset += inputs_output_shape[i][axis]; | axis_offset += inputs_output_shape[i][axis]; | ||||
| } | } | ||||
| @@ -22,7 +22,8 @@ | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| extern "C" { | extern "C" { | ||||
| #endif | #endif | ||||
| void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output); | |||||
| void Concat(void **input, int input_num, int axis, int **inputs_output_shape, size_t shape_size, void *output, | |||||
| int task_id, int thread_num); | |||||
| #ifdef __cplusplus | #ifdef __cplusplus | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -20,6 +20,8 @@ | |||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/runtime/runtime_api.h" | |||||
| #include "src/runtime/thread_pool.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | using mindspore::kernel::KERNEL_ARCH::kCPU; | ||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| @@ -42,12 +44,7 @@ int ConcatCPUKernel::Init() { | |||||
| int ConcatCPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); } | int ConcatCPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); } | ||||
| int ConcatCPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| int ConcatCPUKernel::DoConcat(int task_id) { | |||||
| auto input_num = in_tensors_.size(); | auto input_num = in_tensors_.size(); | ||||
| std::vector<void *> inputs_addr(input_num, nullptr); | std::vector<void *> inputs_addr(input_num, nullptr); | ||||
| std::vector<int *> inputs_output_shape(input_num + 1, nullptr); | std::vector<int *> inputs_output_shape(input_num + 1, nullptr); | ||||
| @@ -63,7 +60,27 @@ int ConcatCPUKernel::Run() { | |||||
| auto output_addr = out_tensors_.at(0)->MutableData(); | auto output_addr = out_tensors_.at(0)->MutableData(); | ||||
| Concat(reinterpret_cast<void **>(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), | Concat(reinterpret_cast<void **>(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), | ||||
| output_shape.size(), output_addr); | |||||
| output_shape.size(), output_addr, task_id, thread_count_); | |||||
| return RET_OK; | |||||
| } | |||||
| int ConcatsRun(void *cdata, int task_id) { | |||||
| auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata); | |||||
| auto error_code = concat_kernel->DoConcat(task_id); | |||||
| if (error_code != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConcatsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ConcatCPUKernel::Run() { | |||||
| auto prepare_ret = Prepare(); | |||||
| if (prepare_ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; | |||||
| return prepare_ret; | |||||
| } | |||||
| int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ConcatsRun, this, thread_count_); | |||||
| return error_code; | |||||
| } | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -35,9 +35,8 @@ class ConcatCPUKernel : public ConcatBaseCPUKernel { | |||||
| ~ConcatCPUKernel() = default; | ~ConcatCPUKernel() = default; | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int DoConcat(int task_id); | |||||
| int Run() override; | int Run() override; | ||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||