| @@ -831,3 +831,8 @@ if(MSVC OR WIN32) | |||||
| endif() | endif() | ||||
| endforeach() | endforeach() | ||||
| endif() | endif() | ||||
| if(MGE_WITH_JIT_MLIR) | |||||
| add_subdirectory(tools/mlir/mgb-opt) | |||||
| add_subdirectory(tools/mlir/mgb-file-check) | |||||
| endif() | |||||
| @@ -297,7 +297,7 @@ void JITFusionPass::Impl::process_opr(OperatorNodeBase* opr) { | |||||
| #if MGB_JIT_MLIR | #if MGB_JIT_MLIR | ||||
| //! FIXME mlir does't support broadcast currently. | //! FIXME mlir does't support broadcast currently. | ||||
| auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | auto backend = MGB_GETENV("MGB_JIT_BACKEND"); | ||||
| if (!strcmp(backend, "MLIR")) { | |||||
| if (backend && !strcmp(backend, "MLIR")) { | |||||
| for (VarNode* var : opr->input()) { | for (VarNode* var : opr->input()) { | ||||
| if (!SymbolVar{var}.as_immutable_scalar().valid()) { | if (!SymbolVar{var}.as_immutable_scalar().valid()) { | ||||
| if (opr->node_prop().dep_map().at(var) & | if (opr->node_prop().dep_map().at(var) & | ||||
| @@ -44,6 +44,7 @@ | |||||
| using namespace mlir; | using namespace mlir; | ||||
| namespace { | |||||
| template <typename OpTy> | template <typename OpTy> | ||||
| static void createForAllDimensions(OpBuilder& builder, Location loc, | static void createForAllDimensions(OpBuilder& builder, Location loc, | ||||
| SmallVectorImpl<Value>& values) { | SmallVectorImpl<Value>& values) { | ||||
| @@ -80,7 +81,7 @@ static bool isSinkingBeneficiary(Operation* op) { | |||||
| return isa<ConstantOp, DimOp>(op); | return isa<ConstantOp, DimOp>(op); | ||||
| } | } | ||||
| LogicalResult mlir::sinkOperationsIntoLaunchOp(gpu::LaunchOp launchOp) { | |||||
| LogicalResult sink_operations_into_launch_op(gpu::LaunchOp launchOp) { | |||||
| Region& launchOpBody = launchOp.body(); | Region& launchOpBody = launchOp.body(); | ||||
| // Identify uses from values defined outside of the scope of the launch | // Identify uses from values defined outside of the scope of the launch | ||||
| @@ -232,7 +233,6 @@ static void convertToLaunchFuncOp(gpu::LaunchOp launchOp, | |||||
| launchOp.erase(); | launchOp.erase(); | ||||
| } | } | ||||
| namespace { | |||||
| /// Pass that moves the kernel of each LaunchOp into its separate nested module. | /// Pass that moves the kernel of each LaunchOp into its separate nested module. | ||||
| /// | /// | ||||
| /// This pass moves the kernel code of each LaunchOp into a function created | /// This pass moves the kernel code of each LaunchOp into a function created | ||||
| @@ -258,7 +258,7 @@ public: | |||||
| .str(); | .str(); | ||||
| // Pull in instructions that can be sunk | // Pull in instructions that can be sunk | ||||
| if (failed(sinkOperationsIntoLaunchOp(op))) | |||||
| if (failed(sink_operations_into_launch_op(op))) | |||||
| return WalkResult::interrupt(); | return WalkResult::interrupt(); | ||||
| gpu::GPUFuncOp outlinedFunc = | gpu::GPUFuncOp outlinedFunc = | ||||
| outlineKernelFuncImpl(op, kernelFnName, operands); | outlineKernelFuncImpl(op, kernelFnName, operands); | ||||
| @@ -327,7 +327,6 @@ private: | |||||
| return kernelModule; | return kernelModule; | ||||
| } | } | ||||
| }; | }; | ||||
| } // namespace | } // namespace | ||||
| std::unique_ptr<mlir::Pass> mgb::jit::create_gpu_kernel_outlining_pass() { | std::unique_ptr<mlir::Pass> mgb::jit::create_gpu_kernel_outlining_pass() { | ||||
| @@ -20,13 +20,12 @@ | |||||
| #include "./each_mode.h" | #include "./each_mode.h" | ||||
| #include <llvm/ADT/Sequence.h> | |||||
| #include <mlir/Dialect/Affine/IR/AffineOps.h> | #include <mlir/Dialect/Affine/IR/AffineOps.h> | ||||
| #include <mlir/Pass/Pass.h> | #include <mlir/Pass/Pass.h> | ||||
| #include <mlir/Transforms/DialectConversion.h> | #include <mlir/Transforms/DialectConversion.h> | ||||
| #include "mlir/IR/StandardTypes.h" | #include "mlir/IR/StandardTypes.h" | ||||
| #include <llvm/ADT/Sequence.h> | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace jit; | using namespace jit; | ||||
| @@ -188,6 +187,7 @@ struct ReturnOpLowering : public OpRewritePattern<jit::ReturnOp> { | |||||
| LogicalResult matchAndRewrite(jit::ReturnOp op, | LogicalResult matchAndRewrite(jit::ReturnOp op, | ||||
| PatternRewriter& rewriter) const final { | PatternRewriter& rewriter) const final { | ||||
| // We lower "mgb.return" directly to "std.return". | |||||
| rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | rewriter.replaceOpWithNewOp<mlir::ReturnOp>(op); | ||||
| return success(); | return success(); | ||||
| } | } | ||||
| @@ -212,6 +212,7 @@ public: | |||||
| void runOnFunction() override final { | void runOnFunction() override final { | ||||
| ConversionTarget target(getContext()); | ConversionTarget target(getContext()); | ||||
| target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | target.addLegalDialect<AffineDialect, StandardOpsDialect>(); | ||||
| // target.addLegalDialect<AffineDialect>(); | |||||
| target.addIllegalDialect<MgbDialect>(); | target.addIllegalDialect<MgbDialect>(); | ||||
| OwningRewritePatternList patterns; | OwningRewritePatternList patterns; | ||||
| @@ -236,6 +237,16 @@ std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_affine_pass() { | |||||
| return std::make_unique<MgbToAffineLoweringPass>(); | return std::make_unique<MgbToAffineLoweringPass>(); | ||||
| } | } | ||||
| namespace mgb { | |||||
| namespace jit { | |||||
| void register_test_mgb_to_affine_lowering_pass() { | |||||
| PassRegistration<MgbToAffineLoweringPass>( | |||||
| "mgb-convert-to-affine", | |||||
| "Perform conversion from MGB Dialect to Affine Dialect ", | |||||
| [] { return std::make_unique<MgbToAffineLoweringPass>(); }); | |||||
| } | |||||
| } // namespace jit | |||||
| } // namespace mgb | |||||
| #endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -53,6 +53,16 @@ std::unique_ptr<mlir::Pass> mgb::jit::create_lower_to_llvm_pass() { | |||||
| return std::make_unique<AffineToLLVMLoweringPass>(); | return std::make_unique<AffineToLLVMLoweringPass>(); | ||||
| } | } | ||||
| namespace mgb { | |||||
| namespace jit { | |||||
| void register_test_affine_to_llvm_lowering_pass() { | |||||
| PassRegistration<AffineToLLVMLoweringPass>( | |||||
| "mgb-codegen-convert-affine-to-llvm", | |||||
| "Perform final conversion from Affine to LLVMIR ", | |||||
| [] { return std::make_unique<AffineToLLVMLoweringPass>(); }); | |||||
| } | |||||
| } // namespace jit | |||||
| } // namespace mgb | |||||
| #endif // MGB_JIT && MGB_JIT_MLIR | #endif // MGB_JIT && MGB_JIT_MLIR | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -177,6 +177,12 @@ def ReturnOp : GenericOp<"return", | |||||
| The operation takes an no tensor operand and produces no results. | The operation takes an no tensor operand and produces no results. | ||||
| }]; | }]; | ||||
| // The return operation takes an optional input operand to return. This | |||||
| // value must match the return type of the enclosing function. | |||||
| let arguments = (ins); | |||||
| // The return operation only emits the input in the format if it is present. | |||||
| let assemblyFormat = "attr-dict"; | |||||
| } | } | ||||
| def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | def ConstantScalarOp: GenericOp<"sconst", [NoSideEffect]> { | ||||
| @@ -19,7 +19,7 @@ | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace jit { | namespace jit { | ||||
| inline const bool is_elemwise_float(const mlir::Type& dt) { | |||||
| inline bool is_elemwise_float(const mlir::Type& dt) { | |||||
| if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | if (auto cast = dt.dyn_cast_or_null<mlir::MemRefType>()) { | ||||
| if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | if (cast.getElementType().getKind() == mlir::StandardTypes::F32) { | ||||
| return true; | return true; | ||||
| @@ -0,0 +1,27 @@ | |||||
| configure_lit_site_cfg( | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/utils/lit.site.cfg.py.in | |||||
| ${CMAKE_CURRENT_BINARY_DIR}/utils/lit.site.cfg.py | |||||
| MAIN_CONFIG | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/utils/lit.cfg.py | |||||
| ) | |||||
| set(LLVM_EXTERNAL_LIT "${PROJECT_SOURCE_DIR}/third_party/llvm-project/llvm/utils/lit/lit.py" CACHE STRING "External lit") | |||||
| set(MLIR_MGB_TEST_DEPENDS | |||||
| mgb-file-check | |||||
| count not | |||||
| mgb-opt | |||||
| ) | |||||
| add_lit_testsuite(mgb-mlir-test-lit "Running the mgb regression tests" | |||||
| ${CMAKE_CURRENT_BINARY_DIR}/utils | |||||
| DEPENDS ${MLIR_MGB_TEST_DEPENDS} | |||||
| ) | |||||
| set_target_properties(mgb-mlir-test-lit PROPERTIES FOLDER "Tests") | |||||
| add_lit_testsuites(MLIR_TEST ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| DEPENDS ${MLIR_MGB_TEST_DEPENDS} | |||||
| ) | |||||
| add_custom_target(mlir_pass_check) | |||||
| add_dependencies(mlir_pass_check mgb-mlir-test-lit) | |||||
| @@ -0,0 +1,16 @@ | |||||
| load("//brain/megbrain/src/jit/test/mlir/utils:lit.bzl", "mlir_lit_test_suite") | |||||
| filegroup( | |||||
| name = "mlir_test_tools", | |||||
| testonly = True, | |||||
| data = [ | |||||
| "//brain/megbrain/tools/mlir:mgb-opt", | |||||
| "//brain/megbrain/tools/mlir:mgb-file-check" | |||||
| ], | |||||
| ) | |||||
| mlir_lit_test_suite( | |||||
| name = "mlir_pass_check", | |||||
| data = [":mlir_test_tools"], | |||||
| test_file_exts = ["mlir",] | |||||
| ) | |||||
| @@ -0,0 +1,58 @@ | |||||
| // RUN: mgb-opt --mgb-convert-to-affine --split-input-file -canonicalize -cse %s | mgb-file-check %s | |||||
| // RUN: mgb-opt --mgb-convert-to-affine --mgb-codegen-convert-affine-to-llvm --split-input-file -canonicalize -cse %s | |||||
| func @add_dim1(%lhs: memref<2xf32>, %rhs: memref<2xf32>, %res: memref<2xf32>) -> () { | |||||
| %0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : | |||||
| (memref<2xf32>, memref<2xf32>) -> memref<2xf32> | |||||
| "mgb.assign"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () | |||||
| mgb.return | |||||
| } | |||||
| // CHECK-LABEL: func @add_dim1(%arg0: memref<2xf32>, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { | |||||
| // CHECK: %0 = alloc() : memref<2xf32> | |||||
| // CHECK: affine.for %arg3 = 0 to 2 { | |||||
| // CHECK: %1 = affine.load %arg0[%arg3] : memref<2xf32> | |||||
| // CHECK: %2 = affine.load %arg1[%arg3] : memref<2xf32> | |||||
| // CHECK: %3 = addf %1, %2 : f32 | |||||
| // CHECK: affine.store %3, %0[%arg3] : memref<2xf32> | |||||
| // CHECK: } | |||||
| // CHECK: affine.for %arg3 = 0 to 2 { | |||||
| // CHECK: %1 = affine.load %0[%arg3] : memref<2xf32> | |||||
| // CHECK: affine.store %1, %arg2[%arg3] : memref<2xf32> | |||||
| // CHECK: } | |||||
| // CHECK: dealloc %0 : memref<2xf32> | |||||
| // CHECK: return | |||||
| // CHECK: } | |||||
| func @add_dim4(%lhs: memref<4x3x64x64xf32>, %rhs: memref<4x3x64x64xf32>, %res: memref<4x3x64x64xf32>) -> () { | |||||
| %0 = "mgb.add"(%lhs, %rhs) {name = "add.f"} : | |||||
| (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> memref<4x3x64x64xf32> | |||||
| "mgb.assign"(%0, %res) : (memref<4x3x64x64xf32>, memref<4x3x64x64xf32>) -> () | |||||
| mgb.return | |||||
| } | |||||
| // CHECK-LABEL: func @add_dim4(%arg0: memref<4x3x64x64xf32>, %arg1: memref<4x3x64x64xf32>, %arg2: memref<4x3x64x64xf32>) { | |||||
| // CHECK: %0 = alloc() : memref<4x3x64x64xf32> | |||||
| // CHECK: affine.for %arg3 = 0 to 4 { | |||||
| // CHECK: affine.for %arg4 = 0 to 3 { | |||||
| // CHECK: affine.for %arg5 = 0 to 64 { | |||||
| // CHECK: affine.for %arg6 = 0 to 64 { | |||||
| // CHECK: %1 = affine.load %arg0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
| // CHECK: %2 = affine.load %arg1[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
| // CHECK: %3 = addf %1, %2 : f32 | |||||
| // CHECK: affine.store %3, %0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
| // CHECK: } | |||||
| // CHECK: } | |||||
| // CHECK: } | |||||
| // CHECK: } | |||||
| // CHECK: affine.for %arg3 = 0 to 4 { | |||||
| // CHECK: affine.for %arg4 = 0 to 3 { | |||||
| // CHECK: affine.for %arg5 = 0 to 64 { | |||||
| // CHECK: affine.for %arg6 = 0 to 64 { | |||||
| // CHECK: %1 = affine.load %0[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
| // CHECK: affine.store %1, %arg2[%arg3, %arg4, %arg5, %arg6] : memref<4x3x64x64xf32> | |||||
| // CHECK: } | |||||
| // CHECK: } | |||||
| // CHECK: } | |||||
| // CHECK: } | |||||
| // CHECK: dealloc %0 : memref<4x3x64x64xf32> | |||||
| // CHECK: return | |||||
| // CHECK: } | |||||
| @@ -0,0 +1,5 @@ | |||||
| filegroup( | |||||
| name = "litfiles", | |||||
| srcs = glob(["lit.bzl.*py"]), | |||||
| visibility = ["//visibility:public"], | |||||
| ) | |||||
| @@ -0,0 +1,127 @@ | |||||
| # Test definitions for Lit, the LLVM test runner. | |||||
| # | |||||
| """Lit runner globbing test | |||||
| """ | |||||
| # Default values used by the test runner. | |||||
| _default_test_file_exts = ["mlir", "pbtxt", "td"] | |||||
| _default_size = "small" | |||||
| _default_tags = [] | |||||
| # These are patterns which we should never match, for tests, subdirectories, or | |||||
| # test input data files. | |||||
| _ALWAYS_EXCLUDE = [ | |||||
| "**/LICENSE.txt", | |||||
| "**/README.txt", | |||||
| "**/lit.local.cfg", | |||||
| # Exclude input files that have spaces in their names, since bazel | |||||
| # cannot cope with such "targets" in the srcs list. | |||||
| "**/* *", | |||||
| "**/* */**", | |||||
| ] | |||||
| def _run_lit_test(name, data, size, tags, features): | |||||
| """Runs lit on all tests it can find in `data` under megbrain/src/jit/test/mlir/ir. | |||||
| Note that, due to Bazel's hermetic builds, lit only sees the tests that | |||||
| are included in the `data` parameter, regardless of what other tests might | |||||
| exist in the directory searched. | |||||
| Args: | |||||
| name: str, the name of the test, including extension. | |||||
| data: [str], the data input to the test. | |||||
| size: str, the size of the test. | |||||
| tags: [str], tags to attach to the test. | |||||
| features: [str], list of extra features to enable. | |||||
| """ | |||||
| native.py_test( | |||||
| name = name, | |||||
| srcs = ["@llvm-project//llvm:lit"], | |||||
| tags = tags, | |||||
| args = [ | |||||
| "brain/megbrain/src/jit/test/mlir/utils --config-prefix=lit.bzl -v", | |||||
| ] + features, | |||||
| data = data + [ | |||||
| "//brain/megbrain/src/jit/test/mlir/utils:litfiles", | |||||
| "//brain/megbrain/tools/mlir:mgb-file-check", | |||||
| "@llvm-project//llvm:count", | |||||
| "@llvm-project//llvm:not", | |||||
| ], | |||||
| size = size, | |||||
| main = "lit.py", | |||||
| ) | |||||
| def mlir_lit_test_suite( | |||||
| name, | |||||
| exclude = [], | |||||
| test_file_exts = _default_test_file_exts, | |||||
| default_size = _default_size, | |||||
| size_override = {}, | |||||
| data = [], | |||||
| per_test_extra_data = {}, | |||||
| default_tags = _default_tags, | |||||
| tags_override = {}, | |||||
| features = []): | |||||
| """Creates all plausible Lit tests (and their inputs) under this directory. | |||||
| Args: | |||||
| name: str, name of the generated test suite. | |||||
| exclude: [str], paths to exclude (for tests and inputs). | |||||
| test_file_exts: [str], extensions for files that are tests. | |||||
| default_size: str, the test size for targets not in "size_override". | |||||
| size_override: {str: str}, sizes to use for specific tests. | |||||
| data: [str], additional input data to the test. | |||||
| per_test_extra_data: {str: [str]}, extra data to attach to a given file. | |||||
| default_tags: [str], additional tags to attach to the test. | |||||
| tags_override: {str: str}, tags to add to specific tests. | |||||
| features: [str], list of extra features to enable. | |||||
| """ | |||||
| # Ignore some patterns by default for tests and input data. | |||||
| exclude = _ALWAYS_EXCLUDE + exclude | |||||
| test_names = [] | |||||
| tests = native.glob( | |||||
| ["*." + ext for ext in test_file_exts], | |||||
| exclude = exclude, | |||||
| ) | |||||
| # Run tests individually such that errors can be attributed to a specific | |||||
| # failure. | |||||
| for i in range(len(tests)): | |||||
| cur_test = tests[i] | |||||
| # Instantiate this test with updated parameters. | |||||
| internal_name = cur_test | |||||
| lit_test( | |||||
| name = internal_name, | |||||
| data = data + per_test_extra_data.pop(cur_test, []), | |||||
| size = size_override.pop(cur_test, default_size), | |||||
| tags = ["windows_fail"] + default_tags + tags_override.pop(cur_test, []), | |||||
| features = features, | |||||
| ) | |||||
| test_names.append(internal_name + ".test") | |||||
| native.test_suite( | |||||
| name = name, | |||||
| tests = test_names, | |||||
| tags = default_tags, | |||||
| ) | |||||
| def lit_test( | |||||
| name, | |||||
| data = [], | |||||
| size = _default_size, | |||||
| tags = _default_tags, | |||||
| features = []): | |||||
| """Runs test files under lit. | |||||
| Args: | |||||
| name: str, the name of the test. | |||||
| data: [str], labels that should be provided as data inputs. | |||||
| size: str, the size of the test. | |||||
| tags: [str], tags to attach to the test. | |||||
| features: [str], list of extra features to enable. | |||||
| """ | |||||
| _run_lit_test(name + ".test", data + [name], size, tags, features) | |||||
| @@ -0,0 +1,52 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import platform | |||||
| import re | |||||
| import subprocess | |||||
| import tempfile | |||||
| import lit.formats | |||||
| import lit.util | |||||
| from lit.llvm import llvm_config | |||||
| from lit.llvm.subst import ToolSubst | |||||
| from lit.llvm.subst import FindTool | |||||
| # Configuration file for the 'lit' test runner. | |||||
| # name: The name of this test suite. | |||||
| config.name = 'MLIR_TEST' | |||||
| config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) | |||||
| # suffixes: A list of file extensions to treat as test files. | |||||
| config.suffixes = ['.mlir'] | |||||
| # test_source_root: The root path where tests are located. | |||||
| config.test_source_root = config.mlir_test_dir | |||||
| # test_exec_root: The root path where tests should be run. | |||||
| config.test_exec_root = os.environ['RUNFILES_DIR'] | |||||
| llvm_config.use_default_substitutions() | |||||
| # Tweak the PATH to include the tools dir. | |||||
| llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) | |||||
| tool_dirs = config.mlir_mgb_tools_dirs + [config.mlir_tools_dir, config.llvm_tools_dir] | |||||
| tool_names = [ | |||||
| 'mgb-opt', | |||||
| 'mlir-tblgen', | |||||
| 'mlir-translate', | |||||
| 'mgb-file-check', | |||||
| ] | |||||
| tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] | |||||
| llvm_config.add_tool_substitutions(tools, tool_dirs) | |||||
| @@ -0,0 +1,43 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| """Lit runner site configuration.""" | |||||
| import os | |||||
| import lit.llvm | |||||
| config.llvm_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'llvm') | |||||
| config.mlir_obj_root = os.path.join(os.environ['TEST_SRCDIR']) | |||||
| config.mlir_tools_dir = os.path.join(os.environ['TEST_SRCDIR'], 'llvm-project', 'mlir') | |||||
| config.suffixes = ['.td', '.mlir', '.pbtxt'] | |||||
| mlir_mgb_tools_dirs = [ | |||||
| 'brain/megbrain/tools/mlir', | |||||
| ] | |||||
| config.mlir_mgb_tools_dirs = [ | |||||
| os.path.join(os.environ['TEST_SRCDIR'], os.environ['TEST_WORKSPACE'], s) | |||||
| for s in mlir_mgb_tools_dirs | |||||
| ] | |||||
| test_dir = os.environ['TEST_TARGET'] | |||||
| test_dir = test_dir.strip('/').rsplit(':', 1)[0] | |||||
| config.mlir_test_dir = os.path.join( | |||||
| os.environ['TEST_SRCDIR'], | |||||
| os.environ['TEST_WORKSPACE'], | |||||
| test_dir, | |||||
| ) | |||||
| lit.llvm.initialize(lit_config, config) | |||||
| # Let the main config do the real work. | |||||
| lit_config.load_config( | |||||
| config, | |||||
| os.path.join( | |||||
| os.path.join( | |||||
| os.environ['TEST_SRCDIR'], | |||||
| os.environ['TEST_WORKSPACE'], | |||||
| 'brain/megbrain/src/jit/test/mlir/utils/lit.bzl.cfg.py', | |||||
| ))) | |||||
| @@ -0,0 +1,49 @@ | |||||
| @LIT_SITE_CFG_IN_HEADER@ | |||||
| import sys | |||||
| config.host_triple = "@LLVM_HOST_TRIPLE@" | |||||
| config.target_triple = "@TARGET_TRIPLE@" | |||||
| config.llvm_src_root = "@LLVM_SOURCE_DIR@" | |||||
| config.llvm_obj_root = "@LLVM_BINARY_DIR@" | |||||
| config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" | |||||
| config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" | |||||
| config.llvm_shlib_dir = "@SHLIBDIR@" | |||||
| config.llvm_shlib_ext = "@SHLIBEXT@" | |||||
| config.llvm_exe_ext = "@EXEEXT@" | |||||
| config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" | |||||
| config.python_executable = "@PYTHON_EXECUTABLE@" | |||||
| config.gold_executable = "@GOLD_EXECUTABLE@" | |||||
| config.ld64_executable = "@LD64_EXECUTABLE@" | |||||
| config.enable_shared = @ENABLE_SHARED@ | |||||
| config.enable_assertions = @ENABLE_ASSERTIONS@ | |||||
| config.targets_to_build = "@TARGETS_TO_BUILD@" | |||||
| config.native_target = "@LLVM_NATIVE_ARCH@" | |||||
| config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') | |||||
| config.host_os = "@HOST_OS@" | |||||
| config.host_cc = "@HOST_CC@" | |||||
| config.host_cxx = "@HOST_CXX@" | |||||
| # Note: ldflags can contain double-quoted paths, so must use single quotes here. | |||||
| config.host_ldflags = '@HOST_LDFLAGS@' | |||||
| config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" | |||||
| config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' | |||||
| config.host_arch = "@HOST_ARCH@" | |||||
| config.mgb_src_root = "@CMAKE_SOURCE_DIR@" | |||||
| config.mgb_obj_root = "@CMAKE_BINARY_DIR@" | |||||
| # Support substitution of the tools_dir with user parameters. This is | |||||
| # used when we can't determine the tool dir at configuration time. | |||||
| try: | |||||
| config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params | |||||
| config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params | |||||
| except KeyError: | |||||
| e = sys.exc_info()[1] | |||||
| key, = e.args | |||||
| lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) | |||||
| import lit.llvm | |||||
| lit.llvm.initialize(lit_config, config) | |||||
| # Let the main config do the real work. | |||||
| lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/src/jit/test/mlir/utils/lit.cfg.py") | |||||
| @@ -0,0 +1,58 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import os | |||||
| import platform | |||||
| import re | |||||
| import subprocess | |||||
| import tempfile | |||||
| import lit.formats | |||||
| import lit.util | |||||
| from lit.llvm import llvm_config | |||||
| from lit.llvm.subst import ToolSubst | |||||
| from lit.llvm.subst import FindTool | |||||
| # Configuration file for the 'lit' test runner. | |||||
| # name: The name of this test suite. | |||||
| config.name = 'MLIR_TEST' | |||||
| config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) | |||||
| # suffixes: A list of file extensions to treat as test files. | |||||
| config.suffixes = ['.mlir'] | |||||
| # test_source_root: The root path where tests are located. | |||||
| config.test_source_root = os.path.join(os.path.dirname(__file__), '../ir') | |||||
| # test_exec_root: The root path where tests should be run. | |||||
| config.test_exec_root = config.test_source_root | |||||
| # llvm_config.use_default_substitutions() | |||||
| # Tweak the PATH to include the tools dir. | |||||
| llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) | |||||
| tool_dirs = [ | |||||
| os.path.join(config.mgb_obj_root, 'tools/mlir'), | |||||
| os.path.join(config.mgb_obj_root, 'tools/mlir/mgb-opt'), | |||||
| os.path.join(config.mgb_obj_root, 'tools/mlir/mgb-file-check'), | |||||
| config.llvm_tools_dir] | |||||
| tool_names = [ | |||||
| 'mgb-opt', | |||||
| 'mlir-tblgen', | |||||
| 'mlir-translate', | |||||
| 'mgb-file-check', | |||||
| ] | |||||
| tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] | |||||
| llvm_config.add_tool_substitutions(tools, tool_dirs) | |||||
| lit.llvm.initialize(lit_config, config) | |||||
| @@ -0,0 +1,49 @@ | |||||
| @LIT_SITE_CFG_IN_HEADER@ | |||||
| import sys | |||||
| config.host_triple = "@LLVM_HOST_TRIPLE@" | |||||
| config.target_triple = "@TARGET_TRIPLE@" | |||||
| config.llvm_src_root = "@LLVM_SOURCE_DIR@" | |||||
| config.llvm_obj_root = "@LLVM_BINARY_DIR@" | |||||
| config.llvm_tools_dir = "@LLVM_BINARY_DIR@/bin" | |||||
| config.llvm_lib_dir = "@LLVM_LIBRARY_DIR@" | |||||
| config.llvm_shlib_dir = "@SHLIBDIR@" | |||||
| config.llvm_shlib_ext = "@SHLIBEXT@" | |||||
| config.llvm_exe_ext = "@EXEEXT@" | |||||
| config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" | |||||
| config.python_executable = "@PYTHON_EXECUTABLE@" | |||||
| config.gold_executable = "@GOLD_EXECUTABLE@" | |||||
| config.ld64_executable = "@LD64_EXECUTABLE@" | |||||
| config.enable_shared = @ENABLE_SHARED@ | |||||
| config.enable_assertions = @ENABLE_ASSERTIONS@ | |||||
| config.targets_to_build = "@TARGETS_TO_BUILD@" | |||||
| config.native_target = "@LLVM_NATIVE_ARCH@" | |||||
| config.llvm_bindings = "@LLVM_BINDINGS@".split(' ') | |||||
| config.host_os = "@HOST_OS@" | |||||
| config.host_cc = "@HOST_CC@" | |||||
| config.host_cxx = "@HOST_CXX@" | |||||
| # Note: ldflags can contain double-quoted paths, so must use single quotes here. | |||||
| config.host_ldflags = '@HOST_LDFLAGS@' | |||||
| config.llvm_use_sanitizer = "@LLVM_USE_SANITIZER@" | |||||
| config.llvm_host_triple = '@LLVM_HOST_TRIPLE@' | |||||
| config.host_arch = "@HOST_ARCH@" | |||||
| config.mgb_src_root = "@CMAKE_SOURCE_DIR@" | |||||
| config.mgb_obj_root = "@CMAKE_BINARY_DIR@" | |||||
| # Support substitution of the tools_dir with user parameters. This is | |||||
| # used when we can't determine the tool dir at configuration time. | |||||
| try: | |||||
| config.llvm_tools_dir = config.llvm_tools_dir % lit_config.params | |||||
| config.llvm_shlib_dir = config.llvm_shlib_dir % lit_config.params | |||||
| except KeyError: | |||||
| e = sys.exc_info()[1] | |||||
| key, = e.args | |||||
| lit_config.fatal("unable to find %r parameter, use '--param=%s=VALUE'" % (key,key)) | |||||
| import lit.llvm | |||||
| lit.llvm.initialize(lit_config, config) | |||||
| # Let the main config do the real work. | |||||
| lit_config.load_config(config, "@CMAKE_SOURCE_DIR@/src/jit/test/mlir/utils/lit.cfg.py") | |||||
| @@ -43,3 +43,9 @@ endif() | |||||
| if (MGE_WITH_DISTRIBUTED) | if (MGE_WITH_DISTRIBUTED) | ||||
| target_link_libraries(megbrain_test megray) | target_link_libraries(megbrain_test megray) | ||||
| endif() | endif() | ||||
| if(MGE_WITH_JIT) | |||||
| if(MGE_WITH_JIT_MLIR) | |||||
| add_subdirectory(${PROJECT_SOURCE_DIR}/src/jit/test/mlir ${CMAKE_CURRENT_BINARY_DIR}/../src/jit/test/mlir) | |||||
| endif() | |||||
| endif() | |||||