From ce65edcc84d52def8e2ab7d41d47ffb327e62266 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 12 Mar 2019 16:35:16 +0800 Subject: [PATCH] fix flatten pack1to4 --- src/layer/flatten.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/layer/flatten.cpp b/src/layer/flatten.cpp index f3f40ccb5..2a0fc5855 100644 --- a/src/layer/flatten.cpp +++ b/src/layer/flatten.cpp @@ -138,7 +138,7 @@ int Flatten::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, constants[3].i = bottom_blob.c; constants[4].i = bottom_blob.cstep; constants[5].i = top_blob.dims; - constants[6].i = (packing == 1 && out_packing == 4) ? top_blob.w * out_packing : top_blob.w; + constants[6].i = (packing == 1 && out_packing == 4) ? total : top_blob.w; constants[7].i = top_blob.h; constants[8].i = top_blob.c; constants[9].i = top_blob.cstep; @@ -148,7 +148,18 @@ int Flatten::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& cmd, // record cmd.record_prepare_compute_barrier(bottom_blob); cmd.record_prepare_compute_barrier(top_blob); - cmd.record_pipeline(pipeline, bindings, constants, top_blob); + if (packing == 1 && out_packing == 4) + { + VkMat dispatcher; + dispatcher.w = total; + dispatcher.h = 1; + dispatcher.c = 1; + cmd.record_pipeline(pipeline, bindings, constants, dispatcher); + } + else + { + cmd.record_pipeline(pipeline, bindings, constants, top_blob); + } return 0; }