|
|
|
@@ -77,14 +77,14 @@ void OpenCLSubGraph::ReplaceOutTensorAndKernelToConvert(const lite::Tensor *in_t |
|
|
|
for (auto &iv : in_kernels) { |
|
|
|
MS_ASSERT(iv); |
|
|
|
auto kernels = (mem_type == MemType::IMG) ? iv->in_kernels() : iv->out_kernels(); |
|
|
|
auto fk = std::find_if(kernels.begin(), kernels.end(), [&](kernel::LiteKernel *kv) { return kv == nullptr; }); |
|
|
|
auto fk = std::find_if(kernels.begin(), kernels.end(), [&](kernel::LiteKernel *kv) { return kv == iv; }); |
|
|
|
if (fk != kernels.end()) { |
|
|
|
*fk = in_convert_op; |
|
|
|
} else { |
|
|
|
kernels.emplace_back(in_convert_op); |
|
|
|
} |
|
|
|
auto tensors = (mem_type == MemType::IMG) ? iv->in_tensors() : iv->out_tensors(); |
|
|
|
auto ft = std::find_if(tensors.begin(), tensors.end(), [&](lite::Tensor *kv) { return kv == nullptr; }); |
|
|
|
auto ft = std::find_if(tensors.begin(), tensors.end(), [&](lite::Tensor *kv) { return kv == in_tensor; }); |
|
|
|
if (ft != tensors.end()) { |
|
|
|
*ft = new_tensor; |
|
|
|
} else { |
|
|
|
@@ -118,8 +118,6 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, |
|
|
|
GetKernelFromToTensor(in_tensors, nodes_, &loop_kernels, true); |
|
|
|
} |
|
|
|
|
|
|
|
ReplaceOutTensorAndKernelToNull(in_tensors, in_kernels, mem_type); |
|
|
|
|
|
|
|
for (size_t i = 0; i < in_tensors.size(); ++i) { |
|
|
|
auto *in_tensor = in_tensors.at(i); |
|
|
|
auto dst_format = (mem_type == MemType::IMG) ? schema::Format::Format_NHWC4 : schema::Format::Format_NHWC; |
|
|
|
|