| @@ -441,7 +441,7 @@ static std::string get_attr_s(const mlir::Attribute& attr) | |||
| return s; | |||
| } | |||
| static int get_attr_i(const mlir::Attribute& attr) | |||
| static int get_attr_b(const mlir::Attribute& attr) | |||
| { | |||
| int i; | |||
| @@ -451,12 +451,28 @@ static int get_attr_i(const mlir::Attribute& attr) | |||
| i = a.getValue() ? 1 : 0; | |||
| } | |||
| else if (attr.isa<mlir::IntegerAttr>()) | |||
| else | |||
| { | |||
| fprintf(stderr, "not BoolAttr\n"); | |||
| } | |||
| return i; | |||
| } | |||
| static int get_attr_i(const mlir::Attribute& attr) | |||
| { | |||
| int i; | |||
| if (attr.isa<mlir::IntegerAttr>()) | |||
| { | |||
| mlir::IntegerAttr a = attr.cast<mlir::IntegerAttr>(); | |||
| i = (int)a.getInt(); | |||
| } | |||
| else | |||
| { | |||
| fprintf(stderr, "not IntegerAttr\n"); | |||
| } | |||
| return i; | |||
| } | |||
| @@ -471,6 +487,10 @@ static float get_attr_f(const mlir::Attribute& attr) | |||
| f = (float)a.getValueAsDouble(); | |||
| } | |||
| else | |||
| { | |||
| fprintf(stderr, "not FloatAttr\n"); | |||
| } | |||
| return f; | |||
| } | |||
| @@ -504,6 +524,10 @@ static std::vector<int> get_attr_ai(const mlir::Attribute& attr) | |||
| v.push_back(ii.getSExtValue()); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| fprintf(stderr, "not ArrayAttr or DenseIntElementsAttr\n"); | |||
| } | |||
| return v; | |||
| } | |||
| @@ -537,6 +561,10 @@ static std::vector<float> get_attr_af(const mlir::Attribute& attr) | |||
| v.push_back(ff.convertToFloat()); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| fprintf(stderr, "not ArrayAttr or DenseFPElementsAttr\n"); | |||
| } | |||
| return v; | |||
| } | |||
| @@ -550,6 +578,15 @@ static std::string get_operation_attr_s(const mlir::Operation& _operation, const | |||
| return get_attr_s(attr); | |||
| } | |||
| static int get_operation_attr_b(const mlir::Operation& _operation, const char* key) | |||
| { | |||
| mlir::Operation& operation = const_cast<mlir::Operation&>(_operation); | |||
| mlir::Attribute attr = operation.getAttr(key); | |||
| return get_attr_b(attr); | |||
| } | |||
| static int get_operation_attr_i(const mlir::Operation& _operation, const char* key) | |||
| { | |||
| mlir::Operation& operation = const_cast<mlir::Operation&>(_operation); | |||
| @@ -818,7 +855,7 @@ int main(int argc, char** argv) | |||
| std::vector<int> v = get_attr_ai(R); | |||
| int keep_dims = get_operation_attr_i(operation, "keep_dims"); | |||
| int keep_dims = get_operation_attr_b(operation, "keep_dims"); | |||
| if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2) | |||
| { | |||
| @@ -962,11 +999,12 @@ int main(int argc, char** argv) | |||
| } | |||
| else if (op == "tf.ConcatV2") | |||
| { | |||
| std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(num_input)); | |||
| std::string axis_name = get_mlir_value_uniq_id(operation.getOperand(operation.getNumOperands() - 1)); | |||
| const mlir::Attribute& A = weights[axis_name]; | |||
| int axis = get_attr_i(A); | |||
| int axis = get_attr_ai(A)[0]; | |||
| // axis nhc to nhw | |||
| // axis nhwc to nchw | |||
| int dims = operation.getOperand(0).getType().cast<mlir::RankedTensorType>().getShape().size(); | |||
| @@ -1391,7 +1429,7 @@ int main(int argc, char** argv) | |||
| std::vector<int> v = get_attr_ai(R); | |||
| int keep_dims = get_operation_attr_i(operation, "keep_dims"); | |||
| int keep_dims = get_operation_attr_b(operation, "keep_dims"); | |||
| if (keep_dims == 0 && v.size() == 2 && v[0] == 1 && v[1] == 2) | |||
| { | |||
| @@ -1473,8 +1511,8 @@ int main(int argc, char** argv) | |||
| std::vector<int> size = get_attr_ai(P); | |||
| int align_corners = get_operation_attr_i(operation, "align_corners"); | |||
| int half_pixel_centers = get_operation_attr_i(operation, "half_pixel_centers"); | |||
| int align_corners = get_operation_attr_b(operation, "align_corners"); | |||
| int half_pixel_centers = get_operation_attr_b(operation, "half_pixel_centers"); | |||
| if (!(align_corners == 0 && half_pixel_centers == 1)) | |||
| { | |||
| fprintf(stderr, "Unsupported ResizeNearestNeighbor align_corners %d half_pixel_centers %d !\n", align_corners, half_pixel_centers); | |||