GitOrigin-RevId: 7a47f5d0d5
tags/v0.3.4
| @@ -18,6 +18,7 @@ import megengine._internal as mgb | |||||
| from megengine._internal.plugin import CompGraphProfiler | from megengine._internal.plugin import CompGraphProfiler | ||||
| from ..core import Tensor, graph, tensor | from ..core import Tensor, graph, tensor | ||||
| from .sublinear_memory_config import SublinearMemConfig | |||||
| def sideeffect(f): | def sideeffect(f): | ||||
| @@ -78,10 +79,12 @@ class trace: | |||||
| * accelerated evalutaion via :meth:`.__call__` | * accelerated evalutaion via :meth:`.__call__` | ||||
| :param func: Positional only argument. | :param func: Positional only argument. | ||||
| :param symbolic: Whether to use symbolic tensor. | |||||
| :param symbolic: Whether to use symbolic tensor. Default: False | |||||
| :param opt_level: Optimization level for compiling trace. | :param opt_level: Optimization level for compiling trace. | ||||
| :param log_level: Log level. | :param log_level: Log level. | ||||
| :param profiling: Whether to profile compiled trace. | |||||
| :param enable_sublinear: Enable sublinear memory optimization. Default: False | |||||
| :param sublinear_mem_config: Configuration for sublinear memory optimization. | |||||
| :param profiling: Whether to profile compiled trace. Default: False | |||||
| """ | """ | ||||
| _active_instance = None | _active_instance = None | ||||
| @@ -103,12 +106,16 @@ class trace: | |||||
| symbolic: bool = False, | symbolic: bool = False, | ||||
| opt_level: int = None, | opt_level: int = None, | ||||
| log_level: int = None, | log_level: int = None, | ||||
| enable_sublinear: bool = False, | |||||
| sublinear_mem_config: SublinearMemConfig = None, | |||||
| profiling: bool = False | profiling: bool = False | ||||
| ): | ): | ||||
| self.__wrapped__ = func | self.__wrapped__ = func | ||||
| self._symbolic = symbolic | self._symbolic = symbolic | ||||
| self._graph_opt_level = opt_level | self._graph_opt_level = opt_level | ||||
| self._log_level = log_level | self._log_level = log_level | ||||
| self._enable_sublinear = enable_sublinear | |||||
| self._sublinear_mem_config = sublinear_mem_config | |||||
| self._status = self._UNSTARTED | self._status = self._UNSTARTED | ||||
| self._args = None | self._args = None | ||||
| self._kwargs = None | self._kwargs = None | ||||
| @@ -280,11 +287,35 @@ class trace: | |||||
| def _apply_graph_options(self, cg): | def _apply_graph_options(self, cg): | ||||
| # graph opt level | # graph opt level | ||||
| if not self._graph_opt_level is None: | |||||
| if not (self._graph_opt_level is None): | |||||
| cg.set_option("graph_opt_level", self._graph_opt_level) | cg.set_option("graph_opt_level", self._graph_opt_level) | ||||
| # log level | # log level | ||||
| if not self._log_level is None: | |||||
| if not (self._log_level is None): | |||||
| cg.set_option("log_level", self._log_level) | cg.set_option("log_level", self._log_level) | ||||
| # sublinear | |||||
| if self._enable_sublinear: | |||||
| cg.set_option("enable_sublinear_memory_opt", True) | |||||
| if not (self._sublinear_mem_config is None): | |||||
| cg.set_option( | |||||
| "sublinear_mem_cofig.lb_memory", | |||||
| self._sublinear_mem_config.lb_memory, | |||||
| ) | |||||
| cg.set_option( | |||||
| "sublinear_mem_cofig.genetic_nr_iter", | |||||
| self._sublinear_mem_config.genetic_nr_iter, | |||||
| ) | |||||
| cg.set_option( | |||||
| "sublinear_mem_cofig.genetic_pool_size", | |||||
| self._sublinear_mem_config.genetic_pool_size, | |||||
| ) | |||||
| cg.set_option( | |||||
| "sublinear_mem_cofig.thresh_nr_try", | |||||
| self._sublinear_mem_config.thresh_nr_try, | |||||
| ) | |||||
| cg.set_option( | |||||
| "sublinear_mem_cofig.num_worker", | |||||
| self._sublinear_mem_config.num_worker, | |||||
| ) | |||||
| # profile | # profile | ||||
| if self._profiling: | if self._profiling: | ||||
| self._profiler = CompGraphProfiler(cg) | self._profiler = CompGraphProfiler(cg) | ||||
| @@ -0,0 +1,46 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..core.device import get_device_count | |||||
| class SublinearMemConfig: | |||||
| r""" | |||||
| Configuration for sublinear memory optimization. | |||||
| :param thresh_nr_try: number of samples both for searching in linear space | |||||
| and around current thresh in sublinear memory optimization. Default: 10. | |||||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY'. | |||||
| :param genetic_nr_iter: number of iterations to find the best checkpoints in genetic algorithm. | |||||
| Default: 0. | |||||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER'. | |||||
| :param genetic_pool_size: number of samples for the crossover random selection | |||||
| during genetic optimization. Default: 20. | |||||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE'. | |||||
| :param lb_memory: memory lower bound of bottleneck size in MB for sublinear memory optimization. | |||||
| It can be used to perform manual tradeoff between memory and speed. Default: 0. | |||||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB'. | |||||
| :param num_worker: number of thread workers to search the optimum checkpoints | |||||
| in sublinear memory optimization. Default: half of cpu number in the system. | |||||
| It can also be set through the environmental variable 'MGB_SUBLINEAR_MEMORY_WORKERS'. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| thresh_nr_try: int = 10, | |||||
| genetic_nr_iter: int = 0, | |||||
| genetic_pool_size: int = 20, | |||||
| lb_memory: int = 0, | |||||
| num_worker: int = get_device_count("cpu") / 2, | |||||
| ): | |||||
| self.thresh_nr_try = thresh_nr_try | |||||
| self.genetic_nr_iter = genetic_nr_iter | |||||
| self.genetic_pool_size = genetic_pool_size | |||||
| self.lb_memory = lb_memory | |||||
| self.num_worker = num_worker | |||||
| @@ -42,7 +42,8 @@ bool _config::set_comp_graph_option( | |||||
| std::is_same<decltype(opt.name_chk), bool>::value || \ | std::is_same<decltype(opt.name_chk), bool>::value || \ | ||||
| std::is_same<decltype(opt.name_chk), uint8_t>::value || \ | std::is_same<decltype(opt.name_chk), uint8_t>::value || \ | ||||
| std::is_same<decltype(opt.name_chk), int16_t>::value || \ | std::is_same<decltype(opt.name_chk), int16_t>::value || \ | ||||
| std::is_same<decltype(opt.name_chk), uint16_t>::value, \ | |||||
| std::is_same<decltype(opt.name_chk), uint16_t>::value || \ | |||||
| std::is_same<decltype(opt.name_chk), int32_t>::value, \ | |||||
| "not bool/int opt"); \ | "not bool/int opt"); \ | ||||
| if (name == #name_chk) { \ | if (name == #name_chk) { \ | ||||
| auto ret = opt.name_chk; \ | auto ret = opt.name_chk; \ | ||||
| @@ -66,6 +67,11 @@ bool _config::set_comp_graph_option( | |||||
| SET_CG_OPTION(allocate_static_mem_after_graph_compile); | SET_CG_OPTION(allocate_static_mem_after_graph_compile); | ||||
| SET_CG_OPTION(log_level); | SET_CG_OPTION(log_level); | ||||
| SET_CG_OPTION(enable_sublinear_memory_opt); | SET_CG_OPTION(enable_sublinear_memory_opt); | ||||
| SET_CG_OPTION(sublinear_mem_cofig.lb_memory); | |||||
| SET_CG_OPTION(sublinear_mem_cofig.genetic_nr_iter); | |||||
| SET_CG_OPTION(sublinear_mem_cofig.genetic_pool_size); | |||||
| SET_CG_OPTION(sublinear_mem_cofig.thresh_nr_try); | |||||
| SET_CG_OPTION(sublinear_mem_cofig.num_worker); | |||||
| SET_CG_OPTION(enable_var_mem_defragment); | SET_CG_OPTION(enable_var_mem_defragment); | ||||
| SET_CG_OPTION(eager_evaluation); | SET_CG_OPTION(eager_evaluation); | ||||
| SET_CG_OPTION(enable_memory_swap); | SET_CG_OPTION(enable_memory_swap); | ||||
| @@ -17,6 +17,7 @@ import megengine as mge | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine import jit, tensor | from megengine import jit, tensor | ||||
| from megengine.functional.debug_param import set_conv_execution_strategy | from megengine.functional.debug_param import set_conv_execution_strategy | ||||
| from megengine.jit import SublinearMemConfig | |||||
| from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | ||||
| from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
| from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
| @@ -130,7 +131,14 @@ def update_model(model_path): | |||||
| mge.save(checkpoint, model_path) | mge.save(checkpoint, model_path) | ||||
| def run_test(model_path, use_jit, use_symbolic): | |||||
| def run_test( | |||||
| model_path, | |||||
| use_jit, | |||||
| use_symbolic, | |||||
| enable_sublinear=False, | |||||
| sublinear_mem_config=None, | |||||
| max_err=None, | |||||
| ): | |||||
| """ | """ | ||||
| Load the model with test cases and run the training for one iter. | Load the model with test cases and run the training for one iter. | ||||
| @@ -152,11 +160,17 @@ def run_test(model_path, use_jit, use_symbolic): | |||||
| data.set_value(checkpoint["data"]) | data.set_value(checkpoint["data"]) | ||||
| label.set_value(checkpoint["label"]) | label.set_value(checkpoint["label"]) | ||||
| max_err = 1e-5 | |||||
| if max_err is None: | |||||
| max_err = 1e-5 | |||||
| train_func = train | train_func = train | ||||
| if use_jit: | if use_jit: | ||||
| train_func = jit.trace(train_func, symbolic=use_symbolic) | |||||
| train_func = jit.trace( | |||||
| train_func, | |||||
| symbolic=use_symbolic, | |||||
| enable_sublinear=enable_sublinear, | |||||
| sublinear_mem_config=sublinear_mem_config, | |||||
| ) | |||||
| opt.zero_grad() | opt.zero_grad() | ||||
| loss = train_func(data, label, net=net, opt=opt) | loss = train_func(data, label, net=net, opt=opt) | ||||
| @@ -183,3 +197,14 @@ def test_correctness(): | |||||
| run_test(model_path, False, False) | run_test(model_path, False, False) | ||||
| run_test(model_path, True, False) | run_test(model_path, True, False) | ||||
| run_test(model_path, True, True) | run_test(model_path, True, True) | ||||
| # sublinear | |||||
| config = SublinearMemConfig(genetic_nr_iter=10) | |||||
| run_test( | |||||
| model_path, | |||||
| True, | |||||
| True, | |||||
| enable_sublinear=True, | |||||
| sublinear_mem_config=config, | |||||
| max_err=1e-5, | |||||
| ) | |||||
| @@ -18,6 +18,7 @@ import megengine._internal as mgb | |||||
| import megengine.module as M | import megengine.module as M | ||||
| from megengine import jit, tensor | from megengine import jit, tensor | ||||
| from megengine.core.tensor import Tensor | from megengine.core.tensor import Tensor | ||||
| from megengine.jit import SublinearMemConfig | |||||
| from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
| @@ -185,3 +186,14 @@ def test_dump_bn_fused(): | |||||
| mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" | mgb.cgtools.get_type(inputs[0]) == "MultipleDeviceTensorHolder" | ||||
| and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" | and mgb.cgtools.get_type(inputs[1]) == "ConvolutionForward" | ||||
| ) | ) | ||||
| # Simply verify the options passed down | |||||
| def test_sublinear(): | |||||
| config = SublinearMemConfig(genetic_nr_iter=10) | |||||
| @jit.trace(symbolic=True, enable_sublinear=True, sublinear_mem_config=config) | |||||
| def f(x): | |||||
| return x + x | |||||
| f([0.0]) | |||||
| @@ -217,7 +217,8 @@ ComputingGraphImpl::Components::Components(ComputingGraphImpl* owner) | |||||
| static_infer_comp_seq_manager{owner}, | static_infer_comp_seq_manager{owner}, | ||||
| grad_manager{owner}, | grad_manager{owner}, | ||||
| #if MGB_ENABLE_SUBLINEAR | #if MGB_ENABLE_SUBLINEAR | ||||
| seq_modifier_for_sublinear_memory{owner}, | |||||
| seq_modifier_for_sublinear_memory{owner, | |||||
| &(owner->options().sublinear_mem_cofig)}, | |||||
| #endif | #endif | ||||
| #if MGB_ENABLE_MEMORY_SWAP | #if MGB_ENABLE_MEMORY_SWAP | ||||
| memory_swap_support{owner}, | memory_swap_support{owner}, | ||||
| @@ -681,14 +681,6 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { | |||||
| std::vector<std::future<void>> m_futures; | std::vector<std::future<void>> m_futures; | ||||
| std::mutex m_mtx; | std::mutex m_mtx; | ||||
| struct Config { | |||||
| size_t thresh_nr_try = 10; | |||||
| size_t genetic_nr_iter = 0; | |||||
| size_t genetic_pool_size = 20; | |||||
| double lb_memory = 0; | |||||
| }; | |||||
| Config m_config; | |||||
| /*! | /*! | ||||
| * \brief check given thresh, and update states | * \brief check given thresh, and update states | ||||
| * \return bottleneck value for given thresh | * \return bottleneck value for given thresh | ||||
| @@ -725,20 +717,22 @@ class SeqModifierForSublinearMemory::ActionSearcherSingleCN { | |||||
| public: | public: | ||||
| ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) | ActionSearcherSingleCN(SeqModifierForSublinearMemory* par) | ||||
| : m_par_modifier{par} { | : m_par_modifier{par} { | ||||
| auto & m_config = m_par_modifier->m_config; | |||||
| //! allow environmental variable to overwrite the setting | |||||
| if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_THRESH_NR_TRY")) { | ||||
| m_config.thresh_nr_try = std::stoi(env); | |||||
| m_config->thresh_nr_try = std::stoi(env); | |||||
| } | } | ||||
| if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER")) { | ||||
| m_config.genetic_nr_iter = std::stoi(env); | |||||
| m_config->genetic_nr_iter = std::stoi(env); | |||||
| } | } | ||||
| if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_GENETIC_POOL_SIZE")) { | ||||
| auto psize = static_cast<size_t>(std::stoi(env)); | auto psize = static_cast<size_t>(std::stoi(env)); | ||||
| mgb_assert(psize > 0 || m_config.genetic_nr_iter == 0, | |||||
| mgb_assert(psize > 0 || m_config->genetic_nr_iter == 0, | |||||
| "invalid pool size %zu in genetic algorithm,", psize); | "invalid pool size %zu in genetic algorithm,", psize); | ||||
| m_config.genetic_pool_size = psize; | |||||
| m_config->genetic_pool_size = psize; | |||||
| } | } | ||||
| if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { | if (auto env = MGB_GETENV("MGB_SUBLINEAR_MEMORY_LOWER_BOUND_MB")) { | ||||
| m_config.lb_memory = std::stod(env) * 1024 * 1024; | |||||
| m_config->lb_memory = std::stoi(env) * 1024 * 1024; | |||||
| } | } | ||||
| } | } | ||||
| @@ -812,7 +806,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { | |||||
| invoke_search(thresh); | invoke_search(thresh); | ||||
| } | } | ||||
| size_t NR_TRY = m_config.thresh_nr_try; | |||||
| size_t NR_TRY = m_par_modifier->m_config->thresh_nr_try; | |||||
| // search in linear space | // search in linear space | ||||
| auto step = init_thresh / (NR_TRY + 1); | auto step = init_thresh / (NR_TRY + 1); | ||||
| @@ -833,8 +827,8 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_preset() { | |||||
| void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | ||||
| RNGxorshf rng(2333); | RNGxorshf rng(2333); | ||||
| size_t POOL_SIZE = m_config.genetic_pool_size; | |||||
| size_t NR_ITER = m_config.genetic_nr_iter; | |||||
| size_t POOL_SIZE = m_par_modifier->m_config->genetic_pool_size; | |||||
| size_t NR_ITER = m_par_modifier->m_config->genetic_nr_iter; | |||||
| auto mutation = [&](const SplitPointSet& sps) { | auto mutation = [&](const SplitPointSet& sps) { | ||||
| auto s = *sps; | auto s = *sps; | ||||
| size_t length = s.size(); | size_t length = s.size(); | ||||
| @@ -953,7 +947,7 @@ void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_genetic() { | |||||
| } | } | ||||
| void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { | void SeqModifierForSublinearMemory::ActionSearcherSingleCN::search_refine() { | ||||
| size_t lower_bound = m_config.lb_memory; | |||||
| size_t lower_bound = m_par_modifier->m_config->lb_memory; | |||||
| if (m_min_bottleneck >= lower_bound) | if (m_min_bottleneck >= lower_bound) | ||||
| return; | return; | ||||
| OprFootprint footprint; | OprFootprint footprint; | ||||
| @@ -1052,7 +1046,7 @@ SeqModifierForSublinearMemory::ActionSearcherSingleCN::search( | |||||
| msg.push_back('\n'); | msg.push_back('\n'); | ||||
| msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", | msg.append(ssprintf("m_min_bottleneck: %-10.2f\n", | ||||
| m_min_bottleneck * SIZE2MB)); | m_min_bottleneck * SIZE2MB)); | ||||
| if(!m_config.genetic_nr_iter) { | |||||
| if(!m_par_modifier->m_config->genetic_nr_iter) { | |||||
| msg.append(ssprintf( | msg.append(ssprintf( | ||||
| "\nGenetic algorithm is currently DISABLED, " | "\nGenetic algorithm is currently DISABLED, " | ||||
| "set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" | "set MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER [default = 0]" | ||||
| @@ -1124,7 +1118,7 @@ SeqModifierForSublinearMemory::search_action( | |||||
| "invalid planner concurrency: %zu", set); | "invalid planner concurrency: %zu", set); | ||||
| planner_concur = set; | planner_concur = set; | ||||
| } else { | } else { | ||||
| planner_concur = sys::get_cpu_count() / 2; | |||||
| planner_concur = m_config->num_worker; | |||||
| } | } | ||||
| mgb_log_debug("use %zu threads to search for sublinear memory plan; " | mgb_log_debug("use %zu threads to search for sublinear memory plan; " | ||||
| @@ -1350,8 +1344,8 @@ SeqModifierForSublinearMemory::prev_min_bottleneck() { | |||||
| } | } | ||||
| SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | SeqModifierForSublinearMemory::SeqModifierForSublinearMemory( | ||||
| ComputingGraphImpl* owner) | |||||
| : m_mem_opt(owner), m_owner_graph(owner) {} | |||||
| ComputingGraphImpl* owner, Config* config_p) | |||||
| : m_config(config_p), m_mem_opt(owner), m_owner_graph(owner) {} | |||||
| #endif // !MGB_ENABLE_SUBLINEAR | #endif // !MGB_ENABLE_SUBLINEAR | ||||
| @@ -12,6 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "./memory_optimizer.h" | #include "./memory_optimizer.h" | ||||
| #include "megbrain/graph/cg.h" | |||||
| #include "megbrain/utils/async_worker.h" | #include "megbrain/utils/async_worker.h" | ||||
| #if MGB_ENABLE_SUBLINEAR | #if MGB_ENABLE_SUBLINEAR | ||||
| @@ -31,6 +32,10 @@ class SeqModifierForSublinearMemory { | |||||
| using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | using SeqModifyAction = std::unordered_map<OperatorNodeBase*, OprNodeArray>; | ||||
| using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | using SplitPointSet = std::shared_ptr<std::vector<size_t>>; | ||||
| //! Config options | |||||
| using Config = mgb::cg::ComputingGraph::Options::SublinearMemConfig; | |||||
| Config* m_config; | |||||
| //! get modifications to be taken under some specific constraints | //! get modifications to be taken under some specific constraints | ||||
| class ModifyActionPlanner; | class ModifyActionPlanner; | ||||
| @@ -104,7 +109,7 @@ class SeqModifierForSublinearMemory { | |||||
| } | } | ||||
| public: | public: | ||||
| SeqModifierForSublinearMemory(ComputingGraphImpl* owner); | |||||
| SeqModifierForSublinearMemory(ComputingGraphImpl* owner, Config* config_g); | |||||
| //! see memory_optimizer set_priority_before_opt | //! see memory_optimizer set_priority_before_opt | ||||
| void set_priority_before_opt(const VarNodeArray& endpoints) { | void set_priority_before_opt(const VarNodeArray& endpoints) { | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include "megbrain/graph/static_infer.h" | #include "megbrain/graph/static_infer.h" | ||||
| #include "megbrain/graph/seq_comp_node_opt.h" | #include "megbrain/graph/seq_comp_node_opt.h" | ||||
| #include "megbrain/utils/event.h" | #include "megbrain/utils/event.h" | ||||
| #include "megbrain/system.h" | |||||
| #if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||
| #include "megbrain/utils/json.h" | #include "megbrain/utils/json.h" | ||||
| @@ -300,6 +301,15 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
| //! whether to enable sublinear memory optimization | //! whether to enable sublinear memory optimization | ||||
| bool enable_sublinear_memory_opt = false; | bool enable_sublinear_memory_opt = false; | ||||
| //! Control parameter for sublinear memory optimization | |||||
| struct SublinearMemConfig { | |||||
| int thresh_nr_try = 10; | |||||
| int genetic_nr_iter = 0; | |||||
| int genetic_pool_size = 20; | |||||
| int lb_memory = 0; | |||||
| int num_worker = sys::get_cpu_count() / 2; | |||||
| } sublinear_mem_cofig; | |||||
| //! do not re-profile to select best impl algo when input shape | //! do not re-profile to select best impl algo when input shape | ||||
| //! changes (use previous algo) | //! changes (use previous algo) | ||||
| bool no_profiling_on_shape_change = false; | bool no_profiling_on_shape_change = false; | ||||
| @@ -504,57 +504,47 @@ TEST(TestSublinearMemory, DepsInTopoSort) { | |||||
| } | } | ||||
| TEST(TestSublinearMemory, BadOpr) { | TEST(TestSublinearMemory, BadOpr) { | ||||
| constexpr const char* KEY = "MGB_SUBLINEAR_MEMORY_GENETIC_NR_ITER"; | |||||
| auto old_value = getenv(KEY); | |||||
| setenv(KEY, "50", 1); | |||||
| MGB_TRY { | |||||
| HostTensorGenerator<> gen; | |||||
| auto cn = CompNode::load("xpu0"); | |||||
| constexpr size_t N = 1024, Scale = 2; | |||||
| auto host_x = gen({N}, cn); | |||||
| for (bool bad : {false, true}) { | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), | |||||
| bad_var = SublinearBadOpr::make(x, bad, Scale), | |||||
| y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), | |||||
| y1 = SublinearBadOpr::make(y0, false, N * Scale), | |||||
| y = y1 + 1, | |||||
| z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); | |||||
| set_priority(y0, 0); | |||||
| set_priority(y1, 1); | |||||
| set_priority(y, 2); | |||||
| set_priority(z, 3); | |||||
| graph->options().graph_opt_level = 0; | |||||
| graph->options().enable_sublinear_memory_opt = 1; | |||||
| auto func = graph->compile({{y, {}}, {z, {}}}); | |||||
| auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | |||||
| ->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | |||||
| // bottleneck: | |||||
| // if bad : y = y1 + 1, bad_var should be saved to calculate | |||||
| // z later, total memory usage is | |||||
| // N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) | |||||
| // else : bad_var = BadOpr(x), total memory usage is | |||||
| // N(x) + N * scale(bad_var), bad_var would be recomputed | |||||
| // when calculate z = reduce(bad_var) | |||||
| size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; | |||||
| ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); | |||||
| size_t nr_bad_opr = 0; | |||||
| auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { | |||||
| if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { | |||||
| ++ nr_bad_opr; | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| func->iter_opr_seq(count_up); | |||||
| ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); | |||||
| } | |||||
| } MGB_FINALLY( | |||||
| if (old_value) { | |||||
| setenv(KEY, old_value, 1); | |||||
| } else { | |||||
| unsetenv(KEY); | |||||
| } | |||||
| ); | |||||
| HostTensorGenerator<> gen; | |||||
| auto cn = CompNode::load("xpu0"); | |||||
| constexpr size_t N = 1024, Scale = 2; | |||||
| auto host_x = gen({N}, cn); | |||||
| for (bool bad : {false, true}) { | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make_no_fwd(*graph, host_x), | |||||
| bad_var = SublinearBadOpr::make(x, bad, Scale), | |||||
| y0 = opr::reduce_sum(bad_var, x.make_scalar_dt(1)), | |||||
| y1 = SublinearBadOpr::make(y0, false, N * Scale), | |||||
| y = y1 + 1, | |||||
| z = opr::reduce_max(bad_var, x.make_scalar_dt(1)); | |||||
| set_priority(y0, 0); | |||||
| set_priority(y1, 1); | |||||
| set_priority(y, 2); | |||||
| set_priority(z, 3); | |||||
| graph->options().graph_opt_level = 0; | |||||
| graph->options().enable_sublinear_memory_opt = 1; | |||||
| graph->options().sublinear_mem_cofig.genetic_nr_iter = 50; | |||||
| auto func = graph->compile({{y, {}}, {z, {}}}); | |||||
| auto&& results = static_cast<cg::ComputingGraphImpl*>(graph.get()) | |||||
| ->seq_modifier_for_sublinear_memory().prev_min_bottleneck(); | |||||
| // bottleneck: | |||||
| // if bad : y = y1 + 1, bad_var should be saved to calculate | |||||
| // z later, total memory usage is | |||||
| // N * sclae * 2(bad_var and y1) + 1 (immutable tensor 1) | |||||
| // else : bad_var = BadOpr(x), total memory usage is | |||||
| // N(x) + N * scale(bad_var), bad_var would be recomputed | |||||
| // when calculate z = reduce(bad_var) | |||||
| size_t expect = bad ? N * Scale * 2 + 1 : N * Scale + N; | |||||
| ASSERT_EQ(results.at(cn), expect * host_x->dtype().size()); | |||||
| size_t nr_bad_opr = 0; | |||||
| auto count_up = [&nr_bad_opr](cg::OperatorNodeBase* op) { | |||||
| if (op->dyn_typeinfo() == SublinearBadOpr::typeinfo()) { | |||||
| ++ nr_bad_opr; | |||||
| } | |||||
| return true; | |||||
| }; | |||||
| func->iter_opr_seq(count_up); | |||||
| ASSERT_EQ(nr_bad_opr, bad ? 2 : 3); | |||||
| } | |||||
| } | } | ||||
| #else | #else | ||||