|
|
|
@@ -15,6 +15,7 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/runtime/agent/npu/npu_executor.h" |
|
|
|
#include <unordered_map> |
|
|
|
#include "include/errorcode.h" |
|
|
|
#include "src/runtime/agent/npu/npu_manager.h" |
|
|
|
#include "nnacl/pack.h" |
|
|
|
@@ -100,15 +101,24 @@ bool IsSameShapeOutTensor(Tensor *tensor, std::shared_ptr<hiai::AiTensor> npu_te |
|
|
|
} |
|
|
|
|
|
|
|
int NPUExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector<Tensor *> &out_tensors, |
|
|
|
const std::vector<kernel::LiteKernel *> &out_kernels, |
|
|
|
const std::vector<kernel::LiteKernel *> &in_kernels, |
|
|
|
const std::vector<kernel::LiteKernel *> &kernels, Allocator *allocator, |
|
|
|
const KernelCallBack &before, const KernelCallBack &after) { |
|
|
|
hiai::AiContext context; |
|
|
|
std::vector<bool> inputs_visited(in_tensors.size(), false); |
|
|
|
std::unordered_map<lite::Tensor *, int> tensor_uses; |
|
|
|
for (const auto ker : in_kernels) { |
|
|
|
for (const auto ker_input : ker->in_tensors()) { |
|
|
|
if (tensor_uses.find(ker_input) == tensor_uses.end()) { |
|
|
|
tensor_uses.insert({ker_input, 1}); |
|
|
|
} else { |
|
|
|
tensor_uses[ker_input]++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (int i = 0; i < npu_input_tensors_.size(); ++i) { |
|
|
|
int index = 0; |
|
|
|
for (; index < in_tensors.size(); index++) { |
|
|
|
if (!inputs_visited[index] && IsSameShapeInTensor(in_tensors[index], npu_input_tensors_[i])) { |
|
|
|
if (tensor_uses[in_tensors[index]] > 0 && IsSameShapeInTensor(in_tensors[index], npu_input_tensors_[i])) { |
|
|
|
void *data = in_tensors[index]->data_c(); |
|
|
|
if (data == nullptr) { |
|
|
|
MS_LOG(ERROR) << "For " << model_name_ << ", the " << i << "th input data is nullptr"; |
|
|
|
@@ -116,7 +126,7 @@ int NPUExecutor::Run(const std::vector<Tensor *> &in_tensors, const std::vector< |
|
|
|
} |
|
|
|
|
|
|
|
memcpy(npu_input_tensors_[i]->GetBuffer(), data, in_tensors[index]->Size()); |
|
|
|
inputs_visited[index] = true; |
|
|
|
tensor_uses[in_tensors[index]]--; |
|
|
|
in_tensors[index]->DecRefCount(); |
|
|
|
break; |
|
|
|
} |
|
|
|
|