From 02e880fc42a0eda049d3ce729cdf45227a1ee6e7 Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 11 Jul 2020 13:28:46 +0800 Subject: [PATCH] fix concatv2 axis --- tools/mlir/mlir2ncnn.cpp | 54 ++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/tools/mlir/mlir2ncnn.cpp b/tools/mlir/mlir2ncnn.cpp index cbd11cc3e..132e4bd6e 100644 --- a/tools/mlir/mlir2ncnn.cpp +++ b/tools/mlir/mlir2ncnn.cpp @@ -441,7 +441,7 @@ static std::string get_attr_s(const mlir::Attribute& attr) return s; } -static int get_attr_i(const mlir::Attribute& attr) +static int get_attr_b(const mlir::Attribute& attr) { int i; @@ -451,12 +451,28 @@ static int get_attr_i(const mlir::Attribute& attr) i = a.getValue() ? 1 : 0; } - else if (attr.isa()) + else + { + fprintf(stderr, "not BoolAttr\n"); + } + + return i; +} + +static int get_attr_i(const mlir::Attribute& attr) +{ + int i; + + if (attr.isa()) { mlir::IntegerAttr a = attr.cast(); i = (int)a.getInt(); } + else + { + fprintf(stderr, "not IntegerAttr\n"); + } return i; } @@ -471,6 +487,10 @@ static float get_attr_f(const mlir::Attribute& attr) f = (float)a.getValueAsDouble(); } + else + { + fprintf(stderr, "not FloatAttr\n"); + } return f; } @@ -504,6 +524,10 @@ static std::vector get_attr_ai(const mlir::Attribute& attr) v.push_back(ii.getSExtValue()); } } + else + { + fprintf(stderr, "not ArrayAttr or DenseIntElementsAttr\n"); + } return v; } @@ -537,6 +561,10 @@ static std::vector get_attr_af(const mlir::Attribute& attr) v.push_back(ff.convertToFloat()); } } + else + { + fprintf(stderr, "not ArrayAttr or DenseFPElementsAttr\n"); + } return v; } @@ -550,6 +578,15 @@ static std::string get_operation_attr_s(const mlir::Operation& _operation, const return get_attr_s(attr); } +static int get_operation_attr_b(const mlir::Operation& _operation, const char* key) +{ + mlir::Operation& operation = const_cast(_operation); + + mlir::Attribute attr = operation.getAttr(key); + + return get_attr_b(attr); +} + static int get_operation_attr_i(const mlir::Operation& _operation, const char* key) { mlir::Operation& operation = const_cast(_operation); @@ -818,7 +855,7 @@ int main(int argc, char** argv) std::vector v = get_attr_ai(R); - int keep_dims = get_operation_attr_i(operation, "keep_dims"); + int keep_dims = get_operation_attr_b(operation, "keep_dims"); if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2) { @@ -962,11 +999,12 @@ int main(int argc, char** argv) } else if (op == "tf.ConcatV2") { - std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(num_input)); + std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(operation.getNumOperands() - 1)); const mlir::Attribute& A = weights[axis_name]; - int axis = get_attr_i(A); + int axis = get_attr_ai(A)[0]; + // axis nhc to nhw // axis nhwc to nchw int dims = operation.getOperand(0).getType().cast().getShape().size(); @@ -1391,7 +1429,7 @@ int main(int argc, char** argv) std::vector v = get_attr_ai(R); - int keep_dims = get_operation_attr_i(operation, "keep_dims"); + int keep_dims = get_operation_attr_b(operation, "keep_dims"); if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2) { @@ -1473,8 +1511,8 @@ int main(int argc, char** argv) std::vector size = get_attr_ai(P); - int align_corners = get_operation_attr_i(operation, "align_corners"); - int half_pixel_centers = get_operation_attr_i(operation, "half_pixel_centers"); + int align_corners = get_operation_attr_b(operation, "align_corners"); + int half_pixel_centers = get_operation_attr_b(operation, "half_pixel_centers"); if (!(align_corners == 0 && half_pixel_centers == 1)) { fprintf(stderr, "Unsupported ResizeNearestNeighbor align_corners %d half_pixel_centers %d !\n", align_corners, half_pixel_centers);