Browse Source

fix concatv2 axis

tags/20200727
nihui 6 years ago
parent
commit
02e880fc42
1 changed files with 46 additions and 8 deletions
  1. +46
    -8
      tools/mlir/mlir2ncnn.cpp

+ 46
- 8
tools/mlir/mlir2ncnn.cpp View File

@@ -441,7 +441,7 @@ static std::string get_attr_s(const mlir::Attribute& attr)
return s;
}

static int get_attr_i(const mlir::Attribute& attr)
static int get_attr_b(const mlir::Attribute& attr)
{
int i;

@@ -451,12 +451,28 @@ static int get_attr_i(const mlir::Attribute& attr)

i = a.getValue() ? 1 : 0;
}
else if (attr.isa<mlir::IntegerAttr>())
else
{
fprintf(stderr, "not BoolAttr\n");
}

return i;
}

static int get_attr_i(const mlir::Attribute& attr)
{
int i;

if (attr.isa<mlir::IntegerAttr>())
{
mlir::IntegerAttr a = attr.cast<mlir::IntegerAttr>();

i = (int)a.getInt();
}
else
{
fprintf(stderr, "not IntegerAttr\n");
}

return i;
}
@@ -471,6 +487,10 @@ static float get_attr_f(const mlir::Attribute& attr)

f = (float)a.getValueAsDouble();
}
else
{
fprintf(stderr, "not FloatAttr\n");
}

return f;
}
@@ -504,6 +524,10 @@ static std::vector<int> get_attr_ai(const mlir::Attribute& attr)
v.push_back(ii.getSExtValue());
}
}
else
{
fprintf(stderr, "not ArrayAttr or DenseIntElementsAttr\n");
}

return v;
}
@@ -537,6 +561,10 @@ static std::vector<float> get_attr_af(const mlir::Attribute& attr)
v.push_back(ff.convertToFloat());
}
}
else
{
fprintf(stderr, "not ArrayAttr or DenseFPElementsAttr\n");
}

return v;
}
@@ -550,6 +578,15 @@ static std::string get_operation_attr_s(const mlir::Operation& _operation, const
return get_attr_s(attr);
}

static int get_operation_attr_b(const mlir::Operation& _operation, const char* key)
{
mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);

mlir::Attribute attr = operation.getAttr(key);

return get_attr_b(attr);
}

static int get_operation_attr_i(const mlir::Operation& _operation, const char* key)
{
mlir::Operation& operation = const_cast<mlir::Operation&>(_operation);
@@ -818,7 +855,7 @@ int main(int argc, char** argv)

std::vector<int> v = get_attr_ai(R);

int keep_dims = get_operation_attr_i(operation, "keep_dims");
int keep_dims = get_operation_attr_b(operation, "keep_dims");

if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2)
{
@@ -962,11 +999,12 @@ int main(int argc, char** argv)
}
else if (op == "tf.ConcatV2")
{
std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(num_input));
std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(operation.getNumOperands() - 1));
const mlir::Attribute& A = weights[axis_name];

int axis = get_attr_i(A);
int axis = get_attr_ai(A)[0];

// axis nhc to nhw
// axis nhwc to nchw
int dims = operation.getOperand(0).getType().cast<mlir::RankedTensorType>().getShape().size();

@@ -1391,7 +1429,7 @@ int main(int argc, char** argv)

std::vector<int> v = get_attr_ai(R);

int keep_dims = get_operation_attr_i(operation, "keep_dims");
int keep_dims = get_operation_attr_b(operation, "keep_dims");

if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2)
{
@@ -1473,8 +1511,8 @@ int main(int argc, char** argv)

std::vector<int> size = get_attr_ai(P);

int align_corners = get_operation_attr_i(operation, "align_corners");
int half_pixel_centers = get_operation_attr_i(operation, "half_pixel_centers");
int align_corners = get_operation_attr_b(operation, "align_corners");
int half_pixel_centers = get_operation_attr_b(operation, "half_pixel_centers");
if (!(align_corners == 0 && half_pixel_centers == 1))
{
fprintf(stderr, "Unsupported ResizeNearestNeighbor align_corners %d half_pixel_centers %d !\n", align_corners, half_pixel_centers);


Loading…
Cancel
Save