GitOrigin-RevId: 44a0adddba
tags/v1.7.0
| @@ -29,17 +29,24 @@ bool is_transpose_single( | |||||
| * assuming contig layout is: | * assuming contig layout is: | ||||
| * shape: b, m, n, c | * shape: b, m, n, c | ||||
| * stride: mnc, nc, c, 1 | * stride: mnc, nc, c, 1 | ||||
| * assuming non-contig layout is: | |||||
| * shape: b, m, n, c | |||||
| * stride: m*stride_m*c, stride_m*c, c, 1 | |||||
| * | * | ||||
| * then given layout should be: | * then given layout should be: | ||||
| * shape: b, n, m, c | * shape: b, n, m, c | ||||
| * stride: mnc, c, nc, 1 | * stride: mnc, c, nc, 1 | ||||
| * non-contig stride: m*stride_m*c, c, stride_m*c, 1 | |||||
| * | * | ||||
| * if c == 1: | * if c == 1: | ||||
| * shape: b, n, m | * shape: b, n, m | ||||
| * stride: mn, 1, n | * stride: mn, 1, n | ||||
| * non-contig stride: m*stride_m, 1, stride_m | |||||
| * | |||||
| * if b == 1: | * if b == 1: | ||||
| * shape: n, m, c | * shape: n, m, c | ||||
| * stride: c, nc, 1 | * stride: c, nc, 1 | ||||
| * non-contig stride: c, stride_m*c, 1 | |||||
| * | * | ||||
| * if b == 1 && c == 1: | * if b == 1 && c == 1: | ||||
| * shape: n, m | * shape: n, m | ||||
| @@ -65,7 +72,16 @@ bool is_transpose_single( | |||||
| p.n = layout[1]; | p.n = layout[1]; | ||||
| p.m = layout[2]; | p.m = layout[2]; | ||||
| p.c = 1; | p.c = 1; | ||||
| return strd(2, p.n) && strd(0, p.m * p.n); | |||||
| if (strd(2, p.n) && strd(0, p.m * p.n)) { | |||||
| return true; | |||||
| } else if ( | |||||
| allow_no_contig && (size_t)(layout.stride[2]) >= p.n && | |||||
| strd(0, p.m * (size_t)(layout.stride[2])) && strd(1, 1)) { | |||||
| p.stride_m = layout.stride[2]; | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| if (strd(2, 1)) { | if (strd(2, 1)) { | ||||
| // b == 1 | // b == 1 | ||||
| @@ -41,6 +41,20 @@ TEST_F(AARCH64, Relayout) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(AARCH64, RelayoutNonContig) { | |||||
| Checker<Relayout> checker(handle()); | |||||
| std::vector<::megdnn::DType> dtype_vec; | |||||
| dtype_vec.push_back(dtype::Float32()); | |||||
| dtype_vec.push_back(dtype::Int16()); | |||||
| dtype_vec.push_back(dtype::Uint16()); | |||||
| dtype_vec.push_back(dtype::Int8()); | |||||
| for (auto dtype : dtype_vec) { | |||||
| TensorLayout src({4, 90, 15, 29}, {41760, 1, 2784, 96}, dtype); | |||||
| TensorLayout dst({4, 90, 15, 29}, {39150, 435, 29, 1}, dtype); | |||||
| checker.execl({src, dst}); | |||||
| } | |||||
| } | |||||
| TEST_F(AARCH64, RelayoutBig) { | TEST_F(AARCH64, RelayoutBig) { | ||||
| Checker<Relayout> checker(handle()); | Checker<Relayout> checker(handle()); | ||||
| ConsecutiveRNG rng; | ConsecutiveRNG rng; | ||||