| @@ -117,6 +117,8 @@ if(CMAKE_TOOLCHAIN_FILE) | |||
| else() | |||
| message(FATAL_ERROR "Unsupported IOS_ARCH.") | |||
| endif() | |||
| elseif(RISCV_TOOLCHAIN_ROOT) | |||
| set(MGE_ARCH "riscv64") | |||
| elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "") | |||
| set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH}) | |||
| else() | |||
| @@ -664,6 +666,11 @@ if(MGE_ARCH STREQUAL "aarch64") | |||
| endif() | |||
| if(MGE_ARCH STREQUAL "riscv64") | |||
| set(MEGDNN_RISCV64 1) | |||
| set(MEGDNN_64_BIT 1) | |||
| endif() | |||
| set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | |||
| set(MGB_ENABLE_IMPERATIVE ${MGE_BUILD_IMPERATIVE_RT}) | |||
| @@ -0,0 +1,21 @@ | |||
| /** | |||
| * \file dnn/src/common/postprocess.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| namespace megdnn { | |||
| enum class PostprocessMode : uint8_t { | |||
| FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||
| NO_PROCESS, ///< support non bias and identity | |||
| QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish | |||
| ///< identify nonline mode | |||
| ADD_BIAS, ///< only add bias | |||
| }; | |||
| } | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * \file dnn/src/common/postprocess_helper.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/basic_types.h" | |||
| #include "midout.h" | |||
| #include "src/common/postprocess.h" | |||
| namespace { | |||
| #define POST_PROCESS_UNUSED_VAR() \ | |||
| MEGDNN_MARK_USED_VAR(conv_dst_ptr); \ | |||
| MEGDNN_MARK_USED_VAR(bias_ptr); \ | |||
| MEGDNN_MARK_USED_VAR(dst_ptr); \ | |||
| MEGDNN_MARK_USED_VAR(bias_mode); \ | |||
| MEGDNN_MARK_USED_VAR(nonlineMode); \ | |||
| MEGDNN_MARK_USED_VAR(bias_type); \ | |||
| MEGDNN_MARK_USED_VAR(dst_type); \ | |||
| MEGDNN_MARK_USED_VAR(N); \ | |||
| MEGDNN_MARK_USED_VAR(OC); \ | |||
| MEGDNN_MARK_USED_VAR(OH); \ | |||
| MEGDNN_MARK_USED_VAR(OW); \ | |||
| MEGDNN_MARK_USED_VAR(pack_oc_size) | |||
| template <typename ctype, typename dtype = ctype, | |||
| megdnn::PostprocessMode postprocess_mode = | |||
| megdnn::PostprocessMode::FLOAT> | |||
| struct PostProcess { | |||
| static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | |||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| POST_PROCESS_UNUSED_VAR(); | |||
| megdnn_throw("not impl PostProcess"); | |||
| } | |||
| }; | |||
| template <typename ctype, typename dtype> | |||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> { | |||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| POST_PROCESS_UNUSED_VAR(); | |||
| megdnn_throw("not impl PostProcess"); | |||
| } | |||
| }; | |||
| template <typename opctype, typename opdtype> | |||
| struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> { | |||
| static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr, | |||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| POST_PROCESS_UNUSED_VAR(); | |||
| megdnn_throw("not impl PostProcess"); | |||
| } | |||
| }; | |||
| template <typename ctype, typename dtype> | |||
| struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> { | |||
| static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr, | |||
| megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode, | |||
| megdnn::DType bias_type, megdnn::DType dst_type, size_t N, | |||
| size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) { | |||
| POST_PROCESS_UNUSED_VAR(); | |||
| megdnn_throw("not impl PostProcess"); | |||
| } | |||
| }; | |||
| } // namespace | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -42,8 +43,12 @@ namespace transpose_fallback { | |||
| #if MEGDNN_X86 | |||
| constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | |||
| #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 /*BEGIN-INLINE-INTERNAL*/ || \ | |||
| MEGDNN_MIPS /*END-INLINE-INTERNAL*/ | |||
| constexpr size_t BLOCK_LINE_SIZE_BYTES = 32; | |||
| #elif MEGDNN_RISCV64 | |||
| //! ref U54-MC arch | |||
| constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | |||
| #else | |||
| #error "unknown megdnn arch" | |||
| #endif | |||
| @@ -6,12 +6,14 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include <stdint.h> | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/postprocess.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| @@ -157,13 +159,6 @@ private: \ | |||
| mutable std::string m_name; \ | |||
| uint32_t m_tile_size; | |||
| enum class PostprocessMode : uint8_t { | |||
| FLOAT = 0, ///< support all biasmode and no_nonlinemode | |||
| NO_PROCESS, ///< support non bias and identity | |||
| QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish | |||
| ///< identify nonline mode | |||
| ADD_BIAS, ///< only add bias | |||
| }; | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -24,6 +24,8 @@ | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||
| #else | |||
| #include "src/common/postprocess_helper.h" | |||
| #endif | |||
| #include "midout.h" | |||
| @@ -106,7 +108,7 @@ ConvBiasImpl::AlgoConv1x1::get_kerns_according_packmode( | |||
| WorkspaceBundle whole_bundle = get_bundle_according_packmode(param); | |||
| //! NO_PACK not implement get_bundle | |||
| WorkspaceBundle matmul_bundle ={nullptr,{}}; | |||
| WorkspaceBundle matmul_bundle = {nullptr, {}}; | |||
| if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) { | |||
| matmul_bundle = {nullptr, | |||
| {0, 0, m_matmul_algo->get_workspace(matmul_param)}}; | |||
| @@ -281,7 +283,6 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param, | |||
| return false; | |||
| } | |||
| bool ConvBiasImpl::AlgoConv1x1::is_preferred( | |||
| const NCBKernSizeParam& param) const { | |||
| size_t OH = param.osz[0]; | |||
| @@ -25,9 +25,11 @@ | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||
| #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | |||
| #include "src/arm_common/matrix_mul/fp16/hgemv.h" | |||
| #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" | |||
| #include "src/arm_common/matrix_mul/int8/gemv.h" | |||
| #else | |||
| #include "src/common/postprocess_helper.h" | |||
| #endif | |||
| #include "midout.h" | |||
| @@ -249,7 +251,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic( | |||
| } | |||
| size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace( | |||
| const NCBKernSizeParam& param) const { | |||
| const NCBKernSizeParam& param) const { | |||
| MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv, | |||
| midout_iv("AlgoConv1x1Gemv::get_workspace"_hash)) { | |||
| size_t compt_oc_block_size = get_oc_tile_size_heuristic(param); | |||
| @@ -335,7 +337,8 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| #else | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16, | |||
| PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | |||
| PostprocessMode::NO_PROCESS, | |||
| "NCHW::GEMV::FLOAT16_FLOAT16"_hash); | |||
| #endif | |||
| #endif | |||
| cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32, | |||
| @@ -361,7 +364,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| dt_uint8, PostprocessMode::QUANTIZED, | |||
| "NCHW::GEMV::QUINT8x8x32_QUINT8"_hash); | |||
| break; | |||
| //!no support nchw44 8x8x16 | |||
| //! no support nchw44 8x8x16 | |||
| case param::ConvBias::Format::NCHW44: | |||
| cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32, | |||
| PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash); | |||
| @@ -377,7 +380,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns( | |||
| dt_int8, PostprocessMode::QUANTIZED, | |||
| "NCHW44::GEMV::QINT8x8x32_QINT8"_hash); | |||
| break; | |||
| //!no support nchw44-dot 8x8x16 | |||
| //! no support nchw44-dot 8x8x16 | |||
| case param::ConvBias::Format::NCHW44_DOT: | |||
| cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32, | |||
| dt_int32, dt_int8, dt_int32, dt_int32, | |||
| @@ -19,6 +19,8 @@ | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||
| #else | |||
| #include "src/common/postprocess_helper.h" | |||
| #endif | |||
| namespace megdnn { | |||
| @@ -16,6 +16,8 @@ | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||
| #else | |||
| #include "src/common/postprocess_helper.h" | |||
| #endif | |||
| using namespace megdnn; | |||
| #if MEGDNN_X86 | |||
| @@ -12,10 +12,10 @@ | |||
| #include "src/fallback/convolution/img2col_helper.h" | |||
| #if MEGDNN_X86 | |||
| #include "src/x86/conv_bias/postprocess_helper.h" | |||
| #endif | |||
| #if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
| #elif (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
| #include "src/arm_common/conv_bias/postprocess_helper.h" | |||
| #else | |||
| #include "src/common/postprocess_helper.h" | |||
| #endif | |||
| using namespace megdnn; | |||
| @@ -74,6 +74,10 @@ public: | |||
| } | |||
| #endif | |||
| //! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci | |||
| //! test. so we just disable all im2col and conv1x1 in riscv64 | |||
| //! FIXME: remove it when impl postprocess for riscv64 | |||
| #if !MEGDNN_RISCV64 | |||
| for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) { | |||
| refhold.emplace_back(new AlgoIm2col( | |||
| static_cast<MatrixMulImpl::AlgoBase*>(algo), | |||
| @@ -86,6 +90,8 @@ public: | |||
| oc_tile_size)); | |||
| all_algos.emplace_back(refhold.back().get()); | |||
| } | |||
| #endif | |||
| #if 0 | |||
| //! As these algos maybe very slow, it will make fastrun search slow, so | |||
| //! we disable it, but for the test of strategyhelper, we just keep it. | |||
| @@ -50,6 +50,7 @@ public: | |||
| _megdnn_tensor_in bias, _megdnn_tensor_in z, | |||
| _megdnn_tensor_out dst, const PreprocessedFilter*, | |||
| _megdnn_workspace workspace) override; | |||
| bool is_thread_safe() const override { return true; } | |||
| void exec_preprocess(const TensorLayout& src_layout, | |||
| _megdnn_tensor_in filter, | |||
| @@ -74,7 +74,7 @@ void mask_conv_test(Handle* handle) { | |||
| arg[8], arg[9], arg[10], arg[11], arg[12]); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| void mask_conv_benchmark(Handle* handle) { | |||
| auto benchmark = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW, | |||
| size_t FH, size_t FW, size_t SH, size_t SW, size_t PH, | |||
| @@ -113,5 +113,6 @@ void mask_conv_benchmark(Handle* handle) { | |||
| arg[7], arg[8], arg[9], arg[10], arg[11], arg[12]); | |||
| } | |||
| } | |||
| #endif | |||
| } // namespace | |||
| @@ -25,9 +25,11 @@ TEST_F(CPU, MASK_CONV) { | |||
| mask_conv_test(handle()); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CPU, MASK_CONV_BENCHMARK) { | |||
| mask_conv_benchmark(handle()); | |||
| } | |||
| #endif | |||
| TEST_F(CPU, MASK_PROPAGATE) { | |||
| param::MaskPropagate mask_param; | |||
| @@ -17,6 +17,7 @@ | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| namespace { | |||
| void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B, | |||
| @@ -70,6 +71,7 @@ TEST_F(CPU, BENCHMARK_MATRIX_MUL) { | |||
| run(m, nk, nk); | |||
| } | |||
| } | |||
| #endif | |||
| TEST_F(CPU, MATRIX_MUL) { | |||
| matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{}, | |||
| @@ -31,6 +31,7 @@ TYPED_TEST(CPU_RELAYOUT, run) { | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CPU, BENCHMARK_RELAYOUT_CV) { | |||
| relayout::run_cv_benchmark(handle()); | |||
| } | |||
| @@ -55,6 +56,6 @@ TEST_F(CPU, BENCHMARK_RELAYOUT) { | |||
| ASSERT_LE(cpu_time * 5, naive_time); | |||
| } | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -22,10 +22,11 @@ using namespace test; | |||
| TEST_F(CUDA, MASK_CONV) { | |||
| mask_conv_test(handle_cuda()); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(CUDA, MASK_CONV_BENCHMARK) { | |||
| mask_conv_benchmark(handle_cuda()); | |||
| } | |||
| #endif | |||
| TEST_F(CUDA, MASK_PROPAGATE) { | |||
| Checker<MaskPropagate> checker(handle_cuda()); | |||
| @@ -27,7 +27,7 @@ TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types); | |||
| TYPED_TEST(FALLBACK_ELEMWISE, run) { | |||
| elemwise::run_test<TypeParam>(this->handle()); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, BENCHMARK_ELEMWISE) { | |||
| auto naive_handle = create_cpu_handle(2); | |||
| auto run = [&](const TensorShape &shp0, const TensorShape &shp1) { | |||
| @@ -72,6 +72,7 @@ TEST_F(FALLBACK, BENCHMARK_ELEMWISE) { | |||
| // non-contig, fallback to naive | |||
| run({1024, 1024, 32}, {1024, 1, 32}); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -25,7 +25,7 @@ TYPED_TEST_CASE(FALLBACK_ELEMWISE_MULTI_TYPE, elemwise_multi_type::test_types); | |||
| TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) { | |||
| elemwise_multi_type::run_test<TypeParam>(this->handle()); | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32) { | |||
| Benchmarker<ElemwiseMultiType> bench{handle()}; | |||
| bench.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32}); | |||
| @@ -64,5 +64,5 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_IXxf32xf32xI8) { | |||
| (1024.0 * 1024.0 * 1024.0)); | |||
| } | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -31,7 +31,7 @@ TYPED_TEST(FALLBACK_RELAYOUT, run) { | |||
| relayout::run_test<TypeParam>(this->handle()); | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, BENCHMARK_RELAYOUT_CV) { | |||
| relayout::run_cv_benchmark(handle()); | |||
| } | |||
| @@ -160,5 +160,6 @@ TEST_F(FALLBACK, BENCHMARK_RELAYOUT) { | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -34,7 +34,7 @@ TEST_F(FALLBACK, ROICOPY) { | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, BENCHMARK_ROICOPY) { | |||
| auto run = [&](const TensorShapeArray& shapes) { | |||
| Benchmarker<ROICopy> benchmarker(handle()); | |||
| @@ -62,6 +62,7 @@ TEST_F(FALLBACK, BENCHMARK_ROICOPY) { | |||
| run(shapes); | |||
| } | |||
| #endif | |||
| } // namespace test | |||
| @@ -0,0 +1,18 @@ | |||
| set(CMAKE_SYSTEM_NAME Linux) | |||
| set(CMAKE_SYSTEM_PROCESSOR riscv64) | |||
| set(RISCV_CROSS_BUILD_ARCH riscv64) | |||
| if(DEFINED ENV{RISCV_TOOLCHAIN_ROOT}) | |||
| file(TO_CMAKE_PATH $ENV{RISCV_TOOLCHAIN_ROOT} RISCV_TOOLCHAIN_ROOT) | |||
| else() | |||
| message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT env must be defined") | |||
| endif() | |||
| set(RISCV_TOOLCHAIN_ROOT ${RISCV_TOOLCHAIN_ROOT} CACHE STRING "root path to riscv toolchain") | |||
| set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") | |||
| set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") | |||
| set(CMAKE_FIND_ROOT_PATH "${RISCV_TOOLCHAIN_ROOT}/riscv64-unknown-linux-gnu") | |||
| set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) | |||
| set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) | |||
| set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) | |||