|
|
|
@@ -37,13 +37,6 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te |
|
|
|
MS_ASSERT(in_tensors.size() == from_kernels.size()); |
|
|
|
for (auto &iv : in_kernels) { |
|
|
|
for (auto &jv : iv) { |
|
|
|
OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(jv); |
|
|
|
schema::Format out_ori_format = cur_opencl_op->GetOutOriFormat(); |
|
|
|
auto tens = cur_opencl_op->out_tensors(); |
|
|
|
if (mem_type == OpenCLMemType::BUF && mem_type == cur_opencl_op->GetMemType() && |
|
|
|
tens[0]->GetFormat() == out_ori_format) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (mem_type == OpenCLMemType::IMG) { |
|
|
|
jv->set_in_tensors({}); |
|
|
|
jv->SetInKernel({}); |
|
|
|
@@ -70,17 +63,8 @@ int SubGraphOpenCLKernel::GenToFormatOp(const std::vector<lite::Tensor *> &in_te |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
OpenCLKernel *cur_opencl_op = reinterpret_cast<OpenCLKernel *>(in_kernels[i][0]); |
|
|
|
schema::Format out_ori_format = cur_opencl_op->GetOutOriFormat(); |
|
|
|
schema::Format in_ori_format = cur_opencl_op->GetInOriFormat(); |
|
|
|
if (mem_type == OpenCLMemType::BUF && mem_type == cur_opencl_op->GetMemType() && |
|
|
|
in_tensors[i]->GetFormat() == out_ori_format) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto dst_format = |
|
|
|
(mem_type == OpenCLMemType::IMG) ? in_kernels[i][0]->in_tensors()[0]->GetFormat() : out_ori_format; |
|
|
|
auto src_format = |
|
|
|
(mem_type == OpenCLMemType::IMG) ? in_ori_format : in_kernels[i][0]->out_tensors()[0]->GetFormat(); |
|
|
|
auto dst_format = (mem_type == OpenCLMemType::IMG) ? schema::Format::Format_NHWC4 : schema::Format::Format_NHWC; |
|
|
|
auto src_format = (mem_type == OpenCLMemType::IMG) ? schema::Format::Format_NHWC : schema::Format::Format_NHWC4; |
|
|
|
lite::Tensor *new_tensor = new (std::nothrow) lite::Tensor(); |
|
|
|
MS_ASSERT(new_tensor); |
|
|
|
if (new_tensor == nullptr) { |
|
|
|
|