Browse Source

write shape as 4-number tuple

tags/20200226
nihui 6 years ago
parent
commit
e2bd4eae6e
2 changed files with 17 additions and 57 deletions
  1. +10
    -38
      src/net.cpp
  2. +7
    -19
      tools/ncnnoptimize.cpp

+ 10
- 38
src/net.cpp View File

@@ -265,26 +265,12 @@ int Net::load_param(const DataReader& dr)
{
Blob& blob = blobs[layer->tops[j]];

int dims = psh[0];
blob.shape.dims = dims;
blob.shape.dims = psh[0];
blob.shape.w = psh[1];
blob.shape.h = psh[2];
blob.shape.c = psh[3];

if (dims == 1)
{
blob.shape.w = psh[1];
}
if (dims == 2)
{
blob.shape.w = psh[1];
blob.shape.h = psh[2];
}
if (dims == 3)
{
blob.shape.w = psh[1];
blob.shape.h = psh[2];
blob.shape.c = psh[3];
}

psh += dims;
psh += 4;
}
}

@@ -442,26 +428,12 @@ int Net::load_param_bin(const DataReader& dr)
{
Blob& blob = blobs[layer->tops[j]];

int dims = psh[0];
blob.shape.dims = dims;

if (dims == 1)
{
blob.shape.w = psh[1];
}
if (dims == 2)
{
blob.shape.w = psh[1];
blob.shape.h = psh[2];
}
if (dims == 3)
{
blob.shape.w = psh[1];
blob.shape.h = psh[2];
blob.shape.c = psh[3];
}
blob.shape.dims = psh[0];
blob.shape.w = psh[1];
blob.shape.h = psh[2];
blob.shape.c = psh[3];

psh += dims;
psh += 4;
}
}



+ 7
- 19
tools/ncnnoptimize.cpp View File

@@ -1984,43 +1984,31 @@ int NetOptimize::save(const char* parampath, const char* binpath)
}

// write shape hints
int shape_hint_array_size = 0;
bool shape_ready = true;
for (int j=0; j<top_count; j++)
{
int top_blob_index = layer->tops[j];

int dims = blobs[top_blob_index].shape.dims;
if (dims == 0)
{
shape_hint_array_size = 0;
shape_ready = false;
break;
}

shape_hint_array_size += dims + 1;
}
if (shape_hint_array_size)
if (shape_ready)
{
fprintf(pp, " -23330=%d", shape_hint_array_size);
fprintf(pp, " -23330=%d", top_count*4);
for (int j=0; j<top_count; j++)
{
int top_blob_index = layer->tops[j];

int dims = blobs[top_blob_index].shape.dims;
int w = blobs[top_blob_index].shape.w;
int h = blobs[top_blob_index].shape.h;
int c = blobs[top_blob_index].shape.c;
fprintf(pp, ",%d", dims);

if (dims == 1)
{
fprintf(pp, ",%d", w);
}
if (dims == 2)
{
fprintf(pp, ",%d,%d", w, h);
}
if (dims == 3)
{
fprintf(pp, ",%d,%d,%d", w, h, c);
}
fprintf(pp, ",%d,%d,%d,%d", dims, w, h, c);
}
}



Loading…
Cancel
Save