GitOrigin-RevId: 170d9eeab2
tags/v1.5.0
| @@ -11,6 +11,7 @@ from ..core._imperative_rt.core2 import ( | |||
| set_cpp_apply_with_tracing, | |||
| ) | |||
| from .dtr_config import DTRConfig | |||
| from .graph_opt_config import GraphOptimizationConfig | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| from .tracing import ( | |||
| apply_const_with_tracing, | |||
| @@ -0,0 +1,33 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| class GraphOptimizationConfig: | |||
| r""" | |||
| Configuration for graph optimization: False for OFF, True for ON. The default value | |||
| None means that opt_level will decide whther this optimization will be applied or not. | |||
| :param jit_fuse_dimshuffle: whether to fuse dimshuffle in JIT optimization | |||
| :param jit_fuse_reduce: whether to fuse reduce in JIT optimization | |||
| """ | |||
| def __init__(self): | |||
| self.jit_fuse_dimshuffle = None | |||
| self.jit_fuse_reduce = None | |||
| def __repr__(self): | |||
| val2str = {None: "UNSET", False: "OFF", True: "ON"} | |||
| return ( | |||
| "GraphOptimizationConfig {" | |||
| + " jit_fuse_dimshuffle = " | |||
| + val2str[self.jit_fuse_dimshuffle] | |||
| + ", jit_fuse_reduce = " | |||
| + val2str[self.jit_fuse_reduce] | |||
| + " }" | |||
| ) | |||
| @@ -38,6 +38,7 @@ from ..core.tensor import megbrain_graph as G | |||
| from ..core.tensor.utils import setscalar | |||
| from ..utils.naming import AutoNaming | |||
| from .dtr_config import DTRConfig | |||
| from .graph_opt_config import GraphOptimizationConfig | |||
| from .sublinear_memory_config import SublinearMemoryConfig | |||
| @@ -129,6 +130,7 @@ class trace: | |||
| If not None, it enables sublinear memory optimization with given setting. | |||
| :param profiling: whether to profile compiled trace. Default: False | |||
| :param opt_level: optimization level for compiling trace. Default: 2 | |||
| :param graph_opt_config: configuration for graph optimization. Default: None | |||
| :param symbolic_shape: whether to use symbolic shape for tracing. Default: True | |||
| """ | |||
| @@ -146,6 +148,7 @@ class trace: | |||
| dtr_config: DTRConfig = None, | |||
| profiling: bool = False, | |||
| opt_level: int = 2, | |||
| graph_opt_config: GraphOptimizationConfig = None, | |||
| symbolic_shape: bool = True, | |||
| ): | |||
| self.__wrapped__ = function | |||
| @@ -156,6 +159,7 @@ class trace: | |||
| self._profiling = profiling | |||
| self._profiler = None | |||
| self._graph_opt_level = opt_level | |||
| self._graph_opt_config = graph_opt_config | |||
| self._symbolic_shape = symbolic_shape | |||
| self._output_handles = set() | |||
| @@ -502,7 +506,14 @@ class trace: | |||
| graph.options.dtr_config.evictee_minimum_size = ( | |||
| self._dtr_config.evictee_minimum_size | |||
| ) | |||
| # graph optimization | |||
| if self._graph_opt_config is not None: | |||
| mapping = {None: 0, False: 1, True: 2} | |||
| jit_config = graph.options.graph_opt.jit_config | |||
| jit_config.fuse_dimshuffle = mapping[ | |||
| self._graph_opt_config.jit_fuse_dimshuffle | |||
| ] | |||
| jit_config.fuse_reduce = mapping[self._graph_opt_config.jit_fuse_reduce] | |||
| # sublinear | |||
| if self._sublinear_memory_config is not None: | |||
| graph.options.enable_sublinear_memory_opt = True | |||
| @@ -421,12 +421,20 @@ void init_graph_rt(py::module m) { | |||
| #undef CURRENT_CLASS | |||
| #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt | |||
| py::class_<cg::ComputingGraph::Options::GraphOpt>(PyComputingGraphOptions, "GraphOpt") | |||
| auto PyGraphOpt = py::class_<cg::ComputingGraph::Options::GraphOpt>( | |||
| PyComputingGraphOptions, "GraphOpt") | |||
| DEF_READWRITE(jit) | |||
| DEF_READWRITE(jit_config) | |||
| DEF_READWRITE(tensorrt); | |||
| #undef CURRENT_CLASS | |||
| #define CURRENT_CLASS cg::ComputingGraph::Options::GraphOpt::JITConfig | |||
| py::class_<cg::ComputingGraph::Options::GraphOpt::JITConfig>(PyGraphOpt, "JITConfig") | |||
| DEF_READWRITE(fuse_dimshuffle) | |||
| DEF_READWRITE(fuse_reduce); | |||
| #undef CURRENT_CLASS | |||
| #define CURRENT_CLASS cg::ComputingGraph::Options::SublinearMemConfig | |||
| py::class_<cg::ComputingGraph::Options::SublinearMemConfig>(PyComputingGraphOptions, "SublinearMemConfig") | |||
| @@ -25,7 +25,7 @@ from megengine.core.ops import builtin as ops | |||
| from megengine.core.ops.builtin import Elemwise | |||
| from megengine.core.tensor.utils import isscalar | |||
| from megengine.functional import exp, log | |||
| from megengine.jit import exclude_from_trace, trace | |||
| from megengine.jit import GraphOptimizationConfig, exclude_from_trace, trace | |||
| from megengine.module import Module | |||
| from megengine.random import normal, uniform | |||
| from megengine.utils.naming import AutoNaming | |||
| @@ -605,3 +605,30 @@ def test_trace_advance_indexing(shape_mode): | |||
| for _ in range(3): | |||
| result_trace = f_traced(**params) | |||
| np.testing.assert_equal(expected, result_trace.numpy()) | |||
| @pytest.mark.require_ngpu(1) # nvrtc backend | |||
| def test_trace_jit_config(): | |||
| def run(fuse_dimshuffle, fuse_reduce): | |||
| config = GraphOptimizationConfig() | |||
| config.jit_fuse_dimshuffle = fuse_dimshuffle | |||
| config.jit_fuse_reduce = fuse_reduce | |||
| # set opt_level = 1 to avoid fusing dimshuffle and reduce at the same time | |||
| @trace(opt_level=1, graph_opt_config=config) | |||
| def func(x): | |||
| return x + 1 | |||
| x = tensor(2) | |||
| y = func(x) | |||
| func._compile() | |||
| options = func._graph.options | |||
| mapping = {None: 0, False: 1, True: 2} | |||
| assert options.graph_opt.jit == 0 | |||
| assert options.graph_opt.jit_config.fuse_dimshuffle == mapping[fuse_dimshuffle] | |||
| assert options.graph_opt.jit_config.fuse_reduce == mapping[fuse_reduce] | |||
| for fuse_dimshuffle in [None, False, True]: | |||
| for fuse_reduce in [None, False, True]: | |||
| run(fuse_dimshuffle, fuse_reduce) | |||
| @@ -145,6 +145,24 @@ size_t ComputingGraph::prealloc_static_storage(size_t size) { | |||
| } | |||
| #endif | |||
| /* ========================== JITConfig ========================== */ | |||
| bool ComputingGraph::Options::GraphOpt::JITConfig::enabled() const { | |||
| if (fuse_dimshuffle != UNSET) return true; | |||
| if (fuse_reduce != UNSET) return true; | |||
| return false; | |||
| } | |||
| void ComputingGraph::Options::GraphOpt::JITConfig::update( | |||
| const JITConfig& modifier) { | |||
| if (modifier.fuse_dimshuffle != UNSET) { | |||
| this->fuse_dimshuffle = modifier.fuse_dimshuffle; | |||
| } | |||
| if (modifier.fuse_reduce != UNSET) { | |||
| this->fuse_reduce = modifier.fuse_reduce; | |||
| } | |||
| } | |||
| /* ========================== CallbackCaller ========================== */ | |||
| MGB_DEFINE_OPR_CLASS(ComputingGraphImpl::CallbackCaller, | |||
| SingleCNOperatorNodeBase) // { | |||
| @@ -538,12 +556,18 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||
| #if MGB_JIT | |||
| if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | |||
| setenv("MGB_JIT_BACKEND","NVRTC",1); | |||
| if (std::abs(options().graph_opt_level) == 0 && | |||
| (options().graph_opt.jit || options().graph_opt.jit_config.enabled())) { | |||
| // Deprecated usage added previously. It allows NVRTC JIT optimization | |||
| // when graph_opt_level is 0. This usage is not recommanded any more. | |||
| mgb_log_warn( | |||
| "It is not recommanded to enable JIT optimization when " | |||
| "graph_opt_level is 0."); | |||
| setenv("MGB_JIT_BACKEND", "NVRTC", 1); | |||
| gopt::GraphOptimizer optimizer; | |||
| optimizer.add_pass<gopt::JITFusionPass>( | |||
| sopr_stat.has_virtual_grad, | |||
| std::max<uint8_t>(options().graph_opt.jit, 1)); | |||
| optimizer.add_pass<gopt::JITFusionPass>(sopr_stat.has_virtual_grad, | |||
| options().graph_opt.jit, | |||
| options().graph_opt.jit_config); | |||
| optimizer.apply_inplace(dest_vars); | |||
| } | |||
| #endif | |||
| @@ -338,6 +338,20 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||
| //! this value indicates JIT level: 1 for basic elemwise opr; 2 | |||
| //! for including reduce oprs | |||
| uint8_t jit = 0; | |||
| //! jit configurations | |||
| struct JITConfig { | |||
| static const int UNSET = 0; | |||
| static const int OFF = 1; | |||
| static const int ON = 2; | |||
| int fuse_dimshuffle = UNSET; | |||
| int fuse_reduce = UNSET; | |||
| bool enabled() const; | |||
| void update(const JITConfig& modifier); | |||
| } jit_config; | |||
| //! whether to enable fine-grained TensorRT opr replace | |||
| bool tensorrt = false; | |||
| } graph_opt; | |||
| @@ -645,11 +645,21 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||
| add_pass<RemoveRedundantCopyPass>(); | |||
| #if MGB_JIT | |||
| bool need_jit = false; | |||
| if (comp_graph_opt && (std::abs(comp_graph_opt->graph_opt_level) >= 3 || | |||
| comp_graph_opt->graph_opt.jit)) { | |||
| need_jit = true; | |||
| using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; | |||
| int jit_opt_level = 0; | |||
| JITConfig jit_config; | |||
| // for more detail on what is happening here, see comments on the | |||
| // constuctor of class JITFusionPass in fusion_pass.h | |||
| if (comp_graph_opt) { | |||
| jit_opt_level = comp_graph_opt->graph_opt.jit; | |||
| if (comp_graph_opt->graph_opt_level >= 3) { | |||
| jit_opt_level = std::max(jit_opt_level, 1); | |||
| } | |||
| jit_config = comp_graph_opt->graph_opt.jit_config; | |||
| } | |||
| bool need_jit = (jit_opt_level > 0) || jit_config.enabled(); | |||
| if (need_jit && after_grad) { | |||
| add_pass<gopt::RecompTypeCvtPass>(); | |||
| } | |||
| @@ -662,9 +672,7 @@ GraphOptimizer& GraphOptimizer::add_preset_passes( | |||
| #if MGB_JIT | |||
| if (need_jit) { | |||
| add_pass<gopt::JITFusionPass>( | |||
| after_grad, | |||
| std::max<uint8_t>(comp_graph_opt->graph_opt.jit, 1)); | |||
| add_pass<gopt::JITFusionPass>(after_grad, jit_opt_level, jit_config); | |||
| } | |||
| #endif | |||
| @@ -428,14 +428,33 @@ bool JITFusionPass::Impl::can_be_fused(cg::OperatorNodeBase* opr) const { | |||
| return false; | |||
| } | |||
| JITFusionPass::JITFusionPass(bool after_grad, int8_t jit_opt_level) | |||
| JITFusionPass::JITFusionPass(bool after_grad, int jit_opt_level, | |||
| const JITConfig& jit_config) | |||
| : m_after_grad{after_grad}, m_feature_bits{JITFeatureBits::NONE} { | |||
| // TODO reduce and dimshuffle can not coexsit now. | |||
| if (jit_opt_level >= 2) { | |||
| m_feature_bits |= JITFeatureBits::REDUCE; | |||
| } else { | |||
| // get default config from jit_opt_level | |||
| JITConfig config; | |||
| if (jit_opt_level == 1) { | |||
| config.fuse_dimshuffle = JITConfig::ON; | |||
| config.fuse_reduce = JITConfig::OFF; | |||
| } else if (jit_opt_level >= 2) { | |||
| config.fuse_dimshuffle = JITConfig::OFF; | |||
| config.fuse_reduce = JITConfig::ON; | |||
| } | |||
| // overwrite default config with custom settings | |||
| config.update(jit_config); | |||
| bool fuse_dimshuffle = config.fuse_dimshuffle == JITConfig::ON; | |||
| bool fuse_reduce = config.fuse_reduce == JITConfig::ON; | |||
| if (fuse_dimshuffle && fuse_reduce) { | |||
| mgb_assert(false, "reduce and dimshuffle can not coexist now"); | |||
| } | |||
| if (fuse_dimshuffle) { | |||
| m_feature_bits |= JITFeatureBits::DIMSHUFFLE; | |||
| } | |||
| if (fuse_reduce) { | |||
| m_feature_bits |= JITFeatureBits::REDUCE; | |||
| } | |||
| } | |||
| const char* JITFusionPass::name() const { | |||
| @@ -39,7 +39,40 @@ class JITFusionPass final : public Pass { | |||
| JITFeatureBits m_feature_bits; | |||
| public: | |||
| JITFusionPass(bool after_grad = true, int8_t jit_opt_level = 1); | |||
| using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; | |||
| /* | |||
| * Explanation of how graph_opt_level, jit_opt_level and jit_config | |||
| * control the behavior of JIT optimization: | |||
| * | |||
| * The design of this API is restricted by the historical burden of | |||
| * jit_opt_level and we have to support the old interface jit_opt_level and | |||
| * the new interface jit_config at the same time. | |||
| * | |||
| * How JITFusionPass decides its behavior: | |||
| * (1) When graph_opt_level is 3, it sets jit_opt_level to 1 | |||
| * (2) When the user-defined jit_opt_level is greater than 1, it overwrites | |||
| * the previous value of jit_opt_level | |||
| * (3) We get a default jit_config from jit_opt_level: | |||
| * jit_opt_level = 0: JIT optimization OFF | |||
| * jit_opt_level = 1: dimshuffle ON, reduce OFF | |||
| * jit_opt_level = 2: dimshuffle OFF, reduce ON | |||
| * (4) The user-defined jit_config provides more precise control and | |||
| * overwrites the default settings defined by jit_opt_level | |||
| * | |||
| * Situations in which JIT optimization is ON: | |||
| * (1) graph_opt_level = 3 | |||
| * (2) graph_opt_level = 2, jit_opt_level > 0 | |||
| * (3) graph_opt_level = 2, jit_opt_level = 0, jit_config is set | |||
| * (4) graph_opt_level = 0, jit_opt_level > 0 (deprecated usage) | |||
| * | |||
| * Situations in which JIT optimization is OFF: | |||
| * (1) graph_opt_level = 2, jit_opt_level = 0, jit_config is unset | |||
| * (2) graph_opt_level = 1 | |||
| * (3) graph_opt_level = 0, jit_opt_level = 0 | |||
| */ | |||
| JITFusionPass(bool after_grad = true, int jit_opt_level = 0, | |||
| const JITConfig& jit_config = {}); | |||
| const char* name() const override; | |||
| void apply(OptState& opt) const override; | |||
| }; | |||
| @@ -27,6 +27,8 @@ | |||
| #include "megbrain/test/helper.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "../../core/impl/graph/cg_impl_seq.h" | |||
| #if MGB_JIT | |||
| using namespace mgb; | |||
| @@ -1455,6 +1457,122 @@ TEST(TestJITNvrtc, DimshuffleGrad) { | |||
| } | |||
| } | |||
| TEST(TestJITNvrtc, JITConfig) { | |||
| using JITConfig = cg::ComputingGraph::Options::GraphOpt::JITConfig; | |||
| using CompSeq = cg::ComputingGraphImpl::ComputingSequence; | |||
| using ReduceMode = opr::Reduce::Param::Mode; | |||
| static const int UNSET = JITConfig::UNSET; | |||
| static const int OFF = JITConfig::OFF; | |||
| static const int ON = JITConfig::ON; | |||
| REQUIRE_GPU(1); | |||
| set_backend(Backend::NVRTC); | |||
| auto cn = CompNode::load("gpu0"); | |||
| HostTensorGenerator<> gen; | |||
| auto run = [&](int graph_opt_level, int jit_opt_level, | |||
| const JITConfig& jit_config, bool expect_dimshuffle_fused, | |||
| bool expect_reduce_fused, bool expect_jit_enabled) { | |||
| auto cg = ComputingGraph::make(); | |||
| cg->options().graph_opt_level = graph_opt_level; | |||
| cg->options().graph_opt.jit = jit_opt_level; | |||
| cg->options().graph_opt.jit_config = jit_config; | |||
| auto host_x = gen({2, 3, 4, 5}, cn); | |||
| auto x = opr::SharedDeviceTensor::make(*cg, *host_x); | |||
| // three types of operations to be fused by JIT | |||
| x = (2 * x + 3) * (3 * x - 1); // Elemwise | |||
| x = opr::Dimshuffle::make(x, {1, 2, 3, 0}); // Dimshuffle | |||
| x = opr::Reduce::make(x + 2, {ReduceMode::SUM, 2}); // Reduce | |||
| auto func = cg->compile({make_callback_copy(x + 1, *host_x)}); | |||
| auto comp_seq = dynamic_cast<CompSeq*>(func.get()); | |||
| ASSERT_TRUE(comp_seq != nullptr); | |||
| bool dimshuffle_found = false, reduce_found = false, | |||
| jit_executor_found = false; | |||
| auto on_opr = [&](cg::OperatorNodeBase* opr) { | |||
| if (opr->same_type<opr::Dimshuffle>()) { | |||
| dimshuffle_found = true; | |||
| } else if (opr->same_type<opr::Reduce>()) { | |||
| reduce_found = true; | |||
| } else if (opr->same_type<JITExecutor>()) { | |||
| jit_executor_found = true; | |||
| } | |||
| return true; | |||
| }; | |||
| comp_seq->iter_opr_seq(on_opr); | |||
| ASSERT_EQ(expect_dimshuffle_fused, !dimshuffle_found); | |||
| ASSERT_EQ(expect_reduce_fused, !reduce_found); | |||
| ASSERT_EQ(expect_jit_enabled, jit_executor_found); | |||
| }; | |||
| // graph_opt_level = 1, always OFF | |||
| for (int jit_opt_level : {0, 1, 2}) { | |||
| for (int fuse_dimshuffle : {UNSET, OFF, ON}) { | |||
| for (int fuse_reduce : {UNSET, OFF, ON}) { | |||
| run(1, jit_opt_level, JITConfig{fuse_dimshuffle, fuse_reduce}, | |||
| false, false, false); | |||
| } | |||
| } | |||
| } | |||
| // some test cases are commented because dimshuffle and reduce can not be | |||
| // fused at the same time | |||
| for (int graph_opt_level : {0, 2}) { | |||
| // jit_opt_level = 0, default = {OFF, OFF} | |||
| run(graph_opt_level, 0, JITConfig{UNSET, UNSET}, false, false, false); | |||
| run(graph_opt_level, 0, JITConfig{UNSET, OFF}, false, false, true); | |||
| run(graph_opt_level, 0, JITConfig{UNSET, ON}, false, true, true); | |||
| run(graph_opt_level, 0, JITConfig{OFF, UNSET}, false, false, true); | |||
| run(graph_opt_level, 0, JITConfig{OFF, OFF}, false, false, true); | |||
| run(graph_opt_level, 0, JITConfig{OFF, ON}, false, true, true); | |||
| run(graph_opt_level, 0, JITConfig{ON, UNSET}, true, false, true); | |||
| run(graph_opt_level, 0, JITConfig{ON, OFF}, true, false, true); | |||
| // run(graph_opt_level, 0, JITConfig{ON, ON}, true, true, true); | |||
| } | |||
| { | |||
| // graph_opt_level = 3, jit_opt_level = 0, default = {ON, OFF} | |||
| run(3, 0, JITConfig{UNSET, UNSET}, true, false, true); | |||
| run(3, 0, JITConfig{UNSET, OFF}, true, false, true); | |||
| // run(3, 0, JITConfig{UNSET, ON}, true, true, true); | |||
| run(3, 0, JITConfig{OFF, UNSET}, false, false, true); | |||
| run(3, 0, JITConfig{OFF, OFF}, false, false, true); | |||
| run(3, 0, JITConfig{OFF, ON}, false, true, true); | |||
| run(3, 0, JITConfig{ON, UNSET}, true, false, true); | |||
| run(3, 0, JITConfig{ON, OFF}, true, false, true); | |||
| // run(3, 0, JITConfig{ON, ON}, true, true, true); | |||
| } | |||
| for (int graph_opt_level : {0, 2, 3}) { | |||
| // jit_opt_level = 1, default = {ON, OFF} | |||
| run(graph_opt_level, 1, JITConfig{UNSET, UNSET}, true, false, true); | |||
| run(graph_opt_level, 1, JITConfig{UNSET, OFF}, true, false, true); | |||
| // run(graph_opt_level, 1, JITConfig{UNSET, ON}, true, true, true); | |||
| run(graph_opt_level, 1, JITConfig{OFF, UNSET}, false, false, true); | |||
| run(graph_opt_level, 1, JITConfig{OFF, OFF}, false, false, true); | |||
| run(graph_opt_level, 1, JITConfig{OFF, ON}, false, true, true); | |||
| run(graph_opt_level, 1, JITConfig{ON, UNSET}, true, false, true); | |||
| run(graph_opt_level, 1, JITConfig{ON, OFF}, true, false, true); | |||
| // run(graph_opt_level, 1, JITConfig{ON, ON}, true, true, true); | |||
| // jit_opt_level = 2, default = {OFF, ON} | |||
| run(graph_opt_level, 2, JITConfig{UNSET, UNSET}, false, true, true); | |||
| run(graph_opt_level, 2, JITConfig{UNSET, OFF}, false, false, true); | |||
| run(graph_opt_level, 2, JITConfig{UNSET, ON}, false, true, true); | |||
| run(graph_opt_level, 2, JITConfig{OFF, UNSET}, false, true, true); | |||
| run(graph_opt_level, 2, JITConfig{OFF, OFF}, false, false, true); | |||
| run(graph_opt_level, 2, JITConfig{OFF, ON}, false, true, true); | |||
| // run(graph_opt_level, 2, JITConfig{ON, UNSET}, true, true, true); | |||
| run(graph_opt_level, 2, JITConfig{ON, OFF}, true, false, true); | |||
| // run(graph_opt_level, 2, JITConfig{ON, ON}, true, true, true); | |||
| } | |||
| } | |||
| TEST(TestJITExecutor, GradBehavior) { | |||
| REQUIRE_GPU(1); | |||
| auto cn = CompNode::load("gpu0"); | |||