|
|
|
@@ -221,42 +221,61 @@ void LiteKernel::FindInoutKernels(const std::vector<kernel::LiteKernel *> &scope |
|
|
|
|
|
|
|
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphInputKernels( |
|
|
|
const std::vector<kernel::LiteKernel *> &kernels) { |
|
|
|
std::vector<kernel::LiteKernel *> input_kernels; |
|
|
|
std::set<kernel::LiteKernel *> input_kernels; |
|
|
|
for (const auto &kernel : kernels) { |
|
|
|
// if kernel has no pre-kernel, kernel is a graph input, it must be a subgraph input |
|
|
|
if (kernel->in_kernels().empty() && !kernel->in_tensors().empty()) { |
|
|
|
input_kernels.emplace_back(kernel); |
|
|
|
input_kernels.insert(kernel); |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (const auto &input : kernel->in_kernels()) { |
|
|
|
auto in_kernel_in_graph = std::find(kernels.begin(), kernels.end(), input); |
|
|
|
auto in_kernel_in_ret = std::find(input_kernels.begin(), input_kernels.end(), kernel); |
|
|
|
if (in_kernel_in_graph == kernels.end() && in_kernel_in_ret == input_kernels.end()) { |
|
|
|
input_kernels.emplace_back(kernel); |
|
|
|
break; |
|
|
|
auto all_input_tensors = kernel->in_tensors(); |
|
|
|
// remove all const tensor from input tensors |
|
|
|
for (auto iter = all_input_tensors.begin(); iter != all_input_tensors.end();) { |
|
|
|
if ((*iter)->IsConst()) { |
|
|
|
iter = all_input_tensors.erase(iter); |
|
|
|
} else { |
|
|
|
iter++; |
|
|
|
} |
|
|
|
} |
|
|
|
for (const auto &kernel_in_subgraph : kernels) { |
|
|
|
// remove input tensors from kernel in subgraph |
|
|
|
for (const auto *tensor : kernel_in_subgraph->out_tensors()) { |
|
|
|
auto ret = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor); |
|
|
|
if (ret != all_input_tensors.end()) { |
|
|
|
all_input_tensors.erase(ret); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
// if some input tensor is not from kernel in subgraph |
|
|
|
if (!all_input_tensors.empty()) { |
|
|
|
input_kernels.insert(kernel); |
|
|
|
} |
|
|
|
} |
|
|
|
return input_kernels; |
|
|
|
std::vector<kernel::LiteKernel *> result; |
|
|
|
result.insert(result.end(), input_kernels.begin(), input_kernels.end()); |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels( |
|
|
|
const std::vector<kernel::LiteKernel *> &kernels) { |
|
|
|
std::vector<kernel::LiteKernel *> output_kernels; |
|
|
|
std::set<kernel::LiteKernel *> output_kernels; |
|
|
|
// if kernel has no post-kernel, kernel is a graph output, it must be a subgraph output |
|
|
|
for (const auto &kernel : kernels) { |
|
|
|
if (kernel->out_kernels().empty() && !kernel->out_tensors().empty()) { |
|
|
|
output_kernels.emplace_back(kernel); |
|
|
|
if (kernel->is_model_output() || (kernel->out_kernels().empty() && !kernel->out_tensors().empty())) { |
|
|
|
output_kernels.insert(kernel); |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (const auto &output : kernel->out_kernels()) { |
|
|
|
auto out_kernel_in_graph = std::find(kernels.begin(), kernels.end(), output); |
|
|
|
auto out_kernel_in_ret = std::find(output_kernels.begin(), output_kernels.end(), kernel); |
|
|
|
if (out_kernel_in_graph == kernels.end() && out_kernel_in_ret == output_kernels.end()) { |
|
|
|
output_kernels.emplace_back(kernel); |
|
|
|
if (out_kernel_in_graph == kernels.end()) { |
|
|
|
output_kernels.insert(kernel); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return output_kernels; |
|
|
|
std::vector<kernel::LiteKernel *> result; |
|
|
|
result.insert(result.end(), output_kernels.begin(), output_kernels.end()); |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) { |
|
|
|
|