Browse Source

padding_mode for Pooling, fix #261

tags/20180314
nihuini 8 years ago
parent
commit
db5e805eff
5 changed files with 29 additions and 17 deletions
  1. +6
    -6
      src/layer/arm/pooling_arm.cpp
  2. +7
    -6
      src/layer/pooling.cpp
  3. +1
    -0
      src/layer/pooling.h
  4. +7
    -1
      tools/mxnet/mxnet2ncnn.cpp
  5. +8
    -4
      tools/tensorflow/tensorflow2ncnn.cpp

+ 6
- 6
src/layer/arm/pooling_arm.cpp View File

@@ -58,7 +58,7 @@ int Pooling_arm::forward(const Mat& bottom_blob, Mat& top_blob) const
w = bottom_blob_bordered.w;
h = bottom_blob_bordered.h;
}
else if (pad_w == -233 && pad_h == -233)
else if (pad_mode == 2) // tensorflow padding=SAME
{
int wpad = kernel_w + (w - 1) / stride_w * stride_w - w;
int hpad = kernel_h + (h - 1) / stride_h * stride_h - h;
@@ -76,12 +76,12 @@ int Pooling_arm::forward(const Mat& bottom_blob, Mat& top_blob) const
int outw = (w - kernel_w) / stride_w + 1;
int outh = (h - kernel_h) / stride_h + 1;

int wtail = (w - kernel_w) % stride_w;
int htail = (h - kernel_h) % stride_h;
if ((pad_w == -233 && pad_h == -233) || (pad_w == -2333 && pad_h == -2333))
int wtail = 0;
int htail = 0;
if (pad_mode == 0) // full padding
{
wtail = 0;
htail = 0;
wtail = (w - kernel_w) % stride_w;
htail = (h - kernel_h) % stride_h;
}
if (wtail != 0 || htail != 0)
{


+ 7
- 6
src/layer/pooling.cpp View File

@@ -35,6 +35,7 @@ int Pooling::load_param(const ParamDict& pd)
pad_w = pd.get(3, 0);
pad_h = pd.get(13, pad_w);
global_pooling = pd.get(4, 0);
pad_mode = pd.get(5, 0);

return 0;
}
@@ -105,7 +106,7 @@ int Pooling::forward(const Mat& bottom_blob, Mat& top_blob) const
w = bottom_blob_bordered.w;
h = bottom_blob_bordered.h;
}
else if (pad_w == -233 && pad_h == -233)
else if (pad_mode == 2) // tensorflow padding=SAME
{
int wpad = kernel_w + (w - 1) / stride_w * stride_w - w;
int hpad = kernel_h + (h - 1) / stride_h * stride_h - h;
@@ -123,12 +124,12 @@ int Pooling::forward(const Mat& bottom_blob, Mat& top_blob) const
int outw = (w - kernel_w) / stride_w + 1;
int outh = (h - kernel_h) / stride_h + 1;

int wtail = (w - kernel_w) % stride_w;
int htail = (h - kernel_h) % stride_h;
if ((pad_w == -233 && pad_h == -233) || (pad_w == -2333 && pad_h == -2333))
int wtail = 0;
int htail = 0;
if (pad_mode == 0) // full padding
{
wtail = 0;
htail = 0;
wtail = (w - kernel_w) % stride_w;
htail = (h - kernel_h) % stride_h;
}
if (wtail != 0 || htail != 0)
{


+ 1
- 0
src/layer/pooling.h View File

@@ -40,6 +40,7 @@ public:
int pad_w;
int pad_h;
int global_pooling;
int pad_mode;// 0=full 1=valid 2=SAME
};

} // namespace ncnn


+ 7
- 1
tools/mxnet/mxnet2ncnn.cpp View File

@@ -1092,9 +1092,14 @@ int main(int argc, char** argv)
pool = 1;
}

int pad_mode = 1;
if (pooling_convention == "valid")
{
// TODO valid and full mode
pad_mode = 1;
}
else if (pooling_convention == "full")
{
pad_mode = 0;
}

fprintf(pp, " 0=%d", pool);
@@ -1105,6 +1110,7 @@ int main(int argc, char** argv)
if (!pad.empty())
fprintf(pp, " 3=%d", pad[0]);
fprintf(pp, " 4=%d", global_pool);
fprintf(pp, " 5=%d", pad_mode);
}
else if (n.op == "SliceChannel")
{


+ 8
- 4
tools/tensorflow/tensorflow2ncnn.cpp View File

@@ -530,6 +530,7 @@ int main(int argc, char** argv)
int pad = 0;

int global_pooling = 0;
int pad_mode = 1;

tensorflow::AttrValue value_ksize;
if (find_attr_value(node, "ksize", value_ksize))
@@ -552,11 +553,11 @@ int main(int argc, char** argv)
{
if (value_padding.s() == "VALID")
{
pad = 0;
pad_mode = 1;
}
else if (value_padding.s() == "SAME")
{
pad = -233;
pad_mode = 2;
}
}

@@ -567,6 +568,7 @@ int main(int argc, char** argv)
fprintf(pp, " 12=%d", stride_h);
fprintf(pp, " 3=%d", pad);
fprintf(pp, " 4=%d", global_pooling);
fprintf(pp, " 5=%d", pad_mode);
}
else if (node.op() == "Concat" || node.op() == "ConcatV2")
{
@@ -1075,6 +1077,7 @@ int main(int argc, char** argv)
int pad = 0;

int global_pooling = 0;
int pad_mode = 1;

tensorflow::AttrValue value_ksize;
if (find_attr_value(node, "ksize", value_ksize))
@@ -1097,11 +1100,11 @@ int main(int argc, char** argv)
{
if (value_padding.s() == "VALID")
{
pad = -2333;
pad_mode = 1;
}
else if (value_padding.s() == "SAME")
{
pad = -233;
pad_mode = 2;
}
}

@@ -1112,6 +1115,7 @@ int main(int argc, char** argv)
fprintf(pp, " 12=%d", stride_h);
fprintf(pp, " 3=%d", pad);
fprintf(pp, " 4=%d", global_pooling);
fprintf(pp, " 5=%d", pad_mode);
}
else if (node.op() == "Min" || node.op() == "Minimum")
{


Loading…
Cancel
Save