GitOrigin-RevId: 170d9eeab2
tags/v1.5.0
| @@ -11,6 +11,7 @@ from ..core._imperative_rt.core2 import ( | |||||
| set_cpp_apply_with_tracing, | set_cpp_apply_with_tracing, | ||||
| ) | ) | ||||
| from .dtr_config import DTRConfig | from .dtr_config import DTRConfig | ||||
| from .graph_opt_config import GraphOptimizationConfig | |||||
| from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
| from .tracing import ( | from .tracing import ( | ||||
| apply_const_with_tracing, | apply_const_with_tracing, | ||||
| @@ -0,0 +1,33 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| class GraphOptimizationConfig: | |||||
| r""" | |||||
| Configuration for graph optimization: False for OFF, True for ON. The default value | |||||
| None means that opt_level will decide whther this optimization will be applied or not. | |||||
| :param jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization | |||||
| :param jit_fuse_reduce: whether to fuse reduce in JIT optimization | |||||
| """ | |||||
| def __init__(self): | |||||
| self.jit_fuse_dimshuffle = None | |||||
| self.jit_fuse_reduce = None | |||||
| def __repr__(self): | |||||
| val2str = {None: "UNSET", False: "OFF", True: "ON"} | |||||
| return ( | |||||
| "GraphOptimizationConfig {" | |||||
| + " jit_fuse_dimshuffle = " | |||||
| + val2str[self.jit_fuse_dimshuffle] | |||||
| + ", jit_fuse_reduce = " | |||||
| + val2str[self.jit_fuse_reduce] | |||||
| + " }" | |||||
| ) | |||||
| @@ -38,6 +38,7 @@ from ..core.tensor import megbrain_graph as G | |||||
| from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
| from ..utils.naming import AutoNaming | from ..utils.naming import AutoNaming | ||||
| from .dtr_config import DTRConfig | from .dtr_config import DTRConfig | ||||
| from .graph_opt_config import GraphOptimizationConfig | |||||
| from .sublinear_memory_config import SublinearMemoryConfig | from .sublinear_memory_config import SublinearMemoryConfig | ||||
| @@ -129,6 +130,7 @@ class trace: | |||||
| If not None, it enables sublinear memory optimization with given setting. | If not None, it enables sublinear memory optimization with given setting. | ||||
| :param profiling: whether to profile compiled trace. Default: False | :param profiling: whether to profile compiled trace. Default: False | ||||
| :param opt_level: optimization level for compiling trace. Default: 2 | :param opt_level: optimization level for compiling trace. Default: 2 | ||||
| :param graph_opt_config: configuration for graph optimization. Default: None | |||||
| :param symbolic_shape: whether to use symbolic shape for tracing. Default: True | :param symbolic_shape: whether to use symbolic shape for tracing. Default: True | ||||
| """ | """ | ||||
| @@ -146,6 +148,7 @@ class trace: | |||||
| dtr_config: DTRConfig = None, | dtr_config: DTRConfig = None, | ||||
| profiling: bool = False, | profiling: bool = False, | ||||
| opt_level: int = 2, | opt_level: int = 2, | ||||
| graph_opt_config: GraphOptimizationConfig = None, | |||||
| symbolic_shape: bool = True, | symbolic_shape: bool = True, | ||||
| ): | ): | ||||
| self.__wrapped__ = function | self.__wrapped__ = function | ||||
| @@ -156,6 +159,7 @@ class trace: | |||||
| self._profiling = profiling | self._profiling = profiling | ||||
| self._profiler = None | self._profiler = None | ||||
| self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
| self._graph_opt_config = graph_opt_config | |||||
| self._symbolic_shape = symbolic_shape | self._symbolic_shape = symbolic_shape | ||||
| self._output_handles = set() | self._output_handles = set() | ||||
| @@ -502,7 +506,14 @@ class trace: | |||||
| graph.options.dtr_config.evictee_minimum_size = ( | graph.options.dtr_config.evictee_minimum_size = ( | ||||
| self._dtr_config.evictee_minimum_size | self._dtr_config.evictee_minimum_size | ||||
| ) | ) | ||||
| # graph optimization | |||||
| if self._graph_opt_config is not None: | |||||
| mapping = {None: 0, False: 1, True: 2} | |||||
| jit_config = graph.options.graph_opt.jit_config | |||||
| jit_config.fuse_dimshuffle = mapping[ | |||||
| self._graph_opt_config.jit_fuse_dimshuffle | |||||
| ] | |||||
| jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce] | |||||
| # sublinear | # sublinear | ||||
| if self._sublinear_memory_config is not None: | if self._sublinear_memory_config is not None: | ||||
| graph.options.enable_sublinear_memory_opt = True | graph.options.enable_sublinear_memory_opt = True | ||||
| @@ -421,12 +421,20 @@ void init_graph_rt(py::module m) { | |||||
| #undef CURRENT_CLASS | #undef CURRENT_CLASS | ||||
| #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt | #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt | ||||
| py::class_<cg::ComputingGraph::Options::GraphOpt>(PyComputingGraphOptions, "GraphOpt") | |||||
| auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>( | |||||
| PyComputingGraphOptions, "GraphOpt") | |||||
| DEF_READWRITE(jit) | DEF_READWRITE(jit) | ||||
| DEF_READWRITE(jit_config) | |||||
| DEF_READWRITE(tensorrt); | DEF_READWRITE(tensorrt); | ||||
| #undef CURRENT_CLASS | #undef CURRENT_CLASS | ||||
| #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig | |||||
| py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(PyGraphOpt, "JITConfig") | |||||
| DEF_READWRITE(fuse_dimshuffle) | |||||
| DEF_READWRITE(fuse_reduce); | |||||
| #undef CURRENT_CLASS | |||||
| #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig | #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig | ||||
| py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig") | py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig") | ||||
| @@ -25,7 +25,7 @@ from megengine.core.ops import builtin as ops | |||||
| from megengine.core.ops.builtin import Elemwise | from megengine.core.ops.builtin import Elemwise | ||||
| from megengine.core.tensor.utils import isscalar | from megengine.core.tensor.utils import isscalar | ||||
| from megengine.functional import exp, log | from megengine.functional import exp, log | ||||
| from megengine.jit import exclude_from_trace, trace | |||||
| from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace | |||||
| from megengine.module import Module | from megengine.module import Module | ||||
| from megengine.random import normal, uniform | from megengine.random import normal, uniform | ||||
| from megengine.utils.naming import AutoNaming | from megengine.utils.naming import AutoNaming | ||||
| @@ -605,3 +605,30 @@ def test_trace_advance_indexing(shape_mode): | |||||
| for _ in range(3): | for _ in range(3): | ||||
| result_trace = f_traced(**params) | result_trace = f_traced(**params) | ||||
| np.testing.assert_equal(expected, result_trace.numpy()) | np.testing.assert_equal(expected, result_trace.numpy()) | ||||
| @pytest.mark.require_ngpu(1) # nvrtc backend | |||||
| def test_trace_jit_config(): | |||||
| def run(fuse_dimshuffle, fuse_reduce): | |||||
| config = GraphOptimizationConfig() | |||||
| config.jit_fuse_dimshuffle = fuse_dimshuffle | |||||
| config.jit_fuse_reduce = fuse_reduce | |||||
| # set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time | |||||
| @trace(opt_level=1, graph_opt_config=config) | |||||
| def func(x): | |||||
| return x + 1 | |||||
| x = tensor(2) | |||||
| y = func(x) | |||||
| func._compile() | |||||
| options = func._graph.options | |||||
| mapping = {None: 0, False: 1, True: 2} | |||||
| assert options.graph_opt.jit == 0 | |||||
| assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] | |||||
| assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce] | |||||
| for fuse_dimshuffle in [None, False, True]: | |||||
| for fuse_reduce in [None, False, True]: | |||||
| run(fuse_dimshuffle, fuse_reduce) | |||||
| @@ -145,6 +145,24 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| /* ========================== JITConfig ========================== */ | |||||
| bool ComputingGraph::Options::GraphOpt::JITConfig::enabled() const { | |||||
| if (fuse_dimshuffle != UNSET) return true; | |||||
| if (fuse_reduce != UNSET) return true; | |||||
| return false; | |||||
| } | |||||
| void ComputingGraph::Options::GraphOpt::JITConfig::update( | |||||
| const JITConfig& modifier) { | |||||
| if (modifier.fuse_dimshuffle != UNSET) { | |||||
| this->fuse_dimshuffle = modifier.fuse_dimshuffle; | |||||
| } | |||||
| if (modifier.fuse_reduce != UNSET) { | |||||
| this->fuse_reduce = modifier.fuse_reduce; | |||||
| } | |||||
| } | |||||
| /* ========================== CallbackCaller ========================== */ | /* ========================== CallbackCaller ========================== */ | ||||
| MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, | MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, | ||||
| SingleCNOperatorNodeBase) // { | SingleCNOperatorNodeBase) // { | ||||
| @@ -538,12 +556,18 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
| #if MGB_JIT | #if MGB_JIT | ||||
| if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | |||||
| setenv("MGB_JIT_BACKEND","NVRTC",1); | |||||
| if (std::abs(options().graph_opt_level) == 0 && | |||||
| (options().graph_opt.jit || options().graph_opt.jit_config.enabled())) { | |||||
| // Deprecated usage added previously. It allows NVRTC JIT optimization | |||||
| // when graph_opt_level is 0. This usage is not recommanded any more. | |||||
| mgb_log_warn( | |||||
| "It is not recommanded to enable JIT optimization when " | |||||
| "graph_opt_level is 0."); | |||||
| setenv("MGB_JIT_BACKEND", "NVRTC", 1); | |||||
| gopt::GraphOptimizer optimizer; | gopt::GraphOptimizer optimizer; | ||||
| optimizer.add_pass<gopt::JITFusionPass>( | |||||
| sopr_stat.has_virtual_grad, | |||||
| std::max<uint8_t>(options().graph_opt.jit, 1)); | |||||
| optimizer.add_pass<gopt::JITFusionPass>(sopr_stat.has_virtual_grad, | |||||
| options().graph_opt.jit, | |||||
| options().graph_opt.jit_config); | |||||
| optimizer.apply_inplace(dest_vars); | optimizer.apply_inplace(dest_vars); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -338,6 +338,20 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
| //! this value indicates JIT level: 1 for basic elemwise opr; 2 | //! this value indicates JIT level: 1 for basic elemwise opr; 2 | ||||
| //! for including reduce oprs | //! for including reduce oprs | ||||
| uint8_t jit = 0; | uint8_t jit = 0; | ||||
| //! jit configurations | |||||
| struct JITConfig { | |||||
| static const int UNSET = 0; | |||||
| static const int OFF = 1; | |||||
| static const int ON = 2; | |||||
| int fuse_dimshuffle = UNSET; | |||||
| int fuse_reduce = UNSET; | |||||
| bool enabled() const; | |||||
| void update(const JITConfig& modifier); | |||||
| } jit_config; | |||||
| //! whether to enable fine-grained TensorRT opr replace | //! whether to enable fine-grained TensorRT opr replace | ||||
| bool tensorrt = false; | bool tensorrt = false; | ||||
| } graph_opt; | } graph_opt; | ||||
| @@ -645,11 +645,21 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
| add_pass<RemoveRedundantCopyPass>(); | add_pass<RemoveRedundantCopyPass>(); | ||||
| #if MGB_JIT | #if MGB_JIT | ||||
| bool need_jit = false; | |||||
| if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || | |||||
| comp_graph_opt->graph_opt.jit)) { | |||||
| need_jit = true; | |||||
| using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; | |||||
| int jit_opt_level = 0; | |||||
| JITConfig jit_config; | |||||
| // for more detail on what is happening here, see comments on the | |||||
| // constuctor of class JITFusionPass in fusion_pass.h | |||||
| if (comp_graph_opt) { | |||||
| jit_opt_level = comp_graph_opt->graph_opt.jit; | |||||
| if (comp_graph_opt->graph_opt_level >= 3) { | |||||
| jit_opt_level = std::max(jit_opt_level, 1); | |||||
| } | |||||
| jit_config = comp_graph_opt->graph_opt.jit_config; | |||||
| } | } | ||||
| bool need_jit = (jit_opt_level > 0) || jit_config.enabled(); | |||||
| if (need_jit && after_grad) { | if (need_jit && after_grad) { | ||||
| add_pass<gopt::RecompTypeCvtPass>(); | add_pass<gopt::RecompTypeCvtPass>(); | ||||
| } | } | ||||
| @@ -662,9 +672,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||||
| #if MGB_JIT | #if MGB_JIT | ||||
| if (need_jit) { | if (need_jit) { | ||||
| add_pass<gopt::JITFusionPass>( | |||||
| after_grad, | |||||
| std::max<uint8_t>(comp_graph_opt->graph_opt.jit, 1)); | |||||
| add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -428,14 +428,33 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level) | |||||
| JITFusionPass::JITFusionPass(bool after_grad, int jit_opt_level, | |||||
| const JITConfig& jit_config) | |||||
| : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} { | : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} { | ||||
| // TODO reduce and dimshuffle can not coexsit now. | |||||
| if (jit_opt_level >= 2) { | |||||
| m_feature_bits |= JITFeatureBits::REDUCE; | |||||
| } else { | |||||
| // get default config from jit_opt_level | |||||
| JITConfig config; | |||||
| if (jit_opt_level == 1) { | |||||
| config.fuse_dimshuffle = JITConfig::ON; | |||||
| config.fuse_reduce = JITConfig::OFF; | |||||
| } else if (jit_opt_level >= 2) { | |||||
| config.fuse_dimshuffle = JITConfig::OFF; | |||||
| config.fuse_reduce = JITConfig::ON; | |||||
| } | |||||
| // overwrite default config with custom settings | |||||
| config.update(jit_config); | |||||
| bool fuse_dimshuffle = config.fuse_dimshuffle == JITConfig::ON; | |||||
| bool fuse_reduce = config.fuse_reduce == JITConfig::ON; | |||||
| if (fuse_dimshuffle && fuse_reduce) { | |||||
| mgb_assert(false, "reduce and dimshuffle can not coexist now"); | |||||
| } | |||||
| if (fuse_dimshuffle) { | |||||
| m_feature_bits |= JITFeatureBits::DIMSHUFFLE; | m_feature_bits |= JITFeatureBits::DIMSHUFFLE; | ||||
| } | } | ||||
| if (fuse_reduce) { | |||||
| m_feature_bits |= JITFeatureBits::REDUCE; | |||||
| } | |||||
| } | } | ||||
| const char* JITFusionPass::name() const { | const char* JITFusionPass::name() const { | ||||
| @@ -39,7 +39,40 @@ class JITFusionPass final : public Pass { | |||||
| JITFeatureBits m_feature_bits; | JITFeatureBits m_feature_bits; | ||||
| public: | public: | ||||
| JITFusionPass(bool after_grad = true, int8_t jit_opt_level = 1); | |||||
| using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; | |||||
| /* | |||||
| * Explanation of how graph_opt_level, jit_opt_level and jit_config | |||||
| * control the behavior of JIT optimization: | |||||
| * | |||||
| * The design of this API is restricted by the historical burden of | |||||
| * jit_opt_level and we have to support the old interface jit_opt_level and | |||||
| * the new interface jit_config at the same time. | |||||
| * | |||||
| * How JITFusionPass decides its behavior: | |||||
| * (1) When graph_opt_level is 3, it sets jit_opt_level to 1 | |||||
| * (2) When the user-defined jit_opt_level is greater than 1, it overwrites | |||||
| * the previous value of jit_opt_level | |||||
| * (3) We get a default jit_config from jit_opt_level: | |||||
| * jit_opt_level = 0: JIT optimization OFF | |||||
| * jit_opt_level = 1: dimshuffle ON, reduce OFF | |||||
| * jit_opt_level = 2: dimshuffle OFF, reduce ON | |||||
| * (4) The user-defined jit_config provides more precise control and | |||||
| * overwrites the default settings defined by jit_opt_level | |||||
| * | |||||
| * Situations in which JIT optimization is ON: | |||||
| * (1) graph_opt_level = 3 | |||||
| * (2) graph_opt_level = 2, jit_opt_level > 0 | |||||
| * (3) graph_opt_level = 2, jit_opt_level = 0, jit_config is set | |||||
| * (4) graph_opt_level = 0, jit_opt_level > 0 (deprecated usage) | |||||
| * | |||||
| * Situations in which JIT optimization is OFF: | |||||
| * (1) graph_opt_level = 2, jit_opt_level = 0, jit_config is unset | |||||
| * (2) graph_opt_level = 1 | |||||
| * (3) graph_opt_level = 0, jit_opt_level = 0 | |||||
| */ | |||||
| JITFusionPass(bool after_grad = true, int jit_opt_level = 0, | |||||
| const JITConfig& jit_config = {}); | |||||
| const char* name() const override; | const char* name() const override; | ||||
| void apply(OptState& opt) const override; | void apply(OptState& opt) const override; | ||||
| }; | }; | ||||
| @@ -27,6 +27,8 @@ | |||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| #include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
| #include "../../core/impl/graph/cg_impl_seq.h" | |||||
| #if MGB_JIT | #if MGB_JIT | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -1455,6 +1457,122 @@ TEST(TestJITNvrtc, DimshuffleGrad) { | |||||
| } | } | ||||
| } | } | ||||
| TEST(TestJITNvrtc, JITConfig) { | |||||
| using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; | |||||
| using CompSeq = cg::ComputingGraphImpl::ComputingSequence; | |||||
| using ReduceMode = opr::Reduce::Param::Mode; | |||||
| static const int UNSET = JITConfig::UNSET; | |||||
| static const int OFF = JITConfig::OFF; | |||||
| static const int ON = JITConfig::ON; | |||||
| REQUIRE_GPU(1); | |||||
| set_backend(Backend::NVRTC); | |||||
| auto cn = CompNode::load("gpu0"); | |||||
| HostTensorGenerator<> gen; | |||||
| auto run = [&](int graph_opt_level, int jit_opt_level, | |||||
| const JITConfig& jit_config, bool expect_dimshuffle_fused, | |||||
| bool expect_reduce_fused, bool expect_jit_enabled) { | |||||
| auto cg = ComputingGraph::make(); | |||||
| cg->options().graph_opt_level = graph_opt_level; | |||||
| cg->options().graph_opt.jit = jit_opt_level; | |||||
| cg->options().graph_opt.jit_config = jit_config; | |||||
| auto host_x = gen({2, 3, 4, 5}, cn); | |||||
| auto x = opr::SharedDeviceTensor::make(*cg, *host_x); | |||||
| // three types of operations to be fused by JIT | |||||
| x = (2 * x + 3) * (3 * x - 1); // Elemwise | |||||
| x = opr::Dimshuffle::make(x, {1, 2, 3, 0}); // Dimshuffle | |||||
| x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce | |||||
| auto func = cg->compile({make_callback_copy(x + 1, *host_x)}); | |||||
| auto comp_seq = dynamic_cast<CompSeq*>(func.get()); | |||||
| ASSERT_TRUE(comp_seq != nullptr); | |||||
| bool dimshuffle_found = false, reduce_found = false, | |||||
| jit_executor_found = false; | |||||
| auto on_opr = [&](cg::OperatorNodeBase* opr) { | |||||
| if (opr->same_type<opr::Dimshuffle>()) { | |||||
| dimshuffle_found = true; | |||||
| } else if (opr->same_type<opr::Reduce>()) { | |||||
| reduce_found = true; | |||||
| } else if (opr->same_type<JITExecutor>()) { | |||||
| jit_executor_found = true; | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| comp_seq->iter_opr_seq(on_opr); | |||||
| ASSERT_EQ(expect_dimshuffle_fused, !dimshuffle_found); | |||||
| ASSERT_EQ(expect_reduce_fused, !reduce_found); | |||||
| ASSERT_EQ(expect_jit_enabled, jit_executor_found); | |||||
| }; | |||||
| // graph_opt_level = 1, always OFF | |||||
| for (int jit_opt_level : {0, 1, 2}) { | |||||
| for (int fuse_dimshuffle : {UNSET, OFF, ON}) { | |||||
| for (int fuse_reduce : {UNSET, OFF, ON}) { | |||||
| run(1, jit_opt_level, JITConfig{fuse_dimshuffle, fuse_reduce}, | |||||
| false, false, false); | |||||
| } | |||||
| } | |||||
| } | |||||
| // some test cases are commented because dimshuffle and reduce can not be | |||||
| // fused at the same time | |||||
| for (int graph_opt_level : {0, 2}) { | |||||
| // jit_opt_level = 0, default = {OFF, OFF} | |||||
| run(graph_opt_level, 0, JITConfig{UNSET, UNSET}, false, false, false); | |||||
| run(graph_opt_level, 0, JITConfig{UNSET, OFF}, false, false, true); | |||||
| run(graph_opt_level, 0, JITConfig{UNSET, ON}, false, true, true); | |||||
| run(graph_opt_level, 0, JITConfig{OFF, UNSET}, false, false, true); | |||||
| run(graph_opt_level, 0, JITConfig{OFF, OFF}, false, false, true); | |||||
| run(graph_opt_level, 0, JITConfig{OFF, ON}, false, true, true); | |||||
| run(graph_opt_level, 0, JITConfig{ON, UNSET}, true, false, true); | |||||
| run(graph_opt_level, 0, JITConfig{ON, OFF}, true, false, true); | |||||
| // run(graph_opt_level, 0, JITConfig{ON, ON}, true, true, true); | |||||
| } | |||||
| { | |||||
| // graph_opt_level = 3, jit_opt_level = 0, default = {ON, OFF} | |||||
| run(3, 0, JITConfig{UNSET, UNSET}, true, false, true); | |||||
| run(3, 0, JITConfig{UNSET, OFF}, true, false, true); | |||||
| // run(3, 0, JITConfig{UNSET, ON}, true, true, true); | |||||
| run(3, 0, JITConfig{OFF, UNSET}, false, false, true); | |||||
| run(3, 0, JITConfig{OFF, OFF}, false, false, true); | |||||
| run(3, 0, JITConfig{OFF, ON}, false, true, true); | |||||
| run(3, 0, JITConfig{ON, UNSET}, true, false, true); | |||||
| run(3, 0, JITConfig{ON, OFF}, true, false, true); | |||||
| // run(3, 0, JITConfig{ON, ON}, true, true, true); | |||||
| } | |||||
| for (int graph_opt_level : {0, 2, 3}) { | |||||
| // jit_opt_level = 1, default = {ON, OFF} | |||||
| run(graph_opt_level, 1, JITConfig{UNSET, UNSET}, true, false, true); | |||||
| run(graph_opt_level, 1, JITConfig{UNSET, OFF}, true, false, true); | |||||
| // run(graph_opt_level, 1, JITConfig{UNSET, ON}, true, true, true); | |||||
| run(graph_opt_level, 1, JITConfig{OFF, UNSET}, false, false, true); | |||||
| run(graph_opt_level, 1, JITConfig{OFF, OFF}, false, false, true); | |||||
| run(graph_opt_level, 1, JITConfig{OFF, ON}, false, true, true); | |||||
| run(graph_opt_level, 1, JITConfig{ON, UNSET}, true, false, true); | |||||
| run(graph_opt_level, 1, JITConfig{ON, OFF}, true, false, true); | |||||
| // run(graph_opt_level, 1, JITConfig{ON, ON}, true, true, true); | |||||
| // jit_opt_level = 2, default = {OFF, ON} | |||||
| run(graph_opt_level, 2, JITConfig{UNSET, UNSET}, false, true, true); | |||||
| run(graph_opt_level, 2, JITConfig{UNSET, OFF}, false, false, true); | |||||
| run(graph_opt_level, 2, JITConfig{UNSET, ON}, false, true, true); | |||||
| run(graph_opt_level, 2, JITConfig{OFF, UNSET}, false, true, true); | |||||
| run(graph_opt_level, 2, JITConfig{OFF, OFF}, false, false, true); | |||||
| run(graph_opt_level, 2, JITConfig{OFF, ON}, false, true, true); | |||||
| // run(graph_opt_level, 2, JITConfig{ON, UNSET}, true, true, true); | |||||
| run(graph_opt_level, 2, JITConfig{ON, OFF}, true, false, true); | |||||
| // run(graph_opt_level, 2, JITConfig{ON, ON}, true, true, true); | |||||
| } | |||||
| } | |||||
| TEST(TestJITExecutor, GradBehavior) { | TEST(TestJITExecutor, GradBehavior) { | ||||
| REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
| auto cn = CompNode::load("gpu0"); | auto cn = CompNode::load("gpu0"); | ||||