GitOrigin-RevId: 662618954b
tags/v1.10.0
| @@ -71,7 +71,7 @@ public: | |||
| MGE_WIN_DECLSPEC_FUC Result get(const Key& key); | |||
| void clear(); | |||
| MGE_WIN_DECLSPEC_FUC void clear(); | |||
| private: | |||
| struct Hash { | |||
| @@ -9,7 +9,7 @@ | |||
| import os | |||
| from contextlib import contextmanager | |||
| from ._imperative_rt.core2 import get_option, set_option | |||
| from ._imperative_rt.core2 import _clear_algorithm_cache, get_option, set_option | |||
| __compute_mode = "default" | |||
| __conv_format = "default" | |||
| @@ -44,6 +44,9 @@ def benchmark_kernel(mod): | |||
| @benchmark_kernel.setter | |||
| def benchmark_kernel(mod, option: bool): | |||
| global _benchmark_kernel | |||
| # try different strategy, then clear algorithm cache | |||
| if option != _benchmark_kernel: | |||
| _clear_algorithm_cache() | |||
| _benchmark_kernel = option | |||
| @@ -9,6 +9,7 @@ | |||
| import os | |||
| from ..core import _config | |||
| from ..core._imperative_rt.core2 import _clear_algorithm_cache | |||
| from ..core.ops import builtin | |||
| from ..logger import get_logger | |||
| from ..utils.deprecation import deprecated | |||
| @@ -52,7 +53,6 @@ def set_execution_strategy(option): | |||
| * "HEURISTIC": uses heuristic to choose the fastest algorithm. | |||
| * "PROFILE": runs possible algorithms on a real device to find the best one. | |||
| * "REPRODUCIBLE": uses algorithms that are reproducible. | |||
| * "OPTIMIZED": uses algorithms that are optimized. | |||
| The default strategy is "HEURISTIC", these options can be combined to | |||
| form a combination option, e.g. PROFILE_REPRODUCIBLE is a combination | |||
| @@ -70,22 +70,25 @@ def set_execution_strategy(option): | |||
| It can also be set through the environment variable ``MEGENGINE_EXECUTION_STRATEGY``. | |||
| """ | |||
| _benchmark_kernel = False | |||
| _deterministic_kernel = False | |||
| if isinstance(option, Strategy): | |||
| _config._benchmark_kernel = ( | |||
| _benchmark_kernel = ( | |||
| True if option & _valid_string_option["PROFILE"] != Strategy(0) else False | |||
| ) | |||
| _config._deterministic_kernel = ( | |||
| _deterministic_kernel = ( | |||
| True | |||
| if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0) | |||
| else False | |||
| ) | |||
| if _benchmark_kernel != _config._benchmark_kernel: | |||
| _clear_algorithm_cache() | |||
| _config._benchmark_kernel = _benchmark_kernel | |||
| _config._deterministic_kernel = _deterministic_kernel | |||
| return | |||
| assert isinstance(option, str) | |||
| _config._benchmark_kernel = False | |||
| _config._deterministic_kernel = False | |||
| for opt in option.split("_"): | |||
| if not opt in _valid_string_option: | |||
| raise ValueError( | |||
| @@ -93,10 +96,12 @@ def set_execution_strategy(option): | |||
| _valid_string_option.keys() | |||
| ) | |||
| ) | |||
| _config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||
| _config._deterministic_kernel |= ( | |||
| _valid_string_option[opt] == Strategy.REPRODUCIBLE | |||
| ) | |||
| _benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE | |||
| _deterministic_kernel |= _valid_string_option[opt] == Strategy.REPRODUCIBLE | |||
| if _benchmark_kernel != _config._benchmark_kernel: | |||
| _clear_algorithm_cache() | |||
| _config._benchmark_kernel = _benchmark_kernel | |||
| _config._deterministic_kernel = _deterministic_kernel | |||
| @deprecated(version="1.3", reason="use get_execution_strategy() instead") | |||
| @@ -107,6 +112,3 @@ def get_conv_execution_strategy() -> str: | |||
| @deprecated(version="1.3", reason="use set_execution_strategy() instead") | |||
| def set_conv_execution_strategy(option: str): | |||
| return set_execution_strategy(option) | |||
| set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC")) | |||
| @@ -26,6 +26,7 @@ | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/plugin/profiler.h" | |||
| #include "megbrain/utils/stats.h" | |||
| #include "megdnn/algorithm_cache.h" | |||
| #include "./common.h" | |||
| #include "./grad.h" | |||
| @@ -1428,6 +1429,8 @@ void init_tensor(py::module m) { | |||
| return set_amp_prec_dtype(false, dtype_name); | |||
| }); | |||
| m.def("_clear_algorithm_cache", [] { megdnn::AlgorithmCache::instance().clear(); }); | |||
| py::register_exception<TraceError>(m, "TraceError"); | |||
| } | |||
| @@ -1,289 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import re | |||
| import subprocess | |||
| import sys | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.functional as F | |||
| from megengine import jit | |||
| from megengine.core._trace_option import set_symbolic_shape | |||
| from megengine.core.ops import builtin | |||
| from megengine.core.tensor.utils import make_shape_tuple | |||
| from megengine.functional.debug_param import set_execution_strategy | |||
| from megengine.jit import SublinearMemoryConfig | |||
| from megengine.module import ( | |||
| AdaptiveAvgPool2d, | |||
| AvgPool2d, | |||
| BatchNorm2d, | |||
| Conv2d, | |||
| Linear, | |||
| Module, | |||
| ) | |||
| from megengine.optimizer import SGD | |||
| from megengine.tensor import Tensor | |||
| Strategy = builtin.ops.Convolution.Strategy | |||
| def get_gpu_name(): | |||
| try: | |||
| gpu_info = subprocess.check_output( | |||
| ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"] | |||
| ) | |||
| gpu_info = gpu_info.decode("ascii").split("\n")[0] | |||
| except: | |||
| gpu_info = "None" | |||
| return gpu_info | |||
| def get_cpu_name(): | |||
| cpu_info = "None" | |||
| try: | |||
| cpu_info = subprocess.check_output(["cat", "/proc/cpuinfo"]).decode("ascii") | |||
| for line in cpu_info.split("\n"): | |||
| if "model name" in line: | |||
| return re.sub(".*model name.*:", "", line, 1).strip() | |||
| except: | |||
| pass | |||
| return cpu_info | |||
| def get_xpu_name(): | |||
| if mge.is_cuda_available(): | |||
| return get_gpu_name() | |||
| else: | |||
| return get_cpu_name() | |||
| class MnistNet(Module): | |||
| def __init__(self, has_bn=False, use_adaptive_pooling=False): | |||
| super().__init__() | |||
| self.conv0 = Conv2d(1, 20, kernel_size=5, bias=True) | |||
| if use_adaptive_pooling: | |||
| self.pool0 = AdaptiveAvgPool2d(12) | |||
| else: | |||
| self.pool0 = AvgPool2d(2) | |||
| self.conv1 = Conv2d(20, 20, kernel_size=5, bias=True) | |||
| self.pool1 = AvgPool2d(2) | |||
| self.fc0 = Linear(20 * 4 * 4, 500, bias=True) | |||
| self.fc1 = Linear(500, 10, bias=True) | |||
| self.bn0 = None | |||
| self.bn1 = None | |||
| if has_bn: | |||
| self.bn0 = BatchNorm2d(20) | |||
| self.bn1 = BatchNorm2d(20) | |||
| def forward(self, x): | |||
| x = self.conv0(x) | |||
| if self.bn0: | |||
| x = self.bn0(x) | |||
| x = F.relu(x) | |||
| x = self.pool0(x) | |||
| x = self.conv1(x) | |||
| if self.bn1: | |||
| x = self.bn1(x) | |||
| x = F.relu(x) | |||
| x = self.pool1(x) | |||
| x = F.flatten(x, 1) | |||
| x = self.fc0(x) | |||
| x = F.relu(x) | |||
| x = self.fc1(x) | |||
| return x | |||
| def train(data, label, net, opt, gm): | |||
| with gm: | |||
| pred = net(data) | |||
| loss = F.nn.cross_entropy(pred, label) | |||
| gm.backward(loss) | |||
| return loss | |||
| def update_model(model_path): | |||
| """ | |||
| Update the dumped model with test cases for new reference values. | |||
| The model with pre-trained weights is trained for one iter with the test data attached. | |||
| The loss and updated net state dict is dumped. | |||
| .. code-block:: python | |||
| from test_correctness import update_model | |||
| update_model('mnist_model_with_test.mge') # for gpu | |||
| update_model('mnist_model_with_test_cpu.mge') # for cpu | |||
| """ | |||
| net = MnistNet(has_bn=True) | |||
| checkpoint = mge.load(model_path) | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| opt.clear_grad() | |||
| loss = train(data, label, net, opt, gm) | |||
| opt.step() | |||
| xpu_name = get_xpu_name() | |||
| checkpoint.update( | |||
| {"net_updated": net.state_dict(), "loss": loss.numpy(), "xpu": xpu_name} | |||
| ) | |||
| mge.save(checkpoint, model_path) | |||
| def run_train( | |||
| model_path, | |||
| use_jit, | |||
| use_symbolic, | |||
| sublinear_memory_config=None, | |||
| max_err=None, | |||
| use_adaptive_pooling=False, | |||
| ): | |||
| """ | |||
| Load the model with test cases and run the training for one iter. | |||
| The loss and updated weights are compared with reference value to verify the correctness. | |||
| Dump a new file with updated result by calling update_model | |||
| if you think the test fails due to numerical rounding errors instead of bugs. | |||
| Please think twice before you do so. | |||
| """ | |||
| net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) | |||
| checkpoint = mge.load(model_path) | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| lr = checkpoint["sgd_lr"] | |||
| opt = SGD(net.parameters(), lr=lr) | |||
| gm = ad.GradManager().attach(net.parameters()) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| label = Tensor(checkpoint["label"], dtype=np.int32) | |||
| if max_err is None: | |||
| max_err = 1e-5 | |||
| train_func = train | |||
| if use_jit: | |||
| train_func = jit.trace( | |||
| train_func, | |||
| symbolic=use_symbolic, | |||
| sublinear_memory_config=sublinear_memory_config, | |||
| ) | |||
| opt.clear_grad() | |||
| loss = train_func(data, label, net, opt, gm) | |||
| opt.step() | |||
| np.testing.assert_allclose(loss.numpy(), checkpoint["loss"], atol=max_err) | |||
| for param, param_ref in zip( | |||
| net.state_dict().items(), checkpoint["net_updated"].items() | |||
| ): | |||
| assert param[0] == param_ref[0] | |||
| if "bn" in param[0]: | |||
| ref = param_ref[1].reshape(param[1].shape) | |||
| np.testing.assert_allclose(param[1], ref, atol=max_err) | |||
| else: | |||
| np.testing.assert_allclose(param[1], param_ref[1], atol=max_err) | |||
| def run_eval( | |||
| model_path, | |||
| use_symbolic, | |||
| sublinear_memory_config=None, | |||
| max_err=None, | |||
| use_adaptive_pooling=False, | |||
| ): | |||
| """ | |||
| Load the model with test cases and run the training for one iter. | |||
| The loss and updated weights are compared with reference value to verify the correctness. | |||
| Dump a new file with updated result by calling update_model | |||
| if you think the test fails due to numerical rounding errors instead of bugs. | |||
| Please think twice before you do so. | |||
| """ | |||
| net = MnistNet(has_bn=True, use_adaptive_pooling=use_adaptive_pooling) | |||
| checkpoint = mge.load(model_path) | |||
| net.load_state_dict(checkpoint["net_init"]) | |||
| data = Tensor(checkpoint["data"], dtype=np.float32) | |||
| def eval_fun(data, *, net=None): | |||
| pred = net(data) | |||
| return pred | |||
| refer_value = eval_fun(data, net=net) | |||
| eval_fun = jit.trace(eval_fun, symbolic=use_symbolic) | |||
| for _ in range(3): | |||
| new_value = eval_fun(data, net=net) | |||
| np.testing.assert_allclose(new_value.numpy(), refer_value.numpy(), atol=max_err) | |||
| @pytest.mark.skip(reason="close it when cu111 ci") | |||
| def test_correctness(): | |||
| if mge.is_cuda_available(): | |||
| model_name = "mnist_model_with_test.mge" | |||
| else: | |||
| model_name = "mnist_model_with_test_cpu.mge" | |||
| model_path = os.path.join(os.path.dirname(__file__), model_name) | |||
| set_execution_strategy(Strategy.HEURISTIC | Strategy.REPRODUCIBLE) | |||
| run_train(model_path, False, False, max_err=1e-5) | |||
| run_train(model_path, True, False, max_err=1e-5) | |||
| run_train(model_path, True, True, max_err=1e-5) | |||
| # sublinear | |||
| config = SublinearMemoryConfig(genetic_nr_iter=10) | |||
| run_train( | |||
| model_path, True, True, sublinear_memory_config=config, max_err=1e-5, | |||
| ) | |||
| run_eval(model_path, False, max_err=1e-7) | |||
| run_eval(model_path, True, max_err=1e-7) | |||
| @pytest.mark.skip(reason="close it when cu111 ci") | |||
| def test_correctness_use_adaptive_pooling(): | |||
| if mge.is_cuda_available(): | |||
| model_name = "mnist_model_with_test.mge" | |||
| else: | |||
| model_name = "mnist_model_with_test_cpu.mge" | |||
| model_path = os.path.join(os.path.dirname(__file__), model_name) | |||
| set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||
| run_train(model_path, False, False, max_err=1e-5, use_adaptive_pooling=True) | |||
| run_train(model_path, True, False, max_err=1e-5, use_adaptive_pooling=True) | |||
| run_train(model_path, True, True, max_err=1e-5, use_adaptive_pooling=True) | |||
| # sublinear | |||
| config = SublinearMemoryConfig(genetic_nr_iter=10) | |||
| run_train( | |||
| model_path, | |||
| True, | |||
| True, | |||
| sublinear_memory_config=config, | |||
| max_err=1e-5, | |||
| use_adaptive_pooling=True, | |||
| ) | |||
| run_eval(model_path, False, max_err=1e-7, use_adaptive_pooling=True) | |||
| run_eval(model_path, True, max_err=1e-7, use_adaptive_pooling=True) | |||
| @@ -7,11 +7,8 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| import platform | |||
| import re | |||
| import subprocess | |||
| import sys | |||
| from math import ceil | |||
| import numpy as np | |||
| import pytest | |||
| @@ -20,8 +17,6 @@ import megengine as mge | |||
| import megengine.autodiff as ad | |||
| import megengine.distributed as dist | |||
| import megengine.functional as F | |||
| from megengine.device import get_default_device, set_default_device | |||
| from megengine.functional.debug_param import set_execution_strategy | |||
| from megengine.module import AvgPool2d, BatchNorm2d, Conv2d, Linear, Module | |||
| from megengine.optimizer import SGD | |||
| from megengine.tensor import Tensor | |||
| @@ -198,5 +193,7 @@ def run_test( | |||
| def test_dp_correctness(): | |||
| model_name = "mnist_model_with_test.mge" | |||
| model_path = os.path.join(os.path.dirname(__file__), model_name) | |||
| set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||
| old = mge.config.deterministic_kernel | |||
| mge.config.deterministic_kernel = True | |||
| run_test(model_path, False, False, max_err=5e-5) | |||
| mge.config.deterministic_kernel = old | |||
| @@ -11,21 +11,9 @@ import itertools | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.module as M | |||
| from megengine import Parameter, tensor | |||
| from megengine.functional.debug_param import ( | |||
| get_execution_strategy, | |||
| set_execution_strategy, | |||
| ) | |||
| from megengine.module import ConvTranspose2d, ConvTranspose3d, LocalConv2d | |||
| @pytest.fixture | |||
| def reproducible(): | |||
| old = get_execution_strategy() | |||
| set_execution_strategy("HEURISTIC_REPRODUCIBLE") | |||
| yield | |||
| set_execution_strategy(old) | |||
| from megengine import tensor | |||
| # NOTE: test in module for convenience. should really test in functional | |||
| @@ -33,7 +21,9 @@ def reproducible(): | |||
| "name", | |||
| ["Conv1d", "Conv2d", "Conv3d", "ConvTranspose2d", "ConvTranspose3d", "LocalConv2d"], | |||
| ) | |||
| def test_conv_dtype_promotion(name, reproducible): | |||
| def test_conv_dtype_promotion(name): | |||
| old = mge.config.deterministic_kernel | |||
| mge.config.deterministic_kernel = True | |||
| N, Ci, Co, K = 2, 16, 32, 3 | |||
| S = (7,) * int(name[-2]) | |||
| if "Local" in name: | |||
| @@ -42,3 +32,4 @@ def test_conv_dtype_promotion(name, reproducible): | |||
| m = getattr(M, name)(Ci, Co, K) | |||
| x = tensor(np.random.random(size=(N, Ci) + S).astype("float16")) | |||
| np.testing.assert_equal(m(x).numpy(), m(x.astype("float32")).numpy()) | |||
| mge.config.deterministic_kernel = old | |||
| @@ -255,9 +255,8 @@ def test_conv_bias_int4(): | |||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | |||
| @pytest.mark.require_ngpu(1) | |||
| @pytest.mark.skipif( | |||
| get_cuda_compute_capability(0) < 61, | |||
| get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, | |||
| reason="does not support int8 when gpu compute capability less than 6.1", | |||
| ) | |||
| def test_conv_transpose2d(): | |||
| @@ -5,6 +5,7 @@ import platform | |||
| import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.core.tensor.dtype as dtype | |||
| import megengine.core.tensor.megbrain_graph as G | |||
| import megengine.functional as F | |||
| @@ -18,10 +19,6 @@ from megengine.device import ( | |||
| get_device_count, | |||
| is_cuda_available, | |||
| ) | |||
| from megengine.functional.debug_param import ( | |||
| get_execution_strategy, | |||
| set_execution_strategy, | |||
| ) | |||
| from megengine.functional.external import tensorrt_runtime_opr | |||
| from megengine.jit.tracing import trace | |||
| from megengine.tensor import Tensor | |||
| @@ -110,25 +107,30 @@ def test_matinv(): | |||
| @pytest.mark.parametrize( | |||
| "execution_strategy", ["HEURISTIC_REPRODUCIBLE", "PROFILE_REPRODUCIBLE"] | |||
| "benchmark_kernel, max_err", [(False, None), (True, 1e-5)], | |||
| ) | |||
| def test_matmul(execution_strategy): | |||
| def test_matmul(monkeypatch, benchmark_kernel, max_err): | |||
| if get_device_count("gpu") == 0 and benchmark_kernel: | |||
| return | |||
| monkeypatch.setenv("MGE_FASTRUN_CACHE_TYPE", "MEMORY") | |||
| old1, old2 = ( | |||
| mge.config.benchmark_kernel, | |||
| mge.config.deterministic_kernel, | |||
| ) | |||
| mge.config.benchmark_kernel = benchmark_kernel | |||
| mge.config.deterministic_kernel = True | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def fwd(data1, data2): | |||
| return F.matmul(data1, data2) | |||
| old = get_execution_strategy() | |||
| set_execution_strategy(execution_strategy) | |||
| max_err = None | |||
| if execution_strategy == "PROFILE_REPRODUCIBLE": | |||
| max_err = 1e-5 | |||
| data1 = Tensor(np.random.random((32, 64))) | |||
| data2 = Tensor(np.random.random((64, 16))) | |||
| result = fwd(data1, data2) | |||
| check_pygraph_dump(fwd, [data1, data2], [result], max_err=max_err) | |||
| set_execution_strategy(old) | |||
| mge.config.benchmark_kernel = old1 | |||
| mge.config.deterministic_kernel = old2 | |||
| monkeypatch.delenv("MGE_FASTRUN_CACHE_TYPE", raising=False) | |||
| def test_batchmatmul(): | |||
| @@ -290,9 +292,8 @@ def test_deformable_ps_roi_pooling(): | |||
| check_pygraph_dump(fwd, [inp, rois, trans], [result]) | |||
| @pytest.mark.require_ngpu(1) | |||
| @pytest.mark.skipif( | |||
| get_cuda_compute_capability(0) < 61, | |||
| get_device_count("gpu") > 0 and get_cuda_compute_capability(0) < 61, | |||
| reason="does not support int8 when gpu compute capability less than 6.1", | |||
| ) | |||
| def test_convbias(): | |||