|
|
|
@@ -339,8 +339,7 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { |
|
|
|
} |
|
|
|
// when merge is removed, this if is removed automatically |
|
|
|
if (kernel->Type() == schema::PrimitiveType_Merge) { |
|
|
|
MS_ASSERT(kernel->in_kernels().size() == 2); |
|
|
|
return (is_kernel_finish[kernel->in_kernels().at(0)] || is_kernel_finish[kernel->in_kernels().at(1)]); |
|
|
|
return MergeOpIsReady(kernel, is_kernel_finish); |
|
|
|
} else { |
|
|
|
return std::all_of(kernel_inputs.begin(), kernel_inputs.end(), |
|
|
|
[&](kernel::LiteKernel *kernel) { return is_kernel_finish[kernel]; }); |
|
|
|
@@ -370,6 +369,28 @@ int Scheduler::ConstructSubGraphs(std::vector<kernel::LiteKernel *> *kernels) { |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
bool Scheduler::MergeOpIsReady(const kernel::LiteKernel *kernel, |
|
|
|
std::map<const kernel::LiteKernel *, bool> is_kernel_finish) { |
|
|
|
std::map<const lite::Tensor *, bool> merge_in_tensors_map; |
|
|
|
for (auto merge_in_tensor : kernel->in_tensors()) { |
|
|
|
merge_in_tensors_map[merge_in_tensor] = false; |
|
|
|
if (merge_in_tensor->category() == Tensor::CONST_TENSOR || merge_in_tensor->category() == Tensor::CONST_SCALAR) { |
|
|
|
merge_in_tensors_map[merge_in_tensor] = true; |
|
|
|
} |
|
|
|
for (auto merge_in_kernel : kernel->in_kernels()) { |
|
|
|
for (auto tensor : merge_in_kernel->out_tensors()) { |
|
|
|
if (tensor == merge_in_tensor && is_kernel_finish[merge_in_kernel]) { |
|
|
|
merge_in_tensors_map[merge_in_tensor] = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
auto kernel_in_tensors_num = kernel->in_tensors().size(); |
|
|
|
return std::all_of(kernel->in_tensors().begin(), kernel->in_tensors().begin() + kernel_in_tensors_num / 2, |
|
|
|
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }) || |
|
|
|
std::all_of(kernel->in_tensors().begin() + kernel_in_tensors_num / 2, kernel->in_tensors().end(), |
|
|
|
[&](lite::Tensor *in_tensor) { return merge_in_tensors_map[in_tensor]; }); |
|
|
|
} |
|
|
|
|
|
|
|
kernel::SubGraphKernel *Scheduler::CreateSubGraphKernel(const std::vector<kernel::LiteKernel *> &kernels, |
|
|
|
kernel::SubGraphType type) { |
|
|
|
|