GitOrigin-RevId: 0f0dc001cf
tags/v1.5.0
| @@ -181,11 +181,6 @@ def synchronized(func: Callable): | |||||
| return wrapper | return wrapper | ||||
| def _get_device_count_worker(queue, device_type): | |||||
| num = get_device_count(device_type) | |||||
| queue.put(num) | |||||
| def _check_device_initialized(device_type: str, rank: int): | def _check_device_initialized(device_type: str, rank: int): | ||||
| try: | try: | ||||
| test = Tensor(1, device=(device_type + str(rank))) | test = Tensor(1, device=(device_type + str(rank))) | ||||
| @@ -198,19 +193,6 @@ def _check_device_initialized(device_type: str, rank: int): | |||||
| raise RuntimeError(errmsg) | raise RuntimeError(errmsg) | ||||
| def get_device_count_by_fork(device_type: str): | |||||
| """ | |||||
| Get device count in fork thread. | |||||
| See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork | |||||
| for more information. | |||||
| """ | |||||
| q = mp.Queue() | |||||
| p = mp.Process(target=_get_device_count_worker, args=(q, device_type)) | |||||
| p.start() | |||||
| p.join() | |||||
| return q.get() | |||||
| def bcast_list_(inps: list, group: Group = WORLD): | def bcast_list_(inps: list, group: Group = WORLD): | ||||
| """ | """ | ||||
| Broadcast tensors between given group. | Broadcast tensors between given group. | ||||
| @@ -13,9 +13,10 @@ import queue | |||||
| from .. import _exit | from .. import _exit | ||||
| from ..core._imperative_rt.core2 import full_sync | from ..core._imperative_rt.core2 import full_sync | ||||
| from ..device import get_device_count | |||||
| from ..logger import get_logger | from ..logger import get_logger | ||||
| from .group import _set_machine_ranks, group_barrier, init_process_group | from .group import _set_machine_ranks, group_barrier, init_process_group | ||||
| from .helper import _check_device_initialized, get_device_count_by_fork | |||||
| from .helper import _check_device_initialized | |||||
| from .server import Client, Server | from .server import Client, Server | ||||
| WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( | WARN_SUBPROCESS_EXIT_WITHOUT_RETURN = ( | ||||
| @@ -91,9 +92,7 @@ class launcher: | |||||
| backend="auto", | backend="auto", | ||||
| ): | ): | ||||
| self.func = func | self.func = func | ||||
| self.n_gpus = ( | |||||
| n_gpus if n_gpus is not None else get_device_count_by_fork(device_type) | |||||
| ) | |||||
| self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type) | |||||
| self.world_size = world_size if world_size is not None else self.n_gpus | self.world_size = world_size if world_size is not None else self.n_gpus | ||||
| self.rank_start = rank_start | self.rank_start = rank_start | ||||
| self.master_ip = master_ip | self.master_ip = master_ip | ||||
| @@ -1188,11 +1188,11 @@ def copy(inp, device=None): | |||||
| import numpy as np | import numpy as np | ||||
| import platform | import platform | ||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| x = tensor([1, 2, 3], np.int32) | x = tensor([1, 2, 3], np.int32) | ||||
| if 1 == get_device_count_by_fork("gpu"): | |||||
| if 1 == get_device_count("gpu"): | |||||
| y = F.copy(x, "cpu1") | y = F.copy(x, "cpu1") | ||||
| print(y.numpy()) | print(y.numpy()) | ||||
| else: | else: | ||||
| @@ -15,7 +15,7 @@ import megengine.functional | |||||
| import megengine.module | import megengine.module | ||||
| from megengine import Parameter | from megengine import Parameter | ||||
| from megengine.core._imperative_rt.core2 import sync | from megengine.core._imperative_rt.core2 import sync | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| from megengine.experimental.autograd import ( | from megengine.experimental.autograd import ( | ||||
| disable_higher_order_directive, | disable_higher_order_directive, | ||||
| enable_higher_order_directive, | enable_higher_order_directive, | ||||
| @@ -25,7 +25,7 @@ from megengine.module import Linear, Module | |||||
| sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) | ||||
| _ngpu = get_device_count_by_fork("gpu") | |||||
| _ngpu = get_device_count("gpu") | |||||
| @pytest.fixture(autouse=True) | @pytest.fixture(autouse=True) | ||||
| @@ -16,7 +16,6 @@ import megengine.autodiff as ad | |||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| import megengine.optimizer as optimizer | import megengine.optimizer as optimizer | ||||
| from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.module import Module | from megengine.module import Module | ||||
| from megengine.optimizer import SGD | from megengine.optimizer import SGD | ||||
| @@ -18,7 +18,6 @@ import megengine.functional as F | |||||
| import megengine.module as M | import megengine.module as M | ||||
| import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
| from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| @@ -20,7 +20,6 @@ from megengine.core._imperative_rt import CompNode, TensorAttr, imperative | |||||
| from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync | ||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| from megengine.core.ops.builtin import Elemwise, Identity | from megengine.core.ops.builtin import Elemwise, Identity | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.functional.distributed import remote_recv, remote_send | from megengine.functional.distributed import remote_recv, remote_send | ||||
| @@ -31,7 +31,7 @@ from megengine.core.tensor.dtype import ( | |||||
| quint4, | quint4, | ||||
| quint8, | quint8, | ||||
| ) | ) | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
| @@ -184,8 +184,7 @@ def test_dtype_int4_ffi_handle(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("gpu") != 0, | |||||
| reason="TypeCvt to quint4 is not supported on GPU", | |||||
| get_device_count("gpu") != 0, reason="TypeCvt to quint4 is not supported on GPU", | |||||
| ) | ) | ||||
| def test_quint4_typecvt(): | def test_quint4_typecvt(): | ||||
| device = "xpux" | device = "xpux" | ||||
| @@ -17,11 +17,7 @@ import megengine as mge | |||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit | from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit | ||||
| from megengine.device import get_default_device | from megengine.device import get_default_device | ||||
| from megengine.distributed.helper import ( | |||||
| get_device_count_by_fork, | |||||
| param_pack_concat, | |||||
| param_pack_split, | |||||
| ) | |||||
| from megengine.distributed.helper import param_pack_concat, param_pack_split | |||||
| def _assert_q_empty(q): | def _assert_q_empty(q): | ||||
| @@ -22,8 +22,7 @@ from megengine import Parameter, Tensor, is_cuda_available, tensor | |||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.jit import trace | |||||
| from megengine.device import get_device_count | |||||
| def test_where(): | def test_where(): | ||||
| @@ -613,7 +612,7 @@ def test_nms(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||||
| get_device_count("gpu") > 0, reason="cuda does not support nchw int8" | |||||
| ) | ) | ||||
| def test_conv_bias(): | def test_conv_bias(): | ||||
| inp_scale = 1.5 | inp_scale = 1.5 | ||||
| @@ -715,9 +714,7 @@ def test_conv_bias(): | |||||
| run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "relu") | ||||
| @pytest.mark.skipif( | |||||
| get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||||
| ) | |||||
| @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda") | |||||
| def test_batch_conv_bias(): | def test_batch_conv_bias(): | ||||
| inp_scale = 1.5 | inp_scale = 1.5 | ||||
| w_scale = 2.5 | w_scale = 2.5 | ||||
| @@ -16,7 +16,6 @@ import megengine.distributed as dist | |||||
| from megengine import Parameter, tensor | from megengine import Parameter, tensor | ||||
| from megengine.core._imperative_rt.core2 import sync | from megengine.core._imperative_rt.core2 import sync | ||||
| from megengine.device import get_default_device, set_default_device | from megengine.device import get_default_device, set_default_device | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.functional.distributed import ( | from megengine.functional.distributed import ( | ||||
| all_gather, | all_gather, | ||||
| all_reduce_max, | all_reduce_max, | ||||
| @@ -18,7 +18,6 @@ from megengine import tensor | |||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.tensor import megbrain_graph as G | from megengine.core.tensor import megbrain_graph as G | ||||
| from megengine.core.tensor.utils import astensor1d | from megengine.core.tensor.utils import astensor1d | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.utils.network import Network, set_symbolic_shape | from megengine.utils.network import Network, set_symbolic_shape | ||||
| from megengine.utils.network_node import VarNode | from megengine.utils.network_node import VarNode | ||||
| @@ -16,7 +16,6 @@ import megengine as mge | |||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| from megengine import Tensor | from megengine import Tensor | ||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from megengine.module import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
| _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | _assert_allclose = functools.partial(np.testing.assert_allclose, atol=5e-6, rtol=5e-6) | ||||
| @@ -6,7 +6,7 @@ import pytest | |||||
| import megengine.utils.comp_graph_tools as cgtools | import megengine.utils.comp_graph_tools as cgtools | ||||
| from megengine import jit, tensor | from megengine import jit, tensor | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| from megengine.functional import expand_dims | from megengine.functional import expand_dims | ||||
| from megengine.module import ( | from megengine.module import ( | ||||
| BatchMatMulActivation, | BatchMatMulActivation, | ||||
| @@ -101,9 +101,7 @@ def test_qat_conv(): | |||||
| np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | ||||
| @pytest.mark.skipif( | |||||
| get_device_count_by_fork("gpu") > 0, reason="no int8 algorithm on cuda" | |||||
| ) | |||||
| @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda") | |||||
| def test_qat_batchmatmul_activation(): | def test_qat_batchmatmul_activation(): | ||||
| batch = 4 | batch = 4 | ||||
| in_features = 8 | in_features = 8 | ||||
| @@ -13,7 +13,7 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.distributed as dist | import megengine.distributed as dist | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| from megengine.quantization import QuantMode, create_qparams | from megengine.quantization import QuantMode, create_qparams | ||||
| from megengine.quantization.observer import ( | from megengine.quantization.observer import ( | ||||
| ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
| @@ -78,7 +78,7 @@ def test_passive_observer(): | |||||
| @pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_sync_min_max_observer(): | def test_sync_min_max_observer(): | ||||
| word_size = get_device_count_by_fork("gpu") | |||||
| word_size = get_device_count("gpu") | |||||
| x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | x = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | ||||
| np_min, np_max = x.min(), x.max() | np_min, np_max = x.min(), x.max() | ||||
| @@ -96,7 +96,7 @@ def test_sync_min_max_observer(): | |||||
| @pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
| @pytest.mark.isolated_distributed | @pytest.mark.isolated_distributed | ||||
| def test_sync_exponential_moving_average_observer(): | def test_sync_exponential_moving_average_observer(): | ||||
| word_size = get_device_count_by_fork("gpu") | |||||
| word_size = get_device_count("gpu") | |||||
| t = np.random.rand() | t = np.random.rand() | ||||
| x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | ||||
| x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | ||||
| @@ -12,7 +12,7 @@ import pytest | |||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| from megengine.core.tensor import dtype | from megengine.core.tensor import dtype | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| from megengine.functional.elemwise import _elemwise_multi_type, _elwise | from megengine.functional.elemwise import _elemwise_multi_type, _elwise | ||||
| from megengine.quantization import QuantMode, create_qparams | from megengine.quantization import QuantMode, create_qparams | ||||
| @@ -68,7 +68,7 @@ def test_elemwise(kind): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("gpu") > 0, reason="cuda does not support nchw int8" | |||||
| get_device_count("gpu") > 0, reason="cuda does not support nchw int8" | |||||
| ) | ) | ||||
| def test_conv_bias(): | def test_conv_bias(): | ||||
| inp_scale = np.float32(np.random.rand() + 1) | inp_scale = np.float32(np.random.rand() + 1) | ||||
| @@ -26,12 +26,12 @@ from megengine.core.ops.builtin import ( | |||||
| PoissonRNG, | PoissonRNG, | ||||
| UniformRNG, | UniformRNG, | ||||
| ) | ) | ||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count | |||||
| from megengine.random import RNG, seed, uniform | from megengine.random import RNG, seed, uniform | ||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||||
| ) | ) | ||||
| def test_gaussian_op(): | def test_gaussian_op(): | ||||
| shape = ( | shape = ( | ||||
| @@ -61,7 +61,7 @@ def test_gaussian_op(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||||
| ) | ) | ||||
| def test_uniform_op(): | def test_uniform_op(): | ||||
| shape = ( | shape = ( | ||||
| @@ -89,7 +89,7 @@ def test_uniform_op(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||||
| ) | ) | ||||
| def test_gamma_op(): | def test_gamma_op(): | ||||
| _shape, _scale = 2, 0.8 | _shape, _scale = 2, 0.8 | ||||
| @@ -117,7 +117,7 @@ def test_gamma_op(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||||
| ) | ) | ||||
| def test_beta_op(): | def test_beta_op(): | ||||
| _alpha, _beta = 2, 0.8 | _alpha, _beta = 2, 0.8 | ||||
| @@ -148,7 +148,7 @@ def test_beta_op(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||||
| ) | ) | ||||
| def test_poisson_op(): | def test_poisson_op(): | ||||
| lam = F.full([8, 9, 11, 12], value=2, dtype="float32") | lam = F.full([8, 9, 11, 12], value=2, dtype="float32") | ||||
| @@ -171,7 +171,7 @@ def test_poisson_op(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", | |||||
| get_device_count("xpu") <= 2, reason="xpu counts need > 2", | |||||
| ) | ) | ||||
| def test_permutation_op(): | def test_permutation_op(): | ||||
| n = 1000 | n = 1000 | ||||
| @@ -205,7 +205,7 @@ def test_permutation_op(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
| ) | ) | ||||
| def test_UniformRNG(): | def test_UniformRNG(): | ||||
| m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
| @@ -233,7 +233,7 @@ def test_UniformRNG(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
| ) | ) | ||||
| def test_NormalRNG(): | def test_NormalRNG(): | ||||
| m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
| @@ -262,7 +262,7 @@ def test_NormalRNG(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
| ) | ) | ||||
| def test_GammaRNG(): | def test_GammaRNG(): | ||||
| m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
| @@ -295,7 +295,7 @@ def test_GammaRNG(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
| ) | ) | ||||
| def test_BetaRNG(): | def test_BetaRNG(): | ||||
| m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
| @@ -330,7 +330,7 @@ def test_BetaRNG(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
| ) | ) | ||||
| def test_PoissonRNG(): | def test_PoissonRNG(): | ||||
| m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
| @@ -359,7 +359,7 @@ def test_PoissonRNG(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", | |||||
| get_device_count("xpu") <= 1, reason="xpu counts need > 1", | |||||
| ) | ) | ||||
| def test_PermutationRNG(): | def test_PermutationRNG(): | ||||
| m1 = RNG(seed=111, device="xpu0") | m1 = RNG(seed=111, device="xpu0") | ||||
| @@ -13,8 +13,7 @@ import megengine.random as rand | |||||
| from megengine.core._imperative_rt.core2 import apply | from megengine.core._imperative_rt.core2 import apply | ||||
| from megengine.core._wrap import Device | from megengine.core._wrap import Device | ||||
| from megengine.core.ops import builtin | from megengine.core.ops import builtin | ||||
| from megengine.device import is_cuda_available | |||||
| from megengine.distributed.helper import get_device_count_by_fork | |||||
| from megengine.device import get_device_count, is_cuda_available | |||||
| from megengine.functional.external import tensorrt_runtime_opr | from megengine.functional.external import tensorrt_runtime_opr | ||||
| from megengine.jit.tracing import trace | from megengine.jit.tracing import trace | ||||
| from megengine.tensor import Tensor | from megengine.tensor import Tensor | ||||
| @@ -273,7 +272,7 @@ def test_deformable_ps_roi_pooling(): | |||||
| @pytest.mark.skipif( | @pytest.mark.skipif( | ||||
| get_device_count_by_fork("gpu") > 0, | |||||
| get_device_count("gpu") > 0, | |||||
| reason="does not support int8 when gpu compute capability less than 6.1", | reason="does not support int8 when gpu compute capability less than 6.1", | ||||
| ) | ) | ||||
| def test_convbias(): | def test_convbias(): | ||||
| @@ -27,8 +27,14 @@ using namespace mgb; | |||||
| #include <thread> | #include <thread> | ||||
| #include <cuda.h> | |||||
| #include <cuda_runtime.h> | #include <cuda_runtime.h> | ||||
| #ifdef __unix__ | |||||
| #include <unistd.h> | |||||
| #include <sys/wait.h> | |||||
| #endif | |||||
| using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; | using CudaCompNodeImpl = CudaCompNode::CompNodeImpl; | ||||
| namespace { | namespace { | ||||
| @@ -700,19 +706,90 @@ void CudaCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { | |||||
| /* ===================== CudaCompNode static methods ===================== */ | /* ===================== CudaCompNode static methods ===================== */ | ||||
| namespace { | |||||
| #ifndef __unix__ | |||||
| CUresult get_device_count_forksafe(int* pcnt) { | |||||
| cuInit(0); | |||||
| return cuDeviceGetCount(pcnt); | |||||
| } | |||||
| #else | |||||
| struct RAIICloseFD : NonCopyableObj { | |||||
| int m_fd = -1; | |||||
| RAIICloseFD(int fd) : m_fd(fd) {} | |||||
| ~RAIICloseFD() {close();} | |||||
| void close() { | |||||
| if (m_fd != -1) { | |||||
| ::close(m_fd); | |||||
| m_fd = -1; | |||||
| } | |||||
| } | |||||
| }; | |||||
| // an implementation that does not call cuInit | |||||
| CUresult get_device_count_forksafe(int* pcnt) { | |||||
| auto err = cuDeviceGetCount(pcnt); | |||||
| if (err != CUDA_ERROR_NOT_INITIALIZED) return err; | |||||
| // cuInit not called, call it in child process | |||||
| int fd[2]; | |||||
| mgb_assert(pipe(fd) == 0, "pipe() failed"); | |||||
| int fdr = fd[0], fdw = fd[1]; | |||||
| RAIICloseFD fdr_guard(fdr); | |||||
| RAIICloseFD fdw_guard(fdw); | |||||
| auto cpid = fork(); | |||||
| mgb_assert(cpid != -1, "fork() failed"); | |||||
| if (cpid == 0) { | |||||
| fdr_guard.close(); | |||||
| do { | |||||
| err = cuInit(0); | |||||
| if (err != CUDA_SUCCESS) break; | |||||
| err = cuDeviceGetCount(pcnt); | |||||
| } while (0); | |||||
| auto sz = write(fdw, &err, sizeof(err)); | |||||
| if (sz == sizeof(err) && err == CUDA_SUCCESS) { | |||||
| sz = write(fdw, pcnt, sizeof(*pcnt)); | |||||
| } | |||||
| fdw_guard.close(); | |||||
| std::quick_exit(0); | |||||
| } | |||||
| fdw_guard.close(); | |||||
| auto sz = read(fdr, &err, sizeof(err)); | |||||
| mgb_assert(sz == sizeof(err), "failed to read error code from child"); | |||||
| if (err == CUDA_SUCCESS) { | |||||
| sz = read(fdr, pcnt, sizeof(*pcnt)); | |||||
| mgb_assert(sz == sizeof(*pcnt), "failed to read device count from child"); | |||||
| return err; | |||||
| } | |||||
| // try again, maybe another thread called cuInit while we fork | |||||
| auto err2 = cuDeviceGetCount(pcnt); | |||||
| if (err2 == CUDA_SUCCESS) return err2; | |||||
| if (err2 == CUDA_ERROR_NOT_INITIALIZED) return err; | |||||
| return err2; | |||||
| } | |||||
| #endif | |||||
| const char* cu_get_error_string(CUresult err) { | |||||
| const char* ret = nullptr; | |||||
| cuGetErrorString(err, &ret); | |||||
| if (!ret) ret = "unknown cuda error"; | |||||
| return ret; | |||||
| } | |||||
| } // namespace | |||||
| bool CudaCompNode::available() { | bool CudaCompNode::available() { | ||||
| static int result = -1; | static int result = -1; | ||||
| static Spinlock mtx; | static Spinlock mtx; | ||||
| MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
| if (result == -1) { | if (result == -1) { | ||||
| int ndev = -1; | int ndev = -1; | ||||
| auto err = cudaGetDeviceCount(&ndev); | |||||
| result = err == cudaSuccess && ndev > 0; | |||||
| auto err = get_device_count_forksafe(&ndev); | |||||
| result = err == CUDA_SUCCESS && ndev > 0; | |||||
| if (!result) { | if (!result) { | ||||
| mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", | mgb_log_warn("cuda unavailable: %s(%d) ndev=%d", | ||||
| cudaGetErrorString(err), static_cast<int>(err), ndev); | |||||
| cu_get_error_string(err), static_cast<int>(err), ndev); | |||||
| } | } | ||||
| if (err == cudaErrorInitializationError) { | |||||
| if (err == CUDA_ERROR_NOT_INITIALIZED) { | |||||
| mgb_throw(std::runtime_error, "cuda initialization error."); | mgb_throw(std::runtime_error, "cuda initialization error."); | ||||
| } | } | ||||
| } | } | ||||
| @@ -857,11 +934,11 @@ size_t CudaCompNode::get_device_count(bool warn) { | |||||
| static Spinlock mtx; | static Spinlock mtx; | ||||
| MGB_LOCK_GUARD(mtx); | MGB_LOCK_GUARD(mtx); | ||||
| if (cnt == -1) { | if (cnt == -1) { | ||||
| auto err = cudaGetDeviceCount(&cnt); | |||||
| if (err != cudaSuccess) { | |||||
| auto err = get_device_count_forksafe(&cnt); | |||||
| if (err != CUDA_SUCCESS) { | |||||
| if (warn) | if (warn) | ||||
| mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", | mgb_log_error("cudaGetDeviceCount failed: %s (err %d)", | ||||
| cudaGetErrorString(err), int(err)); | |||||
| cu_get_error_string(err), int(err)); | |||||
| cnt = 0; | cnt = 0; | ||||
| } | } | ||||
| mgb_assert(cnt >= 0); | mgb_assert(cnt >= 0); | ||||