|
|
|
@@ -181,17 +181,30 @@ std::vector<kernel::LiteKernel *> LiteKernelUtil::SubgraphOutputKernels( |
|
|
|
|
|
|
|
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vector<kernel::LiteKernel *> &kernels) { |
|
|
|
std::vector<lite::Tensor *> input_tensors; |
|
|
|
std::vector<lite::Tensor *> all_output_tensors; |
|
|
|
for (const auto &kernel : kernels) { |
|
|
|
auto kernel_out_tensors = kernel->out_tensors(); |
|
|
|
all_output_tensors.insert(all_output_tensors.end(), kernel_out_tensors.begin(), kernel_out_tensors.end()); |
|
|
|
} |
|
|
|
std::vector<kernel::LiteKernel *> input_kernels = SubgraphInputKernels(kernels); |
|
|
|
for (const auto &kernel : input_kernels) { |
|
|
|
for (const auto &tensor : kernel->in_tensors()) { |
|
|
|
auto iter = std::find(all_output_tensors.begin(), all_output_tensors.end(), tensor); |
|
|
|
if (iter == all_output_tensors.end() && !tensor->IsConst()) { |
|
|
|
input_tensors.emplace_back(tensor); |
|
|
|
for (const auto &input_kernel : input_kernels) { |
|
|
|
auto &outer_in_kernels = input_kernel->in_kernels(); |
|
|
|
auto &in_kernel_in_tensors = input_kernel->in_tensors(); |
|
|
|
if (outer_in_kernels.empty()) { |
|
|
|
for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { |
|
|
|
if (!in_kernel_in_tensor->IsConst()) { |
|
|
|
input_tensors.push_back(in_kernel_in_tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (auto outer_in_kernel : outer_in_kernels) { |
|
|
|
auto iter = std::find(kernels.begin(), kernels.end(), outer_in_kernel); |
|
|
|
if (iter != kernels.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &outer_in_kernel_out_tensors = outer_in_kernel->out_tensors(); |
|
|
|
for (auto in_kernel_in_tensor : in_kernel_in_tensors) { |
|
|
|
auto outer_in_kernel_out_tensors_iter = |
|
|
|
std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); |
|
|
|
if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { |
|
|
|
input_tensors.emplace_back(in_kernel_in_tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -200,17 +213,26 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect |
|
|
|
|
|
|
|
std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vector<kernel::LiteKernel *> &kernels) { |
|
|
|
std::vector<lite::Tensor *> output_tensors; |
|
|
|
std::vector<lite::Tensor *> all_input_tensors; |
|
|
|
for (const auto &kernel : kernels) { |
|
|
|
auto kernel_in_tensors = kernel->in_tensors(); |
|
|
|
all_input_tensors.insert(all_input_tensors.end(), kernel_in_tensors.begin(), kernel_in_tensors.end()); |
|
|
|
} |
|
|
|
std::vector<kernel::LiteKernel *> output_kernels = SubgraphOutputKernels(kernels); |
|
|
|
for (const auto &kernel : output_kernels) { |
|
|
|
for (const auto &tensor : kernel->out_tensors()) { |
|
|
|
auto iter = std::find(all_input_tensors.begin(), all_input_tensors.end(), tensor); |
|
|
|
if (iter == all_input_tensors.end()) { |
|
|
|
output_tensors.emplace_back(tensor); |
|
|
|
for (const auto &output_kernel : output_kernels) { |
|
|
|
auto &outer_out_kernels = output_kernel->out_kernels(); |
|
|
|
auto &out_kernel_out_tensors = output_kernel->out_tensors(); |
|
|
|
if (outer_out_kernels.empty()) { |
|
|
|
output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
for (auto outer_out_kernel : outer_out_kernels) { |
|
|
|
auto iter = std::find(kernels.begin(), kernels.end(), outer_out_kernel); |
|
|
|
if (iter != kernels.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &outer_out_kernel_in_tensors = outer_out_kernel->in_tensors(); |
|
|
|
for (auto out_kernel_out_tensor : out_kernel_out_tensors) { |
|
|
|
auto outer_out_kernel_in_tensors_iter = |
|
|
|
std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); |
|
|
|
if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { |
|
|
|
output_tensors.emplace_back(out_kernel_out_tensor); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|