Browse Source

ctest 6

pull/5872/head
佰阅 1 year ago
parent
commit
8376eb7d3d
2 changed files with 49 additions and 13 deletions
  1. +12
    -8
      src/layer/flip.cpp
  2. +37
    -5
      tests/test_flip.cpp

+ 12
- 8
src/layer/flip.cpp View File

@@ -458,7 +458,6 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
}
else if (axis.w == 3)
{
return 0; // 在线debug
// dch3、dcw4、chw6
int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0];
int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1];
@@ -469,17 +468,19 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
// 对应dch,除w外,其余全翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度
int flipped_c = channels - 1 - c;

for (int z = 0; z < d; z++)
{
int flipped_d = d - 1 - z; // 翻转d维度
int flipped_d = d - 1 - z;

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = const_cast<float*>(top_blob.channel(flipped_c).row(flipped_d * h + (h - 1 - i))); // 翻转h维度
memcpy(outptr, ptr, w * sizeof(float)); // w维度保持不变
// 修改前:const float* ptr = bottom_blob.channel(c).row(z * h + i);
// 修改为:使用depth()访问方式
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = const_cast<float*>(top_blob.channel(flipped_c).depth(flipped_d).row(h - 1 - i));
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
@@ -520,9 +521,12 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = const_cast<float*>(top_blob.channel(c).row(flipped_d * h + (h - 1 - i))); // 翻转h维度
// const float* ptr = bottom_blob.channel(c).row(z * h + i);
// float* outptr = const_cast<float*>(top_blob.channel(c).row(flipped_d * h + (h - 1 - i))); // 翻转h维度

// 修改为使用depth()访问方式
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = const_cast<float*>(top_blob.channel(c).depth(flipped_d).row(h - 1 - i)); // 翻转h维度
// 翻转w维度
for (int k = 0; k < w; k++)
{


+ 37
- 5
tests/test_flip.cpp View File

@@ -124,9 +124,41 @@ static int test_flip_3()
int main()
{
SRAND(7767517);
return 0
|| test_flip_0()
|| test_flip_1()
|| test_flip_2()
|| test_flip_3();
// return 0
// || test_flip_0()
// || test_flip_1()
// || test_flip_2()
// || test_flip_3();

// debug 测出所有异常
test_flip(RandomMat(2, 3, 4, 5), IntArrayMat(0));
test_flip(RandomMat(3, 2, 4, 5), IntArrayMat(1));
test_flip(RandomMat(4, 3, 2, 5), IntArrayMat(2));
test_flip(RandomMat(2, 3, 1, 5), IntArrayMat(3));
test_flip(RandomMat(6, 3, 4, 5), IntArrayMat(0, 1));
test_flip(RandomMat(2, 3, 1, 6), IntArrayMat(0, 2));
test_flip(RandomMat(5, 1, 2, 5), IntArrayMat(0, 3));
test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(1, 2));
test_flip(RandomMat(4, 5, 2, 3), IntArrayMat(1, 3));
test_flip(RandomMat(2, 6, 4, 5), IntArrayMat(2, 3));
test_flip(RandomMat(6, 1, 4, 5), IntArrayMat(0, 1, 2));
test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(0, 1, 3));
test_flip(RandomMat(4, 3, 3, 5), IntArrayMat(0, 2, 3));
test_flip(RandomMat(4, 3, 4, 5), IntArrayMat(1, 2, 3));
test_flip(RandomMat(6, 3, 3, 2), IntArrayMat(0, 1, 2, 3));

test_flip(RandomMat(2, 3, 5), IntArrayMat(0));
test_flip(RandomMat(3, 3, 5), IntArrayMat(1));
test_flip(RandomMat(4, 3, 5), IntArrayMat(2));
test_flip(RandomMat(3, 1, 5), IntArrayMat(0, 1));
test_flip(RandomMat(3, 2, 5), IntArrayMat(0, 2));
test_flip(RandomMat(3, 3, 4), IntArrayMat(1, 2));
test_flip(RandomMat(4, 3, 2), IntArrayMat(0, 1, 2));

test_flip(RandomMat(8, 2), IntArrayMat(-2));
test_flip(RandomMat(16, 3), IntArrayMat(-1));
test_flip(RandomMat(7, 2), IntArrayMat(-2, -1));

test_flip(RandomMat(18), IntArrayMat(-1));
return 0;
}

Loading…
Cancel
Save