|
|
|
@@ -20,6 +20,8 @@ |
|
|
|
#include "src/kernel_registry.h" |
|
|
|
#include "schema/model_generated.h" |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "src/runtime/runtime_api.h" |
|
|
|
#include "src/runtime/thread_pool.h" |
|
|
|
|
|
|
|
using mindspore::kernel::KERNEL_ARCH::kCPU; |
|
|
|
using mindspore::lite::KernelRegistrar; |
|
|
|
@@ -42,12 +44,7 @@ int ConcatCPUKernel::Init() { |
|
|
|
|
|
|
|
int ConcatCPUKernel::ReSize() { return ConcatBaseCPUKernel::ReSize(); } |
|
|
|
|
|
|
|
int ConcatCPUKernel::Run() { |
|
|
|
auto prepare_ret = Prepare(); |
|
|
|
if (prepare_ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; |
|
|
|
return prepare_ret; |
|
|
|
} |
|
|
|
int ConcatCPUKernel::DoConcat(int task_id) { |
|
|
|
auto input_num = in_tensors_.size(); |
|
|
|
std::vector<void *> inputs_addr(input_num, nullptr); |
|
|
|
std::vector<int *> inputs_output_shape(input_num + 1, nullptr); |
|
|
|
@@ -63,7 +60,27 @@ int ConcatCPUKernel::Run() { |
|
|
|
auto output_addr = out_tensors_.at(0)->MutableData(); |
|
|
|
|
|
|
|
Concat(reinterpret_cast<void **>(inputs_addr.data()), input_num, axis_, inputs_output_shape.data(), |
|
|
|
output_shape.size(), output_addr); |
|
|
|
output_shape.size(), output_addr, task_id, thread_count_); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int ConcatsRun(void *cdata, int task_id) { |
|
|
|
auto concat_kernel = reinterpret_cast<ConcatCPUKernel *>(cdata); |
|
|
|
auto error_code = concat_kernel->DoConcat(task_id); |
|
|
|
if (error_code != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ConcatsRun error task_id[" << task_id << "] error_code[" << error_code << "]"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int ConcatCPUKernel::Run() { |
|
|
|
auto prepare_ret = Prepare(); |
|
|
|
if (prepare_ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; |
|
|
|
return prepare_ret; |
|
|
|
} |
|
|
|
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, ConcatsRun, this, thread_count_); |
|
|
|
return error_code; |
|
|
|
} |
|
|
|
} // namespace mindspore::kernel |