diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index d637a7fd5..8e1647f8c 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -493,6 +493,9 @@ y = crop(x) | 9 | starts | array | [ ] | | | 10 | ends | array | [ ] | | | 11 | axes | array | [ ] | | +| 19 | starts_expr | str | "" | | +| 20 | ends_expr | str | "" | | +| 21 | axes_expr | str | "" | | # CumulativeSum @@ -1699,6 +1702,7 @@ y = reshape(x) | 1 | h | int | -233 | | | 11 | d | int | -233 | | | 2 | c | int | -233 | | +| 6 | shape_expr | str | "" | | Reshape flag: - 0 = copy from bottom diff --git a/src/layer/arm/crop_arm.cpp b/src/layer/arm/crop_arm.cpp index e6163e4ed..e1b38cd8d 100644 --- a/src/layer/arm/crop_arm.cpp +++ b/src/layer/arm/crop_arm.cpp @@ -143,12 +143,21 @@ int Crop_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int elempack = bottom_blob.elempack; #if __ARM_NEON - if (elempack == 8) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == 8) + { if (dims == 1) { int out_elempack = _outw % 8 == 0 ? 8 : _outw % 4 == 0 ? 4 : 1; @@ -218,7 +227,7 @@ int Crop_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -260,7 +269,7 @@ int Crop_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -291,10 +300,6 @@ int Crop_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (elempack == 4) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -364,7 +369,7 @@ int Crop_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -406,7 +411,7 @@ int Crop_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -468,19 +473,28 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to Mat& top_blob = top_blobs[0]; #if __ARM_NEON - if (elempack == 8) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) - { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else + std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + bottom_blob_shapes[i] = bottom_blobs[i].shape(); } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) + { + resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == 8) + { if (dims == 1) { int out_elempack = _outw % 8 == 0 ? 8 : _outw % 4 == 0 ? 4 : 1; @@ -550,7 +564,7 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -592,7 +606,7 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -623,17 +637,6 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to if (elempack == 4) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) - { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else - { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -703,7 +706,7 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -745,7 +748,7 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -775,32 +778,23 @@ int Crop_arm::forward(const std::vector& bottom_blobs, std::vector& to } #endif // __ARM_NEON - Mat bottom_blob_unpacked = bottom_blob; - if (elempack != 1) + std::vector bottom_blobs_unpacked(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; - - convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); - if (bottom_blob_unpacked.empty()) - return -100; - } + Mat bottom_blob_unpacked = bottom_blobs[i]; + if (elempack != 1) + { + Option opt_pack1 = opt; + opt_pack1.blob_allocator = opt.workspace_allocator; - Mat reference_blob_unpacked = reference_blob; - if (ref_elempack != 1) - { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; + convert_packing(bottom_blobs[i], bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; + } - convert_packing(reference_blob, reference_blob_unpacked, 1, opt_pack1); - if (reference_blob_unpacked.empty()) - return -100; + bottom_blobs_unpacked[i] = bottom_blob_unpacked; } - std::vector bottom_blobs_unpacked(2); - bottom_blobs_unpacked[0] = bottom_blob_unpacked; - bottom_blobs_unpacked[1] = reference_blob_unpacked; - return Crop::forward(bottom_blobs_unpacked, top_blobs, opt); } diff --git a/src/layer/crop.cpp b/src/layer/crop.cpp index 5dfda493a..01d1c8bce 100644 --- a/src/layer/crop.cpp +++ b/src/layer/crop.cpp @@ -14,6 +14,8 @@ #include "crop.h" +#include "expression.h" + namespace ncnn { Crop::Crop() @@ -41,13 +43,34 @@ int Crop::load_param(const ParamDict& pd) ends = pd.get(10, Mat()); axes = pd.get(11, Mat()); + starts_expr = pd.get(19, ""); + ends_expr = pd.get(20, ""); + axes_expr = pd.get(21, ""); + + // NCNN_LOGE("%s %s %s", starts_expr.c_str(), ends_expr.c_str(), axes_expr.c_str()); + bool numpy_style_slice = !starts.empty() && !ends.empty(); + if (!starts_expr.empty() && !ends_expr.empty()) + numpy_style_slice = true; + if (outw == 0 && outh == 0 && outd == 0 && outc == 0 && woffset2 == 0 && hoffset2 == 0 && doffset2 == 0 && coffset2 == 0 && !numpy_style_slice) { one_blob_only = false; } + // count reference blobs + if (!starts_expr.empty() || !ends_expr.empty() || !axes_expr.empty()) + { + const int starts_blob_count = count_expression_blobs(starts_expr); + const int ends_blob_count = count_expression_blobs(ends_expr); + const int axes_blob_count = count_expression_blobs(axes_expr); + + // NCNN_LOGE("%d %d %d", starts_blob_count, ends_blob_count, axes_blob_count); + if (starts_blob_count > 1 || ends_blob_count > 1 || axes_blob_count > 1) + one_blob_only = false; + } + return 0; } @@ -89,7 +112,17 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons int _woffset, _hoffset, _doffset, _coffset; int _outw = -1, _outh = -1, _outd = -1, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blobs(1); + bottom_blobs[0] = bottom_blob; + eval_crop_expr(bottom_blobs, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } if (dims == 1) { @@ -109,8 +142,6 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons copy_cut_border_image(bottom_blob, top_blob, 0, _woffset); if (elemsize == 4) copy_cut_border_image(bottom_blob, top_blob, 0, _woffset); - - return 0; } if (dims == 2) @@ -131,8 +162,6 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons copy_cut_border_image(bottom_blob, top_blob, _hoffset, _woffset); if (elemsize == 4) copy_cut_border_image(bottom_blob, top_blob, _hoffset, _woffset); - - return 0; } if (dims == 3) @@ -147,7 +176,7 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; @@ -171,8 +200,6 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (elemsize == 4) copy_cut_border_image(m, borderm, _hoffset, _woffset); } - - return 0; } if (dims == 4) @@ -187,7 +214,7 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; @@ -214,8 +241,6 @@ int Crop::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons copy_cut_border_image(m, borderm, _hoffset, _woffset); } } - - return 0; } return 0; @@ -237,7 +262,12 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl int _woffset, _hoffset, _doffset, _coffset = -1; int _outw = -1, _outh = -1, _outd = -1, _outc; - if (woffset == -233) + + if (!starts_expr.empty() && !ends_expr.empty()) + { + eval_crop_expr(bottom_blobs, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) { resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); } @@ -264,8 +294,6 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl copy_cut_border_image(bottom_blob, top_blob, 0, _woffset); if (elemsize == 4) copy_cut_border_image(bottom_blob, top_blob, 0, _woffset); - - return 0; } if (dims == 2) @@ -286,8 +314,6 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl copy_cut_border_image(bottom_blob, top_blob, _hoffset, _woffset); if (elemsize == 4) copy_cut_border_image(bottom_blob, top_blob, _hoffset, _woffset); - - return 0; } if (dims == 3) @@ -302,7 +328,7 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; @@ -326,8 +352,6 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl if (elemsize == 4) copy_cut_border_image(m, borderm, _hoffset, _woffset); } - - return 0; } if (dims == 4) @@ -342,7 +366,7 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; @@ -369,8 +393,6 @@ int Crop::forward(const std::vector& bottom_blobs, std::vector& top_bl copy_cut_border_image(m, borderm, _hoffset, _woffset); } } - - return 0; } return 0; @@ -649,4 +671,150 @@ void Crop::resolve_crop_roi(const Mat& bottom_blob, const int* param_data, int& } } +int Crop::eval_crop_expr(const std::vector& bottom_blobs, int& _woffset, int& _hoffset, int& _doffset, int& _coffset, int& _outw, int& _outh, int& _outd, int& _outc) const +{ + std::vector _starts; + std::vector _ends; + std::vector _axes; + int er = eval_list_expression(starts_expr, bottom_blobs, _starts); + if (er != 0) + return -1; + + er = eval_list_expression(ends_expr, bottom_blobs, _ends); + if (er != 0) + return -1; + + er = eval_list_expression(axes_expr, bottom_blobs, _axes); + if (er != 0) + return -1; + + // NCNN_LOGE("%d %d %d", _starts[0], _ends[0], _axes[0]); + + const Mat& bottom_blob = bottom_blobs[0]; + const int w = bottom_blob.w; + const int h = bottom_blob.h; + const int d = bottom_blob.d; + const int channels = bottom_blob.c; + const int dims = bottom_blob.dims; + + _woffset = 0; + _hoffset = 0; + _doffset = 0; + _coffset = 0; + _outw = w; + _outh = h; + _outd = d; + _outc = channels; + + const int* starts_ptr = _starts.data(); + const int* ends_ptr = _ends.data(); + const int* axes_ptr = _axes.data(); + + int _axes4[4] = {0, 1, 2, 3}; + int num_axis = (int)_axes.size(); + if (num_axis == 0) + { + num_axis = dims; + } + else + { + for (int i = 0; i < num_axis; i++) + { + int axis = axes_ptr[i]; + if (axis < 0) + axis = dims + axis; + _axes4[i] = axis; + } + } + + for (int i = 0; i < num_axis; i++) + { + int axis = _axes4[i]; + int start = starts_ptr[i]; + int end = ends_ptr[i]; + + if (dims == 1) // axis == 0 + { + if (start == -233) start = 0; + if (end == -233) end = w; + _woffset = start >= 0 ? start : w + start; + _outw = std::min(w, end > 0 ? end : w + end) - _woffset; + } + if (dims == 2) + { + if (axis == 0) + { + if (start == -233) start = 0; + if (end == -233) end = h; + _hoffset = start >= 0 ? start : h + start; + _outh = std::min(h, end > 0 ? end : h + end) - _hoffset; + } + if (axis == 1) + { + if (start == -233) start = 0; + if (end == -233) end = w; + _woffset = start >= 0 ? start : w + start; + _outw = std::min(w, end > 0 ? end : w + end) - _woffset; + } + } + if (dims == 3) + { + if (axis == 0) + { + if (start == -233) start = 0; + if (end == -233) end = channels; + _coffset = start >= 0 ? start : channels + start; + _outc = std::min(channels, end > 0 ? end : channels + end) - _coffset; + } + if (axis == 1) + { + if (start == -233) start = 0; + if (end == -233) end = h; + _hoffset = start >= 0 ? start : h + start; + _outh = std::min(h, end > 0 ? end : h + end) - _hoffset; + } + if (axis == 2) + { + if (start == -233) start = 0; + if (end == -233) end = w; + _woffset = start >= 0 ? start : w + start; + _outw = std::min(w, end > 0 ? end : w + end) - _woffset; + } + } + if (dims == 4) + { + if (axis == 0) + { + if (start == -233) start = 0; + if (end == -233) end = channels; + _coffset = start >= 0 ? start : channels + start; + _outc = std::min(channels, end > 0 ? end : channels + end) - _coffset; + } + if (axis == 1) + { + if (start == -233) start = 0; + if (end == -233) end = d; + _doffset = start >= 0 ? start : d + start; + _outd = std::min(d, end > 0 ? end : d + end) - _doffset; + } + if (axis == 2) + { + if (start == -233) start = 0; + if (end == -233) end = h; + _hoffset = start >= 0 ? start : h + start; + _outh = std::min(h, end > 0 ? end : h + end) - _hoffset; + } + if (axis == 3) + { + if (start == -233) start = 0; + if (end == -233) end = w; + _woffset = start >= 0 ? start : w + start; + _outw = std::min(w, end > 0 ? end : w + end) - _woffset; + } + } + } + + return 0; +} + } // namespace ncnn diff --git a/src/layer/crop.h b/src/layer/crop.h index 826fe1bf1..bad3b4822 100644 --- a/src/layer/crop.h +++ b/src/layer/crop.h @@ -34,6 +34,7 @@ protected: void resolve_crop_roi(const Mat& bottom_blob, int& woffset, int& hoffset, int& doffset, int& coffset, int& outw, int& outh, int& outd, int& outc) const; void resolve_crop_roi(const Mat& bottom_blob, const Mat& reference_blob, int& woffset, int& hoffset, int& doffset, int& coffset, int& outw, int& outh, int& outd, int& outc) const; void resolve_crop_roi(const Mat& bottom_blob, const int* param_data, int& woffset, int& hoffset, int& doffset, int& coffset, int& outw, int& outh, int& outd, int& outc) const; + int eval_crop_expr(const std::vector& bottom_blobs, int& woffset, int& hoffset, int& doffset, int& coffset, int& outw, int& outh, int& outd, int& outc) const; public: // -233 = dynamic offset from reference blob @@ -60,6 +61,11 @@ public: Mat starts; Mat ends; Mat axes; + + // see docs/developer-guide/expression.md + std::string starts_expr; + std::string ends_expr; + std::string axes_expr; }; } // namespace ncnn diff --git a/src/layer/loongarch/crop_loongarch.cpp b/src/layer/loongarch/crop_loongarch.cpp index e7c588bc4..d2e6382c8 100644 --- a/src/layer/loongarch/crop_loongarch.cpp +++ b/src/layer/loongarch/crop_loongarch.cpp @@ -64,12 +64,21 @@ int Crop_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const Option& int elempack = bottom_blob.elempack; #if __loongarch_sx - if (elempack == 4) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == 4) + { if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -133,7 +142,7 @@ int Crop_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const Option& if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -172,7 +181,7 @@ int Crop_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const Option& if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -206,6 +215,8 @@ int Crop_loongarch::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt_pack1.blob_allocator = opt.workspace_allocator; convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; } return Crop::forward(bottom_blob_unpacked, top_blob, opt); @@ -229,19 +240,28 @@ int Crop_loongarch::forward(const std::vector& bottom_blobs, std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else - { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + bottom_blob_shapes[i] = bottom_blobs[i].shape(); } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) + { + resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == 4) + { if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -305,7 +325,7 @@ int Crop_loongarch::forward(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector bottom_blobs_unpacked(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; - - convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); - } + Mat bottom_blob_unpacked = bottom_blobs[i]; + if (elempack != 1) + { + Option opt_pack1 = opt; + opt_pack1.blob_allocator = opt.workspace_allocator; - Mat reference_blob_unpacked = reference_blob; - if (ref_elempack != 1) - { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; + convert_packing(bottom_blobs[i], bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; + } - convert_packing(reference_blob, reference_blob_unpacked, 1, opt_pack1); + bottom_blobs_unpacked[i] = bottom_blob_unpacked; } - std::vector bottom_blobs_unpacked(2); - bottom_blobs_unpacked[0] = bottom_blob_unpacked; - bottom_blobs_unpacked[1] = reference_blob_unpacked; - return Crop::forward(bottom_blobs_unpacked, top_blobs, opt); } diff --git a/src/layer/mips/crop_mips.cpp b/src/layer/mips/crop_mips.cpp index b0186cee9..70a3b96ca 100644 --- a/src/layer/mips/crop_mips.cpp +++ b/src/layer/mips/crop_mips.cpp @@ -64,12 +64,21 @@ int Crop_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int elempack = bottom_blob.elempack; #if __mips_msa - if (elempack == 4) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == 4) + { if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -133,7 +142,7 @@ int Crop_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -172,7 +181,7 @@ int Crop_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -206,6 +215,8 @@ int Crop_mips::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) opt_pack1.blob_allocator = opt.workspace_allocator; convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; } return Crop::forward(bottom_blob_unpacked, top_blob, opt); @@ -229,19 +240,28 @@ int Crop_mips::forward(const std::vector& bottom_blobs, std::vector& t Mat& top_blob = top_blobs[0]; #if __mips_msa - if (elempack == 4) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) + std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else - { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + bottom_blob_shapes[i] = bottom_blobs[i].shape(); } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) + { + resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == 4) + { if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -305,7 +325,7 @@ int Crop_mips::forward(const std::vector& bottom_blobs, std::vector& t if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -344,7 +364,7 @@ int Crop_mips::forward(const std::vector& bottom_blobs, std::vector& t if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -371,28 +391,23 @@ int Crop_mips::forward(const std::vector& bottom_blobs, std::vector& t } #endif // __mips_msa - Mat bottom_blob_unpacked = bottom_blob; - if (elempack != 1) + std::vector bottom_blobs_unpacked(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; - - convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); - } + Mat bottom_blob_unpacked = bottom_blobs[i]; + if (elempack != 1) + { + Option opt_pack1 = opt; + opt_pack1.blob_allocator = opt.workspace_allocator; - Mat reference_blob_unpacked = reference_blob; - if (ref_elempack != 1) - { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; + convert_packing(bottom_blobs[i], bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; + } - convert_packing(reference_blob, reference_blob_unpacked, 1, opt_pack1); + bottom_blobs_unpacked[i] = bottom_blob_unpacked; } - std::vector bottom_blobs_unpacked(2); - bottom_blobs_unpacked[0] = bottom_blob_unpacked; - bottom_blobs_unpacked[1] = reference_blob_unpacked; - return Crop::forward(bottom_blobs_unpacked, top_blobs, opt); } diff --git a/src/layer/riscv/crop_riscv.cpp b/src/layer/riscv/crop_riscv.cpp index 95717c459..ec4f2edc6 100644 --- a/src/layer/riscv/crop_riscv.cpp +++ b/src/layer/riscv/crop_riscv.cpp @@ -113,12 +113,21 @@ int Crop_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt int elempack = bottom_blob.elempack; #if __riscv_vector - if (elempack == packn) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == packn) + { if (dims == 1) { int out_elempack = _outw % packn == 0 ? packn : 1; @@ -188,7 +197,7 @@ int Crop_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -230,7 +239,7 @@ int Crop_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -267,6 +276,8 @@ int Crop_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt opt_pack1.blob_allocator = opt.workspace_allocator; convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; } return Crop::forward(bottom_blob_unpacked, top_blob, opt); @@ -296,19 +307,28 @@ int Crop_riscv::forward(const std::vector& bottom_blobs, std::vector& Mat& top_blob = top_blobs[0]; #if __riscv_vector - if (elempack == packn) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) + std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else - { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + bottom_blob_shapes[i] = bottom_blobs[i].shape(); } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) + { + resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + if (elempack == packn) + { if (dims == 1) { int out_elempack = _outw % packn == 0 ? packn : 1; @@ -378,7 +398,7 @@ int Crop_riscv::forward(const std::vector& bottom_blobs, std::vector& if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -420,7 +440,7 @@ int Crop_riscv::forward(const std::vector& bottom_blobs, std::vector& if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -450,28 +470,23 @@ int Crop_riscv::forward(const std::vector& bottom_blobs, std::vector& } #endif // __riscv_vector - Mat bottom_blob_unpacked = bottom_blob; - if (elempack != 1) + std::vector bottom_blobs_unpacked(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; - - convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); - } + Mat bottom_blob_unpacked = bottom_blobs[i]; + if (elempack != 1) + { + Option opt_pack1 = opt; + opt_pack1.blob_allocator = opt.workspace_allocator; - Mat reference_blob_unpacked = reference_blob; - if (ref_elempack != 1) - { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; + convert_packing(bottom_blobs[i], bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; + } - convert_packing(reference_blob, reference_blob_unpacked, 1, opt_pack1); + bottom_blobs_unpacked[i] = bottom_blob_unpacked; } - std::vector bottom_blobs_unpacked(2); - bottom_blobs_unpacked[0] = bottom_blob_unpacked; - bottom_blobs_unpacked[1] = reference_blob_unpacked; - return Crop::forward(bottom_blobs_unpacked, top_blobs, opt); } diff --git a/src/layer/vulkan/crop_vulkan.cpp b/src/layer/vulkan/crop_vulkan.cpp index 634dbd3ba..790d3bdcc 100644 --- a/src/layer/vulkan/crop_vulkan.cpp +++ b/src/layer/vulkan/crop_vulkan.cpp @@ -52,7 +52,36 @@ int Crop_vulkan::create_pipeline(const Option& opt) int offset_elempack = 1; bool numpy_style_slice = !starts.empty() && !ends.empty(); - if (numpy_style_slice) + if (!starts_expr.empty() && !ends_expr.empty() && !bottom_shapes.empty()) + { + int _woffset, _hoffset, _doffset, _coffset = -1; + int _outw = -1, _outh = -1, _outd = -1, _outc; + + eval_crop_expr(bottom_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + + if (shape.dims == 1) + { + if (_woffset == 0) + offset_elempack = elempack; + else + offset_elempack = opt.use_shader_pack8 && _woffset % 8 == 0 ? 8 : _woffset % 4 == 0 ? 4 : 1; + } + else if (shape.dims == 2) + { + if (_hoffset == 0) + offset_elempack = elempack; + else + offset_elempack = opt.use_shader_pack8 && _hoffset % 8 == 0 ? 8 : _hoffset % 4 == 0 ? 4 : 1; + } + else // if (shape.dims == 3 || shape.dims == 4) + { + if (_coffset == 0) + offset_elempack = elempack; + else + offset_elempack = opt.use_shader_pack8 && _coffset % 8 == 0 ? 8 : _coffset % 4 == 0 ? 4 : 1; + } + } + else if (numpy_style_slice) { offset_elempack = elempack; @@ -156,7 +185,7 @@ int Crop_vulkan::create_pipeline(const Option& opt) if (out_shape.dims == 4) out_shape_packed = Mat(out_shape.w, out_shape.h, out_shape.d, out_shape.c / out_elempack, (void*)0, out_elemsize, out_elempack); Mat shape_unpacked = shape_packed; - if (one_blob_only && shape.dims != 0 && elempack == out_elempack && elempack > offset_elempack) + if ((one_blob_only || (!starts_expr.empty() && !ends_expr.empty())) && shape.dims != 0 && elempack == out_elempack && elempack > offset_elempack) { size_t offset_elemsize; if (opt.use_fp16_storage) @@ -334,7 +363,16 @@ int Crop_vulkan::forward(const VkMat& bottom_blob, VkMat& top_blob, VkCompute& c int _woffset, _hoffset, _doffset, _coffset; int _outw, _outh, _outd, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } int offset_elempack; int out_elempack; @@ -513,7 +551,16 @@ int Crop_vulkan::forward(const std::vector& bottom_blobs, std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) + { + bottom_blob_shapes[i] = bottom_blobs[i].shape(); + } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) { resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob.mapped(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); } @@ -695,7 +742,16 @@ int Crop_vulkan::forward(const VkImageMat& bottom_blob, VkImageMat& top_blob, Vk int _woffset, _hoffset, _doffset, _coffset; int _outw, _outh, _outd, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } int offset_elempack; int out_elempack; @@ -874,7 +930,16 @@ int Crop_vulkan::forward(const std::vector& bottom_blobs, std::vecto int _woffset, _hoffset, _doffset, _coffset; int _outw, _outh, _outd, _outc; - if (woffset == -233) + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) + { + bottom_blob_shapes[i] = bottom_blobs[i].shape(); + } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) { resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob.mapped(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); } diff --git a/src/layer/x86/crop_x86.cpp b/src/layer/x86/crop_x86.cpp index 8f7e697e7..ace9584db 100644 --- a/src/layer/x86/crop_x86.cpp +++ b/src/layer/x86/crop_x86.cpp @@ -116,14 +116,23 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int elempack = bottom_blob.elempack; #if __SSE2__ + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) + { + std::vector bottom_blob_shapes(1); + bottom_blob_shapes[0] = bottom_blob.shape(); + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + #if __AVX__ #if __AVX512F__ if (elempack == 16) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - if (dims == 1) { int out_elempack = _outw % 16 == 0 ? 16 : _outw % 8 == 0 ? 8 : _outw % 4 == 0 ? 4 : 1; @@ -187,7 +196,7 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -225,7 +234,7 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -253,10 +262,6 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (elempack == 8) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - if (dims == 1) { int out_elempack = _outw % 8 == 0 ? 8 : _outw % 4 == 0 ? 4 : 1; @@ -320,7 +325,7 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -358,7 +363,7 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -386,10 +391,6 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (elempack == 4) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - resolve_crop_roi(bottom_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -453,7 +454,7 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -492,7 +493,7 @@ int Crop_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -546,26 +547,33 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to size_t elemsize = bottom_blob.elemsize; int elempack = bottom_blob.elempack; - int ref_elempack = reference_blob.elempack; - Mat& top_blob = top_blobs[0]; #if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) + int _woffset, _hoffset, _doffset, _coffset; + int _outw, _outh, _outd, _outc; + if (!starts_expr.empty() && !ends_expr.empty()) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) - { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else + std::vector bottom_blob_shapes(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + bottom_blob_shapes[i] = bottom_blobs[i].shape(); } + eval_crop_expr(bottom_blob_shapes, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else if (woffset == -233) + { + resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } + else + { + resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); + } +#if __AVX__ +#if __AVX512F__ + if (elempack == 16) + { if (dims == 1) { int out_elempack = _outw % 16 == 0 ? 16 : _outw % 8 == 0 ? 8 : _outw % 4 == 0 ? 4 : 1; @@ -629,7 +637,7 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -667,7 +675,7 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -695,17 +703,6 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (elempack == 8) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) - { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else - { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - if (dims == 1) { int out_elempack = _outw % 8 == 0 ? 8 : _outw % 4 == 0 ? 4 : 1; @@ -769,7 +766,7 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -807,7 +804,7 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -835,17 +832,6 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (elempack == 4) { - int _woffset, _hoffset, _doffset, _coffset; - int _outw, _outh, _outd, _outc; - if (woffset == -233) - { - resolve_crop_roi(bottom_blob.shape(), (const int*)reference_blob, _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - else - { - resolve_crop_roi(bottom_blob.shape(), reference_blob.shape(), _woffset, _hoffset, _doffset, _coffset, _outw, _outh, _outd, _outc); - } - if (dims == 1) { int out_elempack = _outw % 4 == 0 ? 4 : 1; @@ -909,7 +895,7 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -948,7 +934,7 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to if (_outw == w && _outh == h && _outd == d) { - top_blob = bottom_blob_sliced.clone(); + top_blob = bottom_blob_sliced.clone(opt.blob_allocator); if (top_blob.empty()) return -100; } @@ -975,32 +961,23 @@ int Crop_x86::forward(const std::vector& bottom_blobs, std::vector& to } #endif // __SSE2__ - Mat bottom_blob_unpacked = bottom_blob; - if (elempack != 1) + std::vector bottom_blobs_unpacked(bottom_blobs.size()); + for (size_t i = 0; i < bottom_blobs.size(); i++) { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; - - convert_packing(bottom_blob, bottom_blob_unpacked, 1, opt_pack1); - if (bottom_blob_unpacked.empty()) - return -100; - } + Mat bottom_blob_unpacked = bottom_blobs[i]; + if (elempack != 1) + { + Option opt_pack1 = opt; + opt_pack1.blob_allocator = opt.workspace_allocator; - Mat reference_blob_unpacked = reference_blob; - if (ref_elempack != 1) - { - Option opt_pack1 = opt; - opt_pack1.blob_allocator = opt.workspace_allocator; + convert_packing(bottom_blobs[i], bottom_blob_unpacked, 1, opt_pack1); + if (bottom_blob_unpacked.empty()) + return -100; + } - convert_packing(reference_blob, reference_blob_unpacked, 1, opt_pack1); - if (reference_blob_unpacked.empty()) - return -100; + bottom_blobs_unpacked[i] = bottom_blob_unpacked; } - std::vector bottom_blobs_unpacked(2); - bottom_blobs_unpacked[0] = bottom_blob_unpacked; - bottom_blobs_unpacked[1] = reference_blob_unpacked; - return Crop::forward(bottom_blobs_unpacked, top_blobs, opt); } diff --git a/src/paramdict.cpp b/src/paramdict.cpp index be7d4ef71..f1fbead97 100644 --- a/src/paramdict.cpp +++ b/src/paramdict.cpp @@ -358,7 +358,7 @@ int ParamDict::load_param(const DataReader& dr) vstr2[241] = '\0'; // max 255 = 15 + 240 if (vstr[0] == '\"') { - nscan = dr.scan("%255[^\"]\"", vstr2); + nscan = dr.scan("%255[^\"\n]\"", vstr2); } else { diff --git a/tests/test_crop.cpp b/tests/test_crop.cpp index d2a03eb53..93fc3a02f 100644 --- a/tests/test_crop.cpp +++ b/tests/test_crop.cpp @@ -17,18 +17,18 @@ static int test_crop(const ncnn::Mat& a, int woffset, int hoffset, int doffset, int coffset, int outw, int outh, int outd, int outc, int woffset2, int hoffset2, int doffset2, int coffset2) { ncnn::ParamDict pd; - pd.set(0, woffset); // woffset - pd.set(1, hoffset); // hoffset - pd.set(13, doffset); // doffset - pd.set(2, coffset); // coffset - pd.set(3, outw); // outw - pd.set(4, outh); // outh - pd.set(14, outd); // outd - pd.set(5, outc); // outc - pd.set(6, woffset2); // woffset2 - pd.set(7, hoffset2); // hoffset2 - pd.set(15, doffset2); // doffset2 - pd.set(8, coffset2); // coffset2 + pd.set(0, woffset); + pd.set(1, hoffset); + pd.set(13, doffset); + pd.set(2, coffset); + pd.set(3, outw); + pd.set(4, outh); + pd.set(14, outd); + pd.set(5, outc); + pd.set(6, woffset2); + pd.set(7, hoffset2); + pd.set(15, doffset2); + pd.set(8, coffset2); std::vector weights(0); diff --git a/tests/test_crop_3.cpp b/tests/test_crop_3.cpp new file mode 100644 index 000000000..fcf0b9a40 --- /dev/null +++ b/tests/test_crop_3.cpp @@ -0,0 +1,121 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_crop(const ncnn::Mat& a, const char* starts_expr, const char* ends_expr, const char* axes_expr) +{ + ncnn::ParamDict pd; + pd.set(19, std::string(starts_expr)); + pd.set(20, std::string(ends_expr)); + pd.set(21, std::string(axes_expr)); + + std::vector weights(0); + + int ret = test_layer("Crop", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_crop failed a.dims=%d a=(%d %d %d %d) starts_expr=%s ends_expr=%s axes_expr=%s\n", a.dims, a.w, a.h, a.d, a.c, starts_expr, ends_expr, axes_expr); + } + + return ret; +} + +static int test_crop(const std::vector& as, const char* starts_expr, const char* ends_expr, const char* axes_expr) +{ + ncnn::ParamDict pd; + pd.set(19, std::string(starts_expr)); + pd.set(20, std::string(ends_expr)); + pd.set(21, std::string(axes_expr)); + + std::vector weights(0); + + int ret = test_layer("Crop", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_crop failed a.dims=%d a=(%d %d %d %d) starts_expr=%s ends_expr=%s axes_expr=%s\n", as[0].dims, as[0].w, as[0].h, as[0].d, as[0].c, starts_expr, ends_expr, axes_expr); + } + + return ret; +} + +static int test_crop_0() +{ + ncnn::Mat a = RandomMat(13, 12, 25, 48); + ncnn::Mat b = RandomMat(13, 12, 48); + ncnn::Mat c = RandomMat(13, 48); + ncnn::Mat d = RandomMat(128); + + return 0 + || test_crop(a, "2", "-2", "0") + || test_crop(b, "2", "-2", "0") + || test_crop(c, "2", "-2", "0") + || test_crop(d, "2", "-2", "0") + || test_crop(a, "16", "32", "-4") + || test_crop(b, "16", "32", "-3") + || test_crop(c, "16", "32", "-2") + || test_crop(d, "16", "32", "-1") + || test_crop(a, "16,//(0d,4),2,3", "32,-1,-2,-3", "0,1,2,3") + || test_crop(b, "16,//(0h,4),2", "32,-1,-(0w,2)", "0,1,2") + || test_crop(c, "16,//(0w,4)", "32,-2", "0,1") + || test_crop(a, "10", "11", "1") + || test_crop(b, "1,1", "-(0c,15),-(0w,5)", "0,2") + || test_crop(a, "-(0w,3),0h//2,floor(*(0c,0.3))", "-1,0h,ceil(*(0c,0.9))", "3,2,0") + || test_crop(b, "-(0w,3),0h//2,floor(*(0c,0.3))", "-1,0h,ceil(*(0c,0.9))", "2,1,0") + || test_crop(c, "-(0w,3),floor(*(0h,0.3))", "-1,ceil(*(0h,0.9))", "1,0") + || test_crop(d, "floor(*(0w,0.3))", "ceil(*(0w,0.9))", "0"); +} + +static int test_crop_1() +{ + std::vector as(2); + as[0] = RandomMat(14, 15, 13, 48); + as[1] = RandomMat(8, 5, 3, 4); + + std::vector bs(2); + bs[0] = RandomMat(14, 15, 48); + bs[1] = RandomMat(28, 45, 16); + + std::vector cs(2); + cs[0] = RandomMat(24, 48); + cs[1] = RandomMat(28, 6); + + std::vector ds(3); + ds[0] = RandomMat(128); + ds[1] = RandomMat(16); + ds[2] = RandomMat(64); + + return 0 + || test_crop(as, "*(1c,4)", "*(1c,8)", "-4") + || test_crop(bs, "1c", "-(0c,1c)", "-3") + || test_crop(cs, "+(1h,10)", "-(1h,22)", "-2") + || test_crop(ds, "1w", "2w", "-1") + || test_crop(as, "16,//(min(0w,1d),4),2,3", "32,-1,-2,-3", "0,1,2,3") + || test_crop(bs, "16,//(min(0w,1h),4),2", "32,-1,-(0w,2)", "0,1,2") + || test_crop(cs, "16,//(min(0w,1w),4)", "32,-2", "0,1") + || test_crop(bs, "1,//(1w,7)", "+(1c,1),-(0w,2)", "0,2") + || test_crop(as, "-(1w,4)", "neg(1h,3)", "0") + || test_crop(bs, "-(1w,20)", "-2", "0") + || test_crop(bs, "//(1h,15)", "neg(//(1w,7))", "2") + || test_crop(bs, "//(100,0h),round(fmod(100,0c))", "-233,min(1c,0c)", "1,0"); +} + +int main() +{ + SRAND(776757); + + return 0 + || test_crop_0() + || test_crop_1(); +} diff --git a/tests/test_crop_oom.cpp b/tests/test_crop_oom.cpp new file mode 100644 index 000000000..0f22b933d --- /dev/null +++ b/tests/test_crop_oom.cpp @@ -0,0 +1,194 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "testutil.h" + +static int test_crop_oom(const ncnn::Mat& a, int woffset, int hoffset, int doffset, int coffset, int outw, int outh, int outd, int outc, int woffset2, int hoffset2, int doffset2, int coffset2) +{ + ncnn::ParamDict pd; + pd.set(0, woffset); + pd.set(1, hoffset); + pd.set(13, doffset); + pd.set(2, coffset); + pd.set(3, outw); + pd.set(4, outh); + pd.set(14, outd); + pd.set(5, outc); + pd.set(6, woffset2); + pd.set(7, hoffset2); + pd.set(15, doffset2); + pd.set(8, coffset2); + + std::vector weights(0); + + int ret = test_layer_oom("Crop", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_crop_oom failed a.dims=%d a=(%d %d %d %d) woffset=%d hoffset=%d doffset=%d coffset=%d outw=%d outh=%d outd=%d outc=%d woffset2=%d hoffset2=%d doffset2=%d coffset2=%d\n", a.dims, a.w, a.h, a.d, a.c, woffset, hoffset, doffset, coffset, outw, outh, outd, outc, woffset2, hoffset2, doffset2, coffset2); + } + + return ret; +} + +static int test_crop_oom(const ncnn::Mat& a, const char* starts_expr, const char* ends_expr, const char* axes_expr) +{ + ncnn::ParamDict pd; + pd.set(19, std::string(starts_expr)); + pd.set(20, std::string(ends_expr)); + pd.set(21, std::string(axes_expr)); + + std::vector weights(0); + + int ret = test_layer_oom("Crop", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_crop_oom failed a.dims=%d a=(%d %d %d %d) starts_expr=%s ends_expr=%s axes_expr=%s\n", a.dims, a.w, a.h, a.d, a.c, starts_expr, ends_expr, axes_expr); + } + + return ret; +} + +static int test_crop_oom(const std::vector& as, const char* starts_expr, const char* ends_expr, const char* axes_expr) +{ + ncnn::ParamDict pd; + pd.set(19, std::string(starts_expr)); + pd.set(20, std::string(ends_expr)); + pd.set(21, std::string(axes_expr)); + + std::vector weights(0); + + int ret = test_layer_oom("Crop", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_crop_oom failed a.dims=%d a=(%d %d %d %d) starts_expr=%s ends_expr=%s axes_expr=%s\n", as[0].dims, as[0].w, as[0].h, as[0].d, as[0].c, starts_expr, ends_expr, axes_expr); + } + + return ret; +} + +static int test_crop_0() +{ + ncnn::Mat a = RandomMat(13, 12, 25, 48); + ncnn::Mat b = RandomMat(13, 12, 48); + ncnn::Mat c = RandomMat(13, 48); + ncnn::Mat d = RandomMat(128); + + return 0 + || test_crop_oom(a, 1, 1, 1, 1, -233, -233, -233, -233, 1, 1, 1, 1) + || test_crop_oom(b, 1, 1, 0, 1, -233, -233, 0, -233, 1, 1, 0, 1) + || test_crop_oom(c, 1, 1, 0, 0, -233, -233, 0, 0, 1, 1, 0, 0) + || test_crop_oom(d, 1, 0, 0, 0, -233, 0, 0, 0, 1, 0, 0, 0) + || test_crop_oom(a, 2, 2, 2, 2, 6, 6, 6, 16, 0, 0, 0, 0) + || test_crop_oom(b, 2, 2, 0, 2, 6, 6, 0, 16, 0, 0, 0, 0) + || test_crop_oom(c, 2, 2, 0, 0, 6, 16, 0, 0, 0, 0, 0, 0) + || test_crop_oom(d, 2, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0) + || test_crop_oom(a, 3, 3, 3, 16, 3, 4, 5, 16, 0, 0, 0, 0) + || test_crop_oom(b, 3, 3, 0, 16, 3, 4, 0, 16, 0, 0, 0, 0) + || test_crop_oom(c, 3, 16, 0, 0, 3, 16, 0, 0, 0, 0, 0, 0) + || test_crop_oom(d, 16, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0); +} + +static int test_crop_1() +{ + ncnn::Mat a = RandomMat(13, 12, 25, 47); + ncnn::Mat b = RandomMat(13, 12, 47); + ncnn::Mat c = RandomMat(13, 47); + ncnn::Mat d = RandomMat(129); + + return 0 + || test_crop_oom(a, 1, 1, 1, 1, -233, -233, -233, -233, 1, 1, 1, 1) + || test_crop_oom(b, 1, 1, 0, 1, -233, -233, 0, -233, 1, 1, 0, 1) + || test_crop_oom(c, 1, 1, 0, 0, -233, -233, 0, 0, 1, 1, 0, 0) + || test_crop_oom(d, 1, 0, 0, 0, -233, 0, 0, 0, 1, 0, 0, 0) + || test_crop_oom(a, 2, 2, 2, 2, 6, 6, 6, 16, 0, 0, 0, 0) + || test_crop_oom(b, 2, 2, 0, 2, 6, 6, 0, 16, 0, 0, 0, 0) + || test_crop_oom(c, 2, 2, 0, 0, 6, 16, 0, 0, 0, 0, 0, 0) + || test_crop_oom(d, 2, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0) + || test_crop_oom(a, 3, 3, 3, 16, 6, 6, 6, 16, 0, 0, 0, 0) + || test_crop_oom(b, 3, 3, 0, 16, 6, 6, 0, 16, 0, 0, 0, 0) + || test_crop_oom(c, 3, 16, 0, 0, 6, 16, 0, 0, 0, 0, 0, 0) + || test_crop_oom(d, 16, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0); +} + +static int test_crop_2() +{ + ncnn::Mat a = RandomMat(13, 12, 25, 48); + ncnn::Mat b = RandomMat(13, 12, 48); + ncnn::Mat c = RandomMat(13, 48); + ncnn::Mat d = RandomMat(128); + + return 0 + || test_crop_oom(a, "2", "-2", "0") + || test_crop_oom(b, "2", "-2", "0") + || test_crop_oom(c, "2", "-2", "0") + || test_crop_oom(d, "2", "-2", "0") + || test_crop_oom(a, "16", "32", "-4") + || test_crop_oom(b, "16", "32", "-3") + || test_crop_oom(c, "16", "32", "-2") + || test_crop_oom(d, "16", "32", "-1") + || test_crop_oom(a, "16,//(0d,4),2,1", "32,-1,-2,-3", "0,1,2,3") + || test_crop_oom(b, "16,//(0h,4),2", "32,-1,-(0w,2)", "0,1,2") + || test_crop_oom(c, "16,//(0w,4)", "32,-2", "0,1") + || test_crop_oom(a, "10", "11", "1") + || test_crop_oom(b, "1,1", "-(0c,15),-(0w,5)", "0,2") + || test_crop_oom(a, "-(0w,3),0h//2,floor(*(0c,0.3))", "-1,0h,ceil(*(0c,0.9))", "3,2,0") + || test_crop_oom(b, "-(0w,3),0h//2,floor(*(0c,0.3))", "-1,0h,ceil(*(0c,0.9))", "2,1,0") + || test_crop_oom(c, "-(0w,3),floor(*(0h,0.3))", "-1,ceil(*(0h,0.9))", "1,0") + || test_crop_oom(d, "floor(*(0w,0.3))", "ceil(*(0w,0.9))", "0"); +} + +static int test_crop_3() +{ + std::vector as(2); + as[0] = RandomMat(14, 15, 13, 48); + as[1] = RandomMat(8, 5, 3, 4); + + std::vector bs(2); + bs[0] = RandomMat(14, 15, 48); + bs[1] = RandomMat(28, 45, 16); + + std::vector cs(2); + cs[0] = RandomMat(24, 48); + cs[1] = RandomMat(28, 6); + + std::vector ds(3); + ds[0] = RandomMat(128); + ds[1] = RandomMat(16); + ds[2] = RandomMat(64); + + return 0 + || test_crop_oom(as, "*(1c,4)", "*(1c,8)", "-4") + || test_crop_oom(bs, "1c", "-(0c,1c)", "-3") + || test_crop_oom(cs, "+(1h,10)", "-(1h,22)", "-2") + || test_crop_oom(ds, "1w", "2w", "-1") + || test_crop_oom(as, "16,//(min(0w,1d),4),2,3", "32,-1,-2,-3", "0,1,2,3") + || test_crop_oom(bs, "16,//(min(0w,1h),4),2", "32,-1,-(0w,2)", "0,1,2") + || test_crop_oom(cs, "16,//(min(0w,1w),4)", "32,-2", "0,1") + || test_crop_oom(bs, "1,//(1w,7)", "+(1c,1),-(0w,2)", "0,2") + || test_crop_oom(as, "-(1w,4)", "neg(1h,3)", "0") + || test_crop_oom(bs, "-(1w,20)", "-2", "0") + || test_crop_oom(bs, "//(1h,15)", "neg(//(1w,7))", "2") + || test_crop_oom(bs, "//(100,0h),round(fmod(100,0c))", "-233,min(1c,0c)", "1,0"); +} + +int main() +{ + SRAND(776757); + + return 0 + || test_crop_0() + || test_crop_1() + || test_crop_2() + || test_crop_3(); +} diff --git a/tests/test_reshape_1.cpp b/tests/test_reshape_1.cpp index e5c9ded09..f0af96dee 100644 --- a/tests/test_reshape_1.cpp +++ b/tests/test_reshape_1.cpp @@ -30,7 +30,7 @@ static int test_reshape(const ncnn::Mat& a, const char* shape_expr) return ret; } -static int test_reshape_refs(const std::vector& as, const char* shape_expr) +static int test_reshape(const std::vector& as, const char* shape_expr) { ncnn::ParamDict pd; pd.set(6, std::string(shape_expr)); @@ -40,23 +40,7 @@ static int test_reshape_refs(const std::vector& as, const char* shape int ret = test_layer("Reshape", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_reshape_refs failed a.dims=%d a=(%d %d %d %d) shape_expr=%s\n", as[0].dims, as[0].w, as[0].h, as[0].d, as[0].c, shape_expr); - } - - return ret; -} - -static int test_reshape_refs(const ncnn::Mat& a, const char* shape_expr) -{ - ncnn::ParamDict pd; - pd.set(6, std::string(shape_expr)); - - std::vector weights(0); - - int ret = test_layer("Reshape", pd, weights, a); - if (ret != 0) - { - fprintf(stderr, "test_reshape_refs failed a.dims=%d a=(%d %d %d %d) shape_expr=%s\n", a.dims, a.w, a.h, a.d, a.c, shape_expr); + fprintf(stderr, "test_reshape failed a.dims=%d a=(%d %d %d %d) shape_expr=%s\n", as[0].dims, as[0].w, as[0].h, as[0].d, as[0].c, shape_expr); } return ret; @@ -82,8 +66,8 @@ static int test_reshape_1() as[1] = RandomMat(28, 45, 48); return 0 - || test_reshape_refs(as, "*(1w,0.5),/(1h,3),-(1c,32)") - || test_reshape_refs(as, "*(0w,0h),-(-(1c,0c),16)"); + || test_reshape(as, "*(1w,0.5),/(1h,3),-(1c,32)") + || test_reshape(as, "*(0w,0h),-(-(1c,0c),16)"); } static int test_reshape_2() @@ -91,9 +75,9 @@ static int test_reshape_2() ncnn::Mat a = RandomMat(14, 15, 16); return 0 - || test_reshape_refs(a, "*(0w,0.5),/(0h,3),-1") - || test_reshape_refs(a, "-1") - || test_reshape_refs(a, "*(0w,0h),0c"); + || test_reshape(a, "*(0w,0.5),/(0h,3),-1") + || test_reshape(a, "-1") + || test_reshape(a, "*(0w,0h),0c"); } int main() diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 4b5e2b6cf..716179c80 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1235,6 +1235,15 @@ int ModelWriter::save(const char* parampath, const char* binpath) { if (!op->axes.empty()) fprintf_param_int_array(11, op->axes, pp); } + { + if (op->starts_expr != op_default->starts_expr) fprintf(pp, " 19=\"%s\"", op->starts_expr.c_str()); + } + { + if (op->ends_expr != op_default->ends_expr) fprintf(pp, " 20=\"%s\"", op->ends_expr.c_str()); + } + { + if (op->axes_expr != op_default->axes_expr) fprintf(pp, " 21=\"%s\"", op->axes_expr.c_str()); + } } else if (layer->type == "CumulativeSum") { @@ -2345,7 +2354,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 11=%d", d) fprintf_param_value(" 2=%d", c) { - if (op->shape_expr != op_default->shape_expr) fprintf(pp, " 6=%s", op->shape_expr.c_str()); + if (op->shape_expr != op_default->shape_expr) fprintf(pp, " 6=\"%s\"", op->shape_expr.c_str()); } } else if (layer->type == "RMSNorm") diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 33e892e5d..979466169 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -403,6 +403,7 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/convert_half_to_float.cpp pass_ncnn/convert_input.cpp pass_ncnn/convert_reshape_expression.cpp + pass_ncnn/convert_slice_expression.cpp pass_ncnn/convert_torch_cat.cpp pass_ncnn/convert_torch_chunk.cpp pass_ncnn/convert_torch_einsum.cpp diff --git a/tools/pnnx/src/pass_level2/Tensor_size.cpp b/tools/pnnx/src/pass_level2/Tensor_size.cpp index 70127abfc..b4db81aa9 100644 --- a/tools/pnnx/src/pass_level2/Tensor_size.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_size.cpp @@ -54,4 +54,26 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_size, 10) REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_size_dynamic, 11) +class Tensor_size_onnx : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +aten::size op_0 1 1 input shape +Gather op_1 1 1 shape out axis=0 indices=%dim +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "Tensor.size"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_size_onnx, 10) + } // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn.cpp b/tools/pnnx/src/pass_ncnn.cpp index 06c83ea3a..e0afe6d41 100644 --- a/tools/pnnx/src/pass_ncnn.cpp +++ b/tools/pnnx/src/pass_ncnn.cpp @@ -20,6 +20,7 @@ #include "pass_ncnn/convert_half_to_float.h" #include "pass_ncnn/convert_input.h" #include "pass_ncnn/convert_reshape_expression.h" +#include "pass_ncnn/convert_slice_expression.h" #include "pass_ncnn/convert_torch_cat.h" #include "pass_ncnn/convert_torch_chunk.h" #include "pass_ncnn/convert_torch_einsum.h" @@ -110,6 +111,7 @@ void pass_ncnn(Graph& g, const std::vector& module_operators) ncnn::convert_torch_einsum(g); ncnn::convert_reshape_expression(g); + ncnn::convert_slice_expression(g); ncnn::convert_Tensor_select(g); ncnn::convert_Tensor_slice(g); diff --git a/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp b/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp index e62aa580a..4d65f1bc7 100644 --- a/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp +++ b/tools/pnnx/src/pass_ncnn/convert_Tensor_select.cpp @@ -54,8 +54,27 @@ void convert_Tensor_select(Graph& graph) if (axis > batch_index) axis -= 1; - int dim = op->params.at("dim").i; - int index = op->params.at("index").i; + int dim; + int index; + if (op->has_param("dim")) + { + dim = op->params.at("dim").i; + } + else + { + fprintf(stderr, "select with dynamic dim is not supported\n"); + continue; + } + + if (op->has_param("index")) + { + index = op->params.at("index").i; + } + else + { + fprintf(stderr, "select with dynamic index is not supported\n"); + continue; + } op->params["9"] = std::vector {index}; op->params["10"] = std::vector {index + 1}; diff --git a/tools/pnnx/src/pass_ncnn/convert_slice_expression.cpp b/tools/pnnx/src/pass_ncnn/convert_slice_expression.cpp new file mode 100644 index 000000000..1da8d2471 --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_slice_expression.cpp @@ -0,0 +1,1412 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "convert_slice_expression.h" + +#include +#include + +namespace pnnx { + +namespace ncnn { + +static bool token_is_argument(const std::string& t) +{ + if (t[0] != '@' || t.size() < 2) + return false; + + for (size_t i = 1; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +static bool token_is_ncnn_argument(const std::string& t) +{ + char tt = t[t.size() - 1]; + if ((tt != 'w' && tt != 'h' && tt != 'd' && tt != 'c') || t.size() < 2) + return false; + + for (size_t i = 0; i + 1 < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +static bool token_is_complex(const std::string& t) +{ + // 2.000000e+00+3.000000e+00j + if (t[t.size() - 1] != 'j') + return false; + + return true; +} + +static bool token_is_literal(const std::string& t) +{ + if (token_is_ncnn_argument(t)) + return false; + + if (token_is_complex(t)) + return true; + + std::istringstream iss(t); + float f; + iss >> std::noskipws >> f; + return iss.eof() && !iss.fail(); +} + +// static void print_tokens(const std::vector& tokens) +// { +// std::string r; +// for (auto x : tokens) +// { +// r += x + " "; +// } +// fprintf(stderr, "tokens = %s\n", r.c_str()); +// } + +static std::vector split_into_tokens(const std::string& expr) +{ + std::vector tokens; + + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + + // filter unknown tokens + for (std::string& t : tokens) + { + if (t == "add") t = "+"; + if (t == "sub") t = "-"; + if (t == "mul") t = "*"; + if (t == "div") t = "/"; + if (t == "floor_divide") t = "//"; + if (t == "maximum") t = "max"; + if (t == "minimum") t = "min"; + if (t == "int") t = "trunc"; + + if (t == "torch.bool" || t == "torch.float" || t == "torch.long") + fprintf(stderr, "shape expression got unsupported op %s\n", t.c_str()); + } + + return tokens; +} + +static std::string transform_nchw_annotation_and_drop_batch_index(const std::vector& tokens, const std::vector& ordered_references, int output_batch_index) +{ + // change nchw annotation to w,h,c / w,h,d,c with batch index dropped + + struct typed_value + { + int type; // 0=i 1=f + union + { + int i; + float f; + }; + + typed_value() + : type(0), i(0) + { + } + typed_value(int _i) + : type(0), i(_i) + { + } + typed_value(float _f) + : type(1), f(_f) + { + } + + int to_int() + { + if (type == 0) + return i; + + // trunc by default + return (int)f; + } + }; + + // scan and stack + std::stack exprstack; + for (int i = (int)tokens.size() - 1; i >= 0; i--) + { + const std::string& t = tokens[i]; + + if (t == "size") + { + std::string a = exprstack.top(); + exprstack.pop(); + + // fprintf(stderr, "size %s\n", a.c_str()); + + if (exprstack.empty()) + { + std::string r = std::string("size(") + a + ")"; + exprstack.push(r); + } + else + { + std::string b = exprstack.top(); + exprstack.pop(); + + // fprintf(stderr, "size %s %s\n", a.c_str(), b.c_str()); + + if (token_is_argument(a) && token_is_literal(b)) + { + int input_index = std::stoi(a.substr(1)); + if (ordered_references[input_index]->shape.empty()) + { + std::string r = std::string("size(") + a + "," + b + ")"; + exprstack.push(r); + } + else + { + if (input_index > 9) + { + // ncnn can only handle at most 10 reference blobs + fprintf(stderr, "expression with large reference id %d is not supported yet\n", input_index); + } + + int bi = std::stoi(b); + + const int a_batch_index = ordered_references[input_index]->params["__batch_index"].i; + + if (bi == a_batch_index) + { + fprintf(stderr, "slice expression refer to batch axis %d is not supported\n", a_batch_index); + std::string r = std::string("size(") + a + "," + b + ")"; + exprstack.push(r); + } + else + { + int a_rank = (int)ordered_references[input_index]->shape.size(); + + if (bi < 0) + bi = a_rank + bi; + + if (bi > a_batch_index) + { + a_rank -= 1; + bi -= 1; + } + + if (a_rank == 1 && bi == 0) + { + exprstack.push(std::to_string(input_index) + "w"); + } + else if (a_rank == 2 && bi == 0) + { + exprstack.push(std::to_string(input_index) + "h"); + } + else if (a_rank == 2 && bi == 1) + { + exprstack.push(std::to_string(input_index) + "w"); + } + else if (a_rank == 3 && bi == 0) + { + exprstack.push(std::to_string(input_index) + "c"); + } + else if (a_rank == 3 && bi == 1) + { + exprstack.push(std::to_string(input_index) + "h"); + } + else if (a_rank == 3 && bi == 2) + { + exprstack.push(std::to_string(input_index) + "w"); + } + else if (a_rank == 4 && bi == 0) + { + exprstack.push(std::to_string(input_index) + "c"); + } + else if (a_rank == 4 && bi == 1) + { + exprstack.push(std::to_string(input_index) + "d"); + } + else if (a_rank == 4 && bi == 2) + { + exprstack.push(std::to_string(input_index) + "h"); + } + else if (a_rank == 4 && bi == 3) + { + exprstack.push(std::to_string(input_index) + "w"); + } + else + { + fprintf(stderr, "slice expression refer to %d-rank dim %d is not supported\n", a_rank, bi); + std::string r = std::string("size(") + a + "," + b + ")"; + exprstack.push(r); + } + } + } + } + else + { + std::string r = std::string("size(") + a + "," + b + ")"; + exprstack.push(r); + } + } + } + else if (t == "ceil" + || t == "floor" + || t == "round" + || t == "trunc") + { + std::string a = exprstack.top(); + exprstack.pop(); + + std::string r = t + "(" + a + ")"; + exprstack.push(r); + } + else if (t == "abs" + || t == "acos" + || t == "acosh" + || t == "asin" + || t == "asinh" + || t == "atan" + || t == "atanh" + || t == "cos" + || t == "cosh" + || t == "erf" + || t == "exp" + || t == "log" + || t == "log10" + || t == "neg" + || t == "reciprocal" + || t == "rsqrt" + || t == "sign" + || t == "sin" + || t == "sinh" + || t == "sqrt" + || t == "square" + || t == "tan" + || t == "tanh") + { + std::string a = exprstack.top(); + exprstack.pop(); + + std::string r = t + "(" + a + ")"; + exprstack.push(r); + } + else if (t == "+" + || t == "-" + || t == "*" + || t == "/" + || t == "//" + || t == "atan2" + || t == "max" + || t == "min" + || t == "fmod" + || t == "pow" + || t == "remainder" + || t == "logaddexp") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = t + "(" + a + "," + b + ")"; + exprstack.push(r); + } + else if (t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift") + { + std::string a = exprstack.top(); + exprstack.pop(); + std::string b = exprstack.top(); + exprstack.pop(); + + std::string r = t + "(" + a + "," + b + ")"; + exprstack.push(r); + } + else if (t == "[") // list + { + std::vector elements; + while (!exprstack.empty()) + { + std::string a = exprstack.top(); + exprstack.pop(); + + elements.push_back(a); + } + + // drop output batch index + if (output_batch_index != 233) + { + for (int j = output_batch_index; j + 1 < (int)elements.size(); j++) + { + elements[j] = elements[j + 1]; + } + elements.resize(elements.size() - 1); + } + + // reverse order + std::string r; + for (int j = (int)elements.size() - 1; j >= 0; j--) + { + r += elements[j]; + if (j != 0) + r += ","; + } + + exprstack.push(r); + } + else if (t[0] == '@') + { + exprstack.push(t); + } + else + { + // literal + exprstack.push(t); + } + } + + std::string r = exprstack.top(); + exprstack.pop(); + while (!exprstack.empty()) + { + r += std::string(",") + exprstack.top(); + exprstack.pop(); + } + + return r; +} + +static void drop_expression_op(Graph& graph, const Operator* op_this, Operator* op_expr) +{ + if (!op_expr) + return; + + Operand* expr_out = op_expr->outputs[0]; + expr_out->remove_consumer(op_this); + + if (expr_out->consumers.empty()) + { + for (auto& x : op_expr->inputs) + { + x->remove_consumer(op_expr); + } + + graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), expr_out)); + delete expr_out; + + graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op_expr)); + delete op_expr; + } +} + +void convert_slice_expression_single_axis_ranged(Graph& graph) +{ + int op_index = 0; + + // single-axis ranged slice + // pnnx.Expression + // pnnx.Expression + // Tensor.slice + + while (1) + { + bool matched = false; + + for (Operator* op : graph.ops) + { + if (op->type != "Tensor.slice") + continue; + + if (op->inputs.size() == 1) + continue; + + if (!op->has_param("dim")) + continue; + + const int dim = op->params.at("dim").i; + + int start = 0; + int end = 0; + int step = 0; + int select = 0; + Operator* op_start = 0; + Operator* op_end = 0; + Operator* op_step = 0; + Operator* op_select = 0; + + if (op->has_param("start")) + { + start = op->params.at("start").i; + } + else if (op->has_input("start")) + { + op_start = op->named_input("start")->producer; + if (op_start->type != "pnnx.Expression") + continue; + } + else + { + continue; + } + + if (op->has_param("end")) + { + end = op->params.at("end").i; + } + else if (op->has_input("end")) + { + op_end = op->named_input("end")->producer; + if (op_end->type != "pnnx.Expression") + continue; + } + else + { + continue; + } + + if (op->has_param("step")) + { + step = op->params.at("step").i; + } + else if (op->has_input("step")) + { + op_step = op->named_input("step")->producer; + if (op_step->type != "pnnx.Expression") + continue; + } + else + { + continue; + } + + if (op->has_param("select")) + { + select = op->params.at("select").i; + } + else if (op->has_input("select")) + { + op_select = op->named_input("select")->producer; + if (op_select->type != "pnnx.Expression") + continue; + } + + fprintf(stderr, "----------------------------convert_slice_expression_single_axis_ranged\n"); + + matched = true; + + std::string start_expr = op_start ? op_start->params["expr"].s : std::to_string(start); + std::string end_expr = op_end ? op_end->params["expr"].s : std::to_string(end); + std::string step_expr = op_step ? op_step->params["expr"].s : std::to_string(step); + std::string select_expr = op_select ? op_select->params["expr"].s : std::to_string(select); + + bool has_select = !op_step && step == 0; + if (has_select) + { + // simulate select as slice + start_expr = select_expr; + end_expr = std::string("add(") + select_expr + ",1)"; + step_expr = "1"; + } + + // split into tokens + std::vector start_tokens = split_into_tokens(start_expr); + std::vector end_tokens = split_into_tokens(end_expr); + std::vector step_tokens = split_into_tokens(step_expr); + + // collect inputs and references + std::map references; + + // begin with input blob + int reference_index = 0; + { + references[op->inputs[0]] = reference_index++; + } + + for (size_t i = 0; i < start_tokens.size(); i++) + { + std::string& t = start_tokens[i]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_start->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + for (size_t i = 0; i < end_tokens.size(); i++) + { + std::string& t = end_tokens[i]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_end->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + for (size_t i = 0; i < step_tokens.size(); i++) + { + std::string& t = step_tokens[i]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_step->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + // reuse the same reference + t = "@" + std::to_string(references[r]); + } + + std::vector ordered_references(references.size()); + for (auto x : references) + { + ordered_references[x.second] = x.first; + } + + // change nchw annotation to w,h,c / w,h,d,c with batch index dropped + + const int batch_index = op->outputs[0]->params["__batch_index"].i; + + std::string starts_expr = transform_nchw_annotation_and_drop_batch_index(start_tokens, ordered_references, batch_index); + std::string ends_expr = transform_nchw_annotation_and_drop_batch_index(end_tokens, ordered_references, batch_index); + std::string steps_expr = transform_nchw_annotation_and_drop_batch_index(step_tokens, ordered_references, batch_index); + + if (steps_expr != std::to_string(1)) + { + fprintf(stderr, "slice with step expression %s is not supported\n", steps_expr.c_str()); + } + + op->type = "Crop"; + op->name = std::string("slice1_") + std::to_string(op_index++); + + op->params.clear(); + op->params["19"] = starts_expr; + op->params["20"] = ends_expr; + op->params["21"] = std::to_string(dim); + + // link references to reshape + { + op->inputs = ordered_references; + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->consumers.push_back(op); + } + } + + // drop expression + drop_expression_op(graph, op, op_start); + drop_expression_op(graph, op, op_end); + drop_expression_op(graph, op, op_step); + drop_expression_op(graph, op, op_select); + + // reshape for output, squeezing the slice dim + if (has_select) + { + Operand* out = op->outputs[0]; + + Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op); + + Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in"); + + reshape_in->params["__batch_index"] = batch_index; + + reshape->inputs.push_back(reshape_in); + reshape->outputs.push_back(out); + + op->outputs[0] = reshape_in; + + out->producer = reshape; + reshape_in->producer = op; + reshape_in->consumers.push_back(reshape); + + reshape->params["shape"] = out->shape; + } + + break; + } + + if (!matched) + break; + } +} + +void convert_slice_expression_single_axis_select(Graph& graph) +{ + int op_index = 0; + + // single-axis one slice + // pnnx.Expression + // Tensor.select + + while (1) + { + bool matched = false; + + for (Operator* op : graph.ops) + { + if (op->type != "Tensor.select") + continue; + + if (op->inputs.size() == 1) + continue; + + if (!op->has_param("dim")) + continue; + + const int dim = op->params.at("dim").i; + + int start = 0; + Operator* op_start = 0; + + if (op->has_param("index")) + { + start = op->params.at("index").i; + } + else if (op->has_input("index")) + { + op_start = op->named_input("index")->producer; + if (op_start->type != "pnnx.Expression") + continue; + } + else + { + continue; + } + + fprintf(stderr, "----------------------------convert_slice_expression_single_axis_select\n"); + + matched = true; + + std::string start_expr = op_start ? op_start->params["expr"].s : std::to_string(start); + + // split into tokens + std::vector start_tokens = split_into_tokens(start_expr); + + // collect inputs and references + std::map references; + + // begin with input blob + int reference_index = 0; + { + references[op->inputs[0]] = reference_index++; + } + + for (size_t i = 0; i < start_tokens.size(); i++) + { + std::string& t = start_tokens[i]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_start->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + + std::vector ordered_references(references.size()); + for (auto x : references) + { + ordered_references[x.second] = x.first; + } + + // change nchw annotation to w,h,c / w,h,d,c with batch index dropped + + const int batch_index = op->outputs[0]->params["__batch_index"].i; + + std::string starts_expr = transform_nchw_annotation_and_drop_batch_index(start_tokens, ordered_references, batch_index); + + op->type = "Crop"; + op->name = std::string("slice2_") + std::to_string(op_index++); + + op->params.clear(); + op->params["19"] = starts_expr; + op->params["20"] = std::string("+(") + starts_expr + ",1)"; + op->params["21"] = std::to_string(dim); + + // link references to reshape + { + op->inputs = ordered_references; + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->consumers.push_back(op); + } + } + + // drop expression + drop_expression_op(graph, op, op_start); + + // squeezing the select dim + { + Operand* out = op->outputs[0]; + + Operator* squeeze = graph.new_operator_after("torch.squeeze", op->name + "_ncnnsqueeze", op); + + Operand* squeeze_in = graph.new_operand(op->name + "_ncnnsqueeze_in"); + + squeeze->inputs.push_back(squeeze_in); + squeeze->outputs.push_back(out); + + op->outputs[0] = squeeze_in; + + out->producer = squeeze; + squeeze_in->producer = op; + squeeze_in->consumers.push_back(squeeze); + + squeeze->params["dim"] = dim; + + squeeze_in->params["__batch_index"] = batch_index; + } + + break; + } + + if (!matched) + break; + } +} + +static std::vector split_into_raw_tokens(const std::string& expr) +{ + std::vector tokens; + + std::string t; + for (size_t i = 0; i < expr.size(); i++) + { + char ch = expr[i]; + + if (ch == '[') // list + { + t += ch; + tokens.push_back(t); + t.clear(); + } + else if (ch == '(' || ch == ')' || ch == ',' || ch == ']') + { + if (!t.empty()) + { + tokens.push_back(t); + t.clear(); + } + + std::string tt; + tt += ch; + tokens.push_back(tt); + } + else + { + t += ch; + } + } + + if (!t.empty()) + { + tokens.push_back(t); + } + + return tokens; +} + +static void make_slice_indexes_expression(Graph& graph) +{ + // pnnx.Expression pnnx_expr_24 2 1 0 1 13 expr=sub(floor_divide(size(@0,0),64),floor_divide(size(@1,1),128)) #0=(?)f32 #1=(?,?)f32 + // pnnx.Expression pnnx_expr_18 1 1 12 14 expr=sub(size(@0,3),3) #12=(1,15,?,?)f32 + // pnnx.Expression pnnx_expr_13 1 1 12 15 expr=floor_divide(neg(size(@0,2)),7) #12=(1,15,?,?)f32 + // pnnx.Expression pnnx_expr_8 1 1 12 16 expr=floor_divide(size(@0,2),3) #12=(1,15,?,?)f32 + // pnnx.SliceIndexes ncnnstarts 1 1 14 17 indexes=(0,@0,0) + // pnnx.SliceIndexes ncnnends 1 1 15 18 indexes=(0,@0,0) + // pnnx.SliceIndexes ncnnselects 2 1 13 16 19 indexes=(@0,2147483647,@1) + + while (1) + { + bool matched = false; + + for (Operator* op : graph.ops) + { + if (op->type != "pnnx.SliceIndexes") + continue; + + bool slice_index_expr = true; + for (size_t i = 0; i < op->inputs.size(); i++) + { + if (op->inputs[i]->producer->type != "pnnx.Expression") + { + slice_index_expr = false; + break; + } + } + if (!slice_index_expr) + continue; + + matched = true; + + const std::vector& indexes = op->params["indexes"].as; + + std::map references; + std::vector op_expr_si; + + int reference_index = 0; + + std::vector new_indexes; + + for (size_t i = 0; i < indexes.size(); i++) + { + std::string si_expr = indexes[i]; + Operator* op_si = 0; + if (si_expr[0] == '@') + { + int si = std::stoi(si_expr.substr(1)); + op_si = op->inputs[si]->producer; + si_expr = op_si->params.at("expr").s; + + op_expr_si.push_back(op_si); + } + + // split into tokens + std::vector si_tokens = split_into_raw_tokens(si_expr); + + // collect inputs and references + for (size_t j = 0; j < si_tokens.size(); j++) + { + std::string& t = si_tokens[j]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_si->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + + std::string expr; + for (const std::string& t : si_tokens) + { + expr += t; + } + + new_indexes.push_back(expr); + } + + std::vector ordered_references(references.size()); + for (auto x : references) + { + ordered_references[x.second] = x.first; + } + + op->params["indexes"] = new_indexes; + + // for (auto x : new_indexes) + // { + // fprintf(stderr, "%s ", x.c_str()); + // } + // fprintf(stderr, "\n"); + + // link references to slice indexes expression + { + op->inputs = ordered_references; + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->consumers.push_back(op); + } + } + + // drop expression + for (auto op_si : op_expr_si) + { + drop_expression_op(graph, op, op_si); + } + + break; + } + + if (!matched) + break; + } +} + +void convert_slice_expression_multi_axis_ranged(Graph& graph) +{ + int op_index = 0; + + // multi-axis ranged slice + // pnnx.SliceIndexes + // pnnx.SliceIndexes + // pnnx.SliceIndexes + // Tensor.slice + + while (1) + { + bool matched = false; + + for (Operator* op : graph.ops) + { + if (op->type != "Tensor.slice") + continue; + + if (op->inputs.size() == 1) + continue; + + if (!op->has_param("dims")) + continue; + + const std::vector& dims = op->params.at("dims").ai; + + std::vector starts; + std::vector ends; + std::vector steps; + std::vector selects; + Operator* op_starts = 0; + Operator* op_ends = 0; + Operator* op_steps = 0; + Operator* op_selects = 0; + + if (op->has_param("starts")) + { + starts = op->params.at("starts").ai; + } + else if (op->has_input("starts")) + { + op_starts = op->named_input("starts")->producer; + if (op_starts->type != "pnnx.SliceIndexes") + continue; + } + else + { + continue; + } + + if (op->has_param("ends")) + { + ends = op->params.at("ends").ai; + } + else if (op->has_input("ends")) + { + op_ends = op->named_input("ends")->producer; + if (op_ends->type != "pnnx.SliceIndexes") + continue; + } + else + { + continue; + } + + if (op->has_param("steps")) + { + steps = op->params.at("steps").ai; + } + else if (op->has_input("steps")) + { + op_steps = op->named_input("steps")->producer; + if (op_steps->type != "pnnx.SliceIndexes") + continue; + } + else + { + continue; + } + + if (op->has_param("selects")) + { + selects = op->params.at("selects").ai; + } + else if (op->has_input("selects")) + { + op_selects = op->named_input("selects")->producer; + if (op_selects->type != "pnnx.SliceIndexes") + continue; + } + else + { + continue; + } + + fprintf(stderr, "----------------------------convert_slice_expression_multi_axis_ranged\n"); + + matched = true; + + std::vector starts_expr; + std::vector ends_expr; + std::vector steps_expr; + std::vector selects_expr; + if (op_starts) + { + starts_expr = op_starts->params["indexes"].as; + } + else + { + for (int i : starts) + { + starts_expr.push_back(std::to_string(i)); + } + } + if (op_ends) + { + ends_expr = op_ends->params["indexes"].as; + } + else + { + for (int i : ends) + { + ends_expr.push_back(std::to_string(i)); + } + } + if (op_steps) + { + steps_expr = op_steps->params["indexes"].as; + } + else + { + for (int i : steps) + { + steps_expr.push_back(std::to_string(i)); + } + } + if (op_selects) + { + selects_expr = op_selects->params["indexes"].as; + } + else + { + for (int i : selects) + { + selects_expr.push_back(std::to_string(i)); + } + } + + // collect inputs and references + std::map references; + + // begin with input blob + int reference_index = 0; + { + references[op->inputs[0]] = reference_index++; + } + + bool has_select = false; + + const size_t dims_count = dims.size(); + + for (size_t i = 0; i < dims_count; i++) + { + const std::string& start_expr = starts_expr[i]; + const std::string& end_expr = ends_expr[i]; + const std::string& step_expr = steps_expr[i]; + const std::string& select_expr = selects_expr[i]; + + // split into tokens + std::vector start_tokens = split_into_raw_tokens(start_expr); + std::vector end_tokens = split_into_raw_tokens(end_expr); + std::vector step_tokens = split_into_raw_tokens(step_expr); + std::vector select_tokens = split_into_raw_tokens(select_expr); + + bool is_select = true; + if (select_tokens.size() == 1 && select_tokens[0] == std::to_string(INT_MAX)) + { + is_select = false; + } + + if (is_select) + { + has_select = true; + + // simulate select as slice + for (size_t j = 0; j < select_tokens.size(); j++) + { + std::string& t = select_tokens[j]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_selects->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + + start_tokens = select_tokens; + end_tokens.clear(); + step_tokens.clear(); + end_tokens.push_back("add"); + end_tokens.push_back("("); + for (auto t : select_tokens) + { + end_tokens.push_back(t); + step_tokens.push_back("1"); + } + end_tokens.push_back(","); + end_tokens.push_back("1"); + end_tokens.push_back(")"); + } + else + { + for (size_t j = 0; j < start_tokens.size(); j++) + { + std::string& t = start_tokens[j]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_starts->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + for (size_t j = 0; j < end_tokens.size(); j++) + { + std::string& t = end_tokens[j]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_ends->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + t = "@" + std::to_string(references[r]); + } + for (size_t j = 0; j < step_tokens.size(); j++) + { + std::string& t = step_tokens[j]; + + if (t[0] != '@') + continue; + + int input_index = std::stoi(t.substr(1)); + Operand* r = op_steps->inputs[input_index]; + + if (references.find(r) == references.end()) + { + references[r] = reference_index++; + } + + // reuse the same reference + t = "@" + std::to_string(references[r]); + } + } + + std::string new_start_expr; + std::string new_end_expr; + std::string new_step_expr; + for (const std::string& t : start_tokens) + { + new_start_expr += t; + } + for (const std::string& t : end_tokens) + { + new_end_expr += t; + } + for (const std::string& t : step_tokens) + { + new_step_expr += t; + } + starts_expr[i] = new_start_expr; + ends_expr[i] = new_end_expr; + steps_expr[i] = new_step_expr; + } + + std::vector ordered_references(references.size()); + for (auto x : references) + { + ordered_references[x.second] = x.first; + } + + // change nchw annotation to w,h,c / w,h,d,c with batch index dropped + + const int batch_index = op->outputs[0]->params["__batch_index"].i; + + std::string new_starts_expr; + std::string new_ends_expr; + std::string new_steps_expr; + std::string new_dims_expr; + + for (size_t i = 0; i < dims_count; i++) + { + const std::string& start_expr = starts_expr[i]; + const std::string& end_expr = ends_expr[i]; + const std::string& step_expr = steps_expr[i]; + + // split into tokens + std::vector start_tokens = split_into_tokens(start_expr); + std::vector end_tokens = split_into_tokens(end_expr); + std::vector step_tokens = split_into_tokens(step_expr); + + std::string new_start_expr = transform_nchw_annotation_and_drop_batch_index(start_tokens, ordered_references, batch_index); + std::string new_end_expr = transform_nchw_annotation_and_drop_batch_index(end_tokens, ordered_references, batch_index); + std::string new_step_expr = transform_nchw_annotation_and_drop_batch_index(step_tokens, ordered_references, batch_index); + + if (new_step_expr != std::to_string(1)) + { + fprintf(stderr, "slice with step expression %s is not supported\n", new_step_expr.c_str()); + } + + new_starts_expr += new_start_expr; + new_ends_expr += new_end_expr; + new_steps_expr += new_step_expr; + new_dims_expr += std::to_string(dims[i]); + + if (i + 1 != dims_count) + { + new_starts_expr += ","; + new_ends_expr += ","; + new_steps_expr += ","; + new_dims_expr += ","; + } + } + + op->type = "Crop"; + op->name = std::string("slice3_") + std::to_string(op_index++); + + op->params.clear(); + op->params["19"] = new_starts_expr; + op->params["20"] = new_ends_expr; + op->params["21"] = new_dims_expr; + + // link references to reshape + { + op->inputs = ordered_references; + + for (size_t i = 1; i < op->inputs.size(); i++) + { + op->inputs[i]->consumers.push_back(op); + } + } + + // drop expression + drop_expression_op(graph, op, op_starts); + drop_expression_op(graph, op, op_ends); + drop_expression_op(graph, op, op_steps); + drop_expression_op(graph, op, op_selects); + + // reshape for output, squeezing the slice dim + if (has_select) + { + Operand* out = op->outputs[0]; + + Operator* reshape = graph.new_operator_after("Tensor.reshape", op->name + "_ncnnreshape", op); + + Operand* reshape_in = graph.new_operand(op->name + "_ncnnreshape_in"); + + reshape_in->params["__batch_index"] = batch_index; + + reshape->inputs.push_back(reshape_in); + reshape->outputs.push_back(out); + + op->outputs[0] = reshape_in; + + out->producer = reshape; + reshape_in->producer = op; + reshape_in->consumers.push_back(reshape); + + reshape->params["shape"] = out->shape; + } + + break; + } + + if (!matched) + break; + } +} + +void convert_slice_expression(Graph& graph) +{ + convert_slice_expression_single_axis_ranged(graph); + + convert_slice_expression_single_axis_select(graph); + + make_slice_indexes_expression(graph); + + convert_slice_expression_multi_axis_ranged(graph); +} + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/convert_slice_expression.h b/tools/pnnx/src/pass_ncnn/convert_slice_expression.h new file mode 100644 index 000000000..2c57db53a --- /dev/null +++ b/tools/pnnx/src/pass_ncnn/convert_slice_expression.h @@ -0,0 +1,25 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_ncnn.h" + +namespace pnnx { + +namespace ncnn { + +void convert_slice_expression(Graph& graph); + +} // namespace ncnn + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp index d4532422b..efca5bcb3 100644 --- a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp +++ b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp @@ -225,6 +225,21 @@ static void solve_batch_index_forward(Operand* operand) // give up reshape across batch index } } + else if (op->type == "Tensor.slice" || op->type == "Tensor.select") + { + Operand* r = op->outputs[0]; + if (r->params.find("__batch_index") == r->params.end()) + { + r->params["__batch_index"] = batch_index; + + solve_batch_index_forward(r); + solve_batch_index_backward(r); + } + } + else if (op->type == "pnnx.SliceIndexes") + { + // pass + } else { for (Operand* r : op->outputs) @@ -325,6 +340,21 @@ static void solve_batch_index_backward(Operand* operand) // give up reshape across batch index } } + else if (op->type == "Tensor.slice" || op->type == "Tensor.select") + { + Operand* r = op->inputs[0]; + if (r->params.find("__batch_index") == r->params.end()) + { + r->params["__batch_index"] = batch_index; + + solve_batch_index_backward(r); + solve_batch_index_forward(r); + } + } + else if (op->type == "pnnx.SliceIndexes") + { + // pass + } else { for (Operand* r : op->inputs) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 64d88a77a..12f207336 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -218,6 +218,7 @@ pnnx_ncnn_add_test(ncnn_fuse_binaryop_eltwise) pnnx_ncnn_add_test(ncnn_fuse_pad_conv) pnnx_ncnn_add_test(ncnn_numpy_binaryop_broadcast) pnnx_ncnn_add_test(ncnn_reshape_expr) +pnnx_ncnn_add_test(ncnn_slice_expr) if(TorchVision_FOUND) pnnx_ncnn_add_test(torchvision_DeformConv2d) diff --git a/tools/pnnx/tests/ncnn/test_ncnn_slice_expr.py b/tools/pnnx/tests/ncnn/test_ncnn_slice_expr.py new file mode 100644 index 000000000..47fc6b504 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_ncnn_slice_expr.py @@ -0,0 +1,112 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + self.conv0 = nn.Conv2d(in_channels=3, out_channels=15, kernel_size=1) + + def forward(self, x, y, z): + out0 = x[x.size(0)//128:x.size(0)-3] + out0 = out0.clone() * 2 + + out1 = y[...,y.size(0)//8] + out1 = out1.clone() * 3.3 + + z = self.conv0(z) + out2 = z[:,x.size(0)//64-y.size(1)//128,z.size(3)-3:-z.size(2)//7,z.size(2)//3] + out2 = out2.clone() * 1.5 + + out3 = z[...,x.size(0)//128:-z.size(2)//10,z.size(3)//8:-z.size(3)//8] + out3 = out3.clone() * -10 + + return out0, out1, out2, out3 + +def test(): + net = Model().half().float() + net.eval() + + torch.manual_seed(0) + x0 = torch.rand(128) + y0 = torch.rand(64, 16) + z0 = torch.rand(1, 3, 39, 16) + + x1 = torch.rand(256) + y1 = torch.rand(32, 128) + z1 = torch.rand(1, 3, 15, 33) + + a0 = net(x0, y0, z0) + a1 = net(x1, y1, z1) + + # export torchscript + if version.parse(torch.__version__) < version.parse('2.0'): + mod = torch.jit.trace(net, (x0, y0, z0)) + else: + mod = torch.jit.trace(net, (x0, y0, z0), _store_inputs=False) + mod.save("test_ncnn_slice_expr.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_ncnn_slice_expr.pt inputshape=[128],[64,16],[1,3,39,16] inputshape2=[256],[32,128],[1,3,15,33]") + + # ncnn inference + import numpy as np + import ncnn + b0 = [] + b1 = [] + with ncnn.Net() as net: + net.load_param("test_ncnn_slice_expr.ncnn.param") + net.load_model("test_ncnn_slice_expr.ncnn.bin") + + with net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x0.numpy()).clone()) + ex.input("in1", ncnn.Mat(y0.numpy()).clone()) + ex.input("in2", ncnn.Mat(z0.squeeze(0).numpy()).clone()) + + _, out0 = ex.extract("out0") + _, out1 = ex.extract("out1") + b0.append(torch.from_numpy(np.array(out0))) + b0.append(torch.from_numpy(np.array(out1)).unsqueeze(0)) + + with net.create_extractor() as ex: + ex.input("in0", ncnn.Mat(x1.numpy()).clone()) + ex.input("in1", ncnn.Mat(y1.numpy()).clone()) + ex.input("in2", ncnn.Mat(z1.squeeze(0).numpy()).clone()) + + _, out0 = ex.extract("out0") + _, out1 = ex.extract("out1") + b1.append(torch.from_numpy(np.array(out0))) + b1.append(torch.from_numpy(np.array(out1)).unsqueeze(0)) + + for aa, bb in zip(a0, b0): + if not torch.allclose(aa, bb, 1e-4, 1e-4): + return False + + for aa, bb in zip(a1, b1): + if not torch.allclose(aa, bb, 1e-4, 1e-4): + return False + + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)