From e2bd4eae6e1a0aeecd3c76be92f9cd0785401efc Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 31 Jan 2020 12:34:57 +0800 Subject: [PATCH] write shape as 4-number tuple --- src/net.cpp | 48 +++++++++--------------------------------- tools/ncnnoptimize.cpp | 26 ++++++----------------- 2 files changed, 17 insertions(+), 57 deletions(-) diff --git a/src/net.cpp b/src/net.cpp index 2c008524b..b5843a178 100644 --- a/src/net.cpp +++ b/src/net.cpp @@ -265,26 +265,12 @@ int Net::load_param(const DataReader& dr) { Blob& blob = blobs[layer->tops[j]]; - int dims = psh[0]; - blob.shape.dims = dims; + blob.shape.dims = psh[0]; + blob.shape.w = psh[1]; + blob.shape.h = psh[2]; + blob.shape.c = psh[3]; - if (dims == 1) - { - blob.shape.w = psh[1]; - } - if (dims == 2) - { - blob.shape.w = psh[1]; - blob.shape.h = psh[2]; - } - if (dims == 3) - { - blob.shape.w = psh[1]; - blob.shape.h = psh[2]; - blob.shape.c = psh[3]; - } - - psh += dims; + psh += 4; } } @@ -442,26 +428,12 @@ int Net::load_param_bin(const DataReader& dr) { Blob& blob = blobs[layer->tops[j]]; - int dims = psh[0]; - blob.shape.dims = dims; - - if (dims == 1) - { - blob.shape.w = psh[1]; - } - if (dims == 2) - { - blob.shape.w = psh[1]; - blob.shape.h = psh[2]; - } - if (dims == 3) - { - blob.shape.w = psh[1]; - blob.shape.h = psh[2]; - blob.shape.c = psh[3]; - } + blob.shape.dims = psh[0]; + blob.shape.w = psh[1]; + blob.shape.h = psh[2]; + blob.shape.c = psh[3]; - psh += dims; + psh += 4; } } diff --git a/tools/ncnnoptimize.cpp b/tools/ncnnoptimize.cpp index 2d45ebfbf..9c5e26ade 100644 --- a/tools/ncnnoptimize.cpp +++ b/tools/ncnnoptimize.cpp @@ -1984,43 +1984,31 @@ int NetOptimize::save(const char* parampath, const char* binpath) } // write shape hints - int shape_hint_array_size = 0; + bool shape_ready = true; for (int j=0; jtops[j]; + int dims = blobs[top_blob_index].shape.dims; if (dims == 0) { - shape_hint_array_size = 0; + shape_ready = false; break; } - - shape_hint_array_size += dims + 1; } - if (shape_hint_array_size) + if (shape_ready) { - fprintf(pp, " -23330=%d", shape_hint_array_size); + fprintf(pp, " -23330=%d", top_count*4); for (int j=0; jtops[j]; + int dims = blobs[top_blob_index].shape.dims; int w = blobs[top_blob_index].shape.w; int h = blobs[top_blob_index].shape.h; int c = blobs[top_blob_index].shape.c; - fprintf(pp, ",%d", dims); - if (dims == 1) - { - fprintf(pp, ",%d", w); - } - if (dims == 2) - { - fprintf(pp, ",%d,%d", w, h); - } - if (dims == 3) - { - fprintf(pp, ",%d,%d,%d", w, h, c); - } + fprintf(pp, ",%d,%d,%d,%d", dims, w, h, c); } }