GitOrigin-RevId: ee5984c52d
tags/v1.5.0
| @@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False): | |||
| def _remove_axis(inp: Tensor, axis) -> Tensor: | |||
| def get_axes(): | |||
| if axis is None: | |||
| return [i for i, s in enumerate(inp.shape) if s == 1] | |||
| shp = inp.shape | |||
| return [i for i, s in enumerate(shp) if s == 1] | |||
| try: | |||
| return [int(axis)] | |||
| except (TypeError, ValueError): | |||
| @@ -6,9 +6,11 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import time | |||
| from typing import List, Optional, Tuple | |||
| from ..device import set_default_device, what_is_xpu | |||
| from ..random import seed | |||
| from .server import Client, Server | |||
| @@ -156,6 +158,7 @@ def init_process_group( | |||
| WORLD.reset(list(range(world_size))) | |||
| set_default_device("{}{}".format(device_type, device)) | |||
| seed(int(time.time()) + rank) | |||
| def is_distributed() -> bool: | |||
| @@ -7,7 +7,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .distribution import normal, uniform | |||
| from .rng import seed | |||
| from .rng import RNG, seed | |||
| # pylint: disable=undefined-variable | |||
| del distribution, rng # type: ignore[name-defined] | |||
| @@ -9,11 +9,8 @@ | |||
| from typing import Iterable, Optional | |||
| from .. import Tensor | |||
| from ..core._imperative_rt import invoke_op | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops.builtin import GaussianRNG, UniformRNG | |||
| from ..core.tensor import utils | |||
| from .rng import _random_seed_generator | |||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
| from .rng import _normal, _uniform | |||
| __all__ = ["normal", "uniform"] | |||
| @@ -48,14 +45,14 @@ def normal( | |||
| [-1.4939808 -1.5824696 ]] | |||
| """ | |||
| if size is None: | |||
| size = (1,) | |||
| op = GaussianRNG(mean, std) | |||
| _ref = Tensor([], dtype="int32") | |||
| shape = utils.astensor1d(size, _ref, dtype="int32") | |||
| shape = Tensor(shape, dtype="int32") | |||
| (output,) = apply(op, shape) | |||
| return output | |||
| return _normal( | |||
| mean=mean, | |||
| std=std, | |||
| size=size, | |||
| seed=_get_global_rng_seed(), | |||
| device=None, | |||
| handle=0, | |||
| ) | |||
| def uniform( | |||
| @@ -88,14 +85,11 @@ def uniform( | |||
| [0.09365904 0.62957656]] | |||
| """ | |||
| assert low < high, "Uniform is not defined when low >= high" | |||
| if size is None: | |||
| size = (1,) | |||
| op = UniformRNG() | |||
| _ref = Tensor([], dtype="int32") | |||
| shape = utils.astensor1d(size, _ref, dtype="int32") | |||
| shape = Tensor(shape, dtype="int32") | |||
| (output,) = apply(op, shape) | |||
| return low + (high - low) * output | |||
| return _uniform( | |||
| low=low, | |||
| high=high, | |||
| size=size, | |||
| seed=_get_global_rng_seed(), | |||
| device=None, | |||
| handle=0, | |||
| ) | |||
| @@ -7,17 +7,94 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import time | |||
| from typing import Iterable, Optional | |||
| from numpy.random import MT19937 | |||
| from .. import Tensor | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle | |||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | |||
| from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle | |||
| from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed | |||
| from ..core.ops.builtin import GaussianRNG, UniformRNG | |||
| from ..core.tensor import utils | |||
| from ..device import get_default_device | |||
| _rng = None | |||
| def _random_seed_generator(): | |||
| if _rng is None: | |||
| from ..distributed.group import get_rank | |||
| def _normal( | |||
| mean: float, | |||
| std: float, | |||
| size: Optional[Iterable[int]], | |||
| seed: int, | |||
| device: str, | |||
| handle: int, | |||
| ) -> Tensor: | |||
| if size is None: | |||
| size = (1,) | |||
| op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) | |||
| _ref = Tensor([], dtype="int32", device=device) | |||
| shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
| (output,) = apply(op, shape) | |||
| return output | |||
| def _uniform( | |||
| low: float, | |||
| high: float, | |||
| size: Optional[Iterable[int]], | |||
| seed: int, | |||
| device: str, | |||
| handle: int, | |||
| ) -> Tensor: | |||
| assert low < high, "Uniform is not defined when low >= high" | |||
| if size is None: | |||
| size = (1,) | |||
| op = UniformRNG(seed=seed, handle=handle) | |||
| _ref = Tensor([], dtype="int32", device=device) | |||
| shape = utils.astensor1d(size, _ref, dtype="int32", device=device) | |||
| (output,) = apply(op, shape) | |||
| return low + (high - low) * output | |||
| class RNG: | |||
| def __init__(self, seed=0, device=None): | |||
| self.seed = seed | |||
| self.device = device if device else get_default_device() | |||
| self.handle = _new_rng_handle(self.device, self.seed) | |||
| def uniform( | |||
| self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None | |||
| ): | |||
| return _uniform( | |||
| low=low, | |||
| high=high, | |||
| size=size, | |||
| seed=self.seed, | |||
| device=self.device, | |||
| handle=self.handle, | |||
| ) | |||
| seed(seed=int(time.time()) + get_rank()) | |||
| def normal( | |||
| self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None | |||
| ): | |||
| return _normal( | |||
| mean=mean, | |||
| std=std, | |||
| size=size, | |||
| seed=self.seed, | |||
| device=self.device, | |||
| handle=self.handle, | |||
| ) | |||
| def __del__(self): | |||
| _delete_rng_handle(self.handle) | |||
| def _random_seed_generator(): | |||
| assert _rng | |||
| while True: | |||
| yield _rng.random_raw() | |||
| @@ -25,3 +102,7 @@ def _random_seed_generator(): | |||
| def seed(seed: int): | |||
| global _rng # pylint: disable=global-statement | |||
| _rng = MT19937(seed=seed) | |||
| _set_global_rng_seed(seed) | |||
| seed(int(time.time())) | |||
| @@ -10,7 +10,10 @@ | |||
| */ | |||
| #include "./ops.h" | |||
| #include "./helper.h" | |||
| #include "./tensor.h" | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/imperative.h" | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| @@ -491,21 +494,15 @@ void init_ops(py::module m) { | |||
| _init_py_op_base(m); | |||
| INIT_ALL_OP(m) | |||
| m.def("new_rng_handle", &RNGMixin::new_handle); | |||
| // FIXME: RNG op might execute after handle released due to async dispatch, | |||
| // which would cause memory leak or use-after-free | |||
| m.def("delete_rng_handle", &RNGMixin::delete_handle); | |||
| m.def("set_rng_seed", &set_rng_seed); | |||
| py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG") | |||
| .def(py::init<>()) | |||
| .def(py::init<mgb::CompNode>()) | |||
| .def(py::init<RNGMixin::Handle>()); | |||
| py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG") | |||
| .def(py::init<>()) | |||
| .def(py::init<mgb::CompNode>()) | |||
| .def(py::init<float ,float>()) | |||
| .def(py::init<float ,float, mgb::CompNode>()) | |||
| .def(py::init<float ,float, RNGMixin::Handle>()); | |||
| m.def("new_rng_handle", &rng::new_handle); | |||
| m.def("delete_rng_handle", [](size_t handle){ | |||
| // RNG op might execute after handle released due to async dispatch, so | |||
| // we need sync before delete a handle to avoid memory leak or use-after-free | |||
| python::interpreter_for_py->sync(); | |||
| mgb::CompNode::sync_all(); | |||
| py_task_q.wait_all_task_finish(); | |||
| rng::delete_handle(handle); | |||
| }, py::call_guard<py::gil_scoped_release>()); | |||
| m.def("set_global_rng_seed", &rng::set_global_rng_seed); | |||
| m.def("get_global_rng_seed", &rng::get_global_rng_seed); | |||
| } | |||
| @@ -0,0 +1,121 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| import megengine | |||
| from megengine import tensor | |||
| from megengine.core._imperative_rt import CompNode | |||
| from megengine.core._imperative_rt.core2 import apply | |||
| from megengine.core._imperative_rt.ops import ( | |||
| delete_rng_handle, | |||
| get_global_rng_seed, | |||
| new_rng_handle, | |||
| ) | |||
| from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||
| from megengine.random import RNG | |||
| from megengine.random.rng import _normal, _uniform | |||
| def test_gaussian_op(): | |||
| shape = ( | |||
| 8, | |||
| 9, | |||
| 11, | |||
| 12, | |||
| ) | |||
| shape = tensor(shape, dtype="int32") | |||
| op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| def test_uniform_op(): | |||
| shape = ( | |||
| 8, | |||
| 9, | |||
| 11, | |||
| 12, | |||
| ) | |||
| shape = tensor(shape, dtype="int32") | |||
| op = UniformRNG(seed=get_global_rng_seed()) | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = UniformRNG(seed=seed, handle=h) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| def test_UniformRNG(): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = m1.uniform(size=(100,)) | |||
| out1_ = m1.uniform(size=(100,)) | |||
| out2 = m2.uniform(size=(100,)) | |||
| out3 = m3.uniform(size=(100,)) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| assert not (out1.numpy() == out1_.numpy()).all() | |||
| low = -234 | |||
| high = 123 | |||
| out = m1.uniform(low=low, high=high, size=(20, 30, 40)) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (20, 30, 40) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||
| assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 | |||
| def test_NormalRNG(): | |||
| m1 = RNG(seed=111, device="xpu0") | |||
| m2 = RNG(seed=111, device="xpu1") | |||
| m3 = RNG(seed=222, device="xpu0") | |||
| out1 = m1.normal(size=(100,)) | |||
| out1_ = m1.uniform(size=(100,)) | |||
| out2 = m2.normal(size=(100,)) | |||
| out3 = m3.normal(size=(100,)) | |||
| np.testing.assert_equal(out1.numpy(), out2.numpy()) | |||
| assert out1.device == "xpu0" and out2.device == "xpu1" | |||
| assert not (out1.numpy() == out3.numpy()).all() | |||
| assert not (out1.numpy() == out1_.numpy()).all() | |||
| mean = -1 | |||
| std = 2 | |||
| out = m1.normal(mean=mean, std=std, size=(20, 30, 40)) | |||
| out_shp = out.shape | |||
| if isinstance(out_shp, tuple): | |||
| assert out_shp == (20, 30, 40) | |||
| else: | |||
| assert all(out.shape.numpy() == np.array([20, 30, 40])) | |||
| assert np.abs(out.mean().numpy() - mean) / std < 0.1 | |||
| assert np.abs(np.std(out.numpy()) - std) < 0.1 | |||
| @@ -1,76 +0,0 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import numpy as np | |||
| from megengine import tensor | |||
| from megengine.core._imperative_rt import CompNode | |||
| from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle | |||
| from megengine.core.ops.builtin import GaussianRNG, UniformRNG | |||
| from megengine.core.tensor.core import apply | |||
| def test_gaussian_rng(): | |||
| shape = ( | |||
| 8, | |||
| 9, | |||
| 11, | |||
| 12, | |||
| ) | |||
| shape = tensor(shape, dtype="int32") | |||
| op = GaussianRNG(1.0, 3.0) | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu1") | |||
| op = GaussianRNG(-1.0, 2.0, cn) | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = GaussianRNG(3.0, 1.0, h) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 | |||
| assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| def test_uniform_rng(): | |||
| shape = ( | |||
| 8, | |||
| 9, | |||
| 11, | |||
| 12, | |||
| ) | |||
| shape = tensor(shape, dtype="int32") | |||
| op = UniformRNG() | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(CompNode("xpux")) | |||
| cn = CompNode("xpu1") | |||
| op = UniformRNG(cn) | |||
| (output,) = apply(op, shape) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| cn = CompNode("xpu2") | |||
| seed = 233333 | |||
| h = new_rng_handle(cn, seed) | |||
| op = UniformRNG(h) | |||
| (output,) = apply(op, shape) | |||
| delete_rng_handle(h) | |||
| assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 | |||
| assert str(output.device) == str(cn) | |||
| @@ -2,7 +2,7 @@ | |||
| * \file imperative/src/impl/ops/rng.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| @@ -10,23 +10,23 @@ | |||
| */ | |||
| #include "megbrain/imperative/ops/rng.h" | |||
| #include <bits/stdint-uintn.h> | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/graph/helper.h" | |||
| #include "megbrain/opr/rand.h" | |||
| //#include "megbrain/common.h" | |||
| #include "../op_trait.h" | |||
| #include "../dnn_op_helper.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace mgb::imperative::rng { | |||
| namespace { | |||
| template <typename HandleFactory, typename THandle> | |||
| class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { | |||
| public: | |||
| using DT = CompNode::DeviceType; | |||
| using Handle = THandle; | |||
| using OpTypeInfo = size_t; | |||
| template <typename... Args> | |||
| Handle new_handle(Args&&... args) { | |||
| @@ -38,27 +38,26 @@ public: | |||
| size_t removed = 0; | |||
| if (!is_finalized()) { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| removed = m_handle2op.erase(handle); | |||
| removed = m_handle2ops.erase(handle); | |||
| } | |||
| static_cast<HandleFactory*>(this)->do_delete_handle(handle); | |||
| return removed; | |||
| } | |||
| template <typename DnnOp> | |||
| auto get_dnn_op(Handle handle, CompNode cn) { | |||
| auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) { | |||
| mgb_assert(!is_finalized()); | |||
| DnnOpWithMutex* dnn_op_with_mtx; | |||
| { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| dnn_op_with_mtx = &m_handle2op[handle]; | |||
| dnn_op_with_mtx = &m_handle2ops[handle][tpinfo]; | |||
| } | |||
| auto dnn_handle = | |||
| MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); | |||
| DnnOp* dnn_op; | |||
| std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx); | |||
| bool initialized = false; | |||
| if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) != | |||
| nullptr) { | |||
| DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get()); | |||
| if (dnn_op != nullptr) { | |||
| mgb_assert(dnn_op->handle() == dnn_handle); | |||
| initialized = true; | |||
| } else { | |||
| @@ -77,35 +76,30 @@ private: | |||
| struct DnnOpWithMutex { | |||
| std::mutex mtx; | |||
| std::unique_ptr<megdnn::OperatorBase> op; | |||
| DnnOpWithMutex(): op{nullptr} {} | |||
| }; | |||
| std::shared_ptr<void> on_comp_node_finalize() override { | |||
| MGB_LOCK_GUARD(m_mtx); | |||
| m_handle2op.clear(); | |||
| m_handle2ops.clear(); | |||
| return {}; | |||
| } | |||
| std::unordered_map<Handle, DnnOpWithMutex> m_handle2op; | |||
| std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops; | |||
| std::mutex m_mtx; | |||
| }; | |||
| class RNGDnnOpManager final | |||
| : public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> { | |||
| : public DnnOpManagerT<RNGDnnOpManager, Handle> { | |||
| public: | |||
| Handle new_handle(CompNode comp_node, uint64_t seed) { | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| return DnnOpManagerBase::new_handle(comp_node, seed); | |||
| } | |||
| size_t delete_handle(Handle handle) { | |||
| size_t ret = 0; | |||
| { | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| auto iter = sm_partial2full.find(handle); | |||
| if (iter != sm_partial2full.end()) { | |||
| for (auto&& h : iter->second) { | |||
| ret += DnnOpManagerBase::delete_handle(h.second); | |||
| } | |||
| sm_partial2full.erase(iter); | |||
| } | |||
| } | |||
| ret += DnnOpManagerBase::delete_handle(handle); | |||
| return ret; | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| return DnnOpManagerBase::delete_handle(handle); | |||
| } | |||
| Handle do_new_handle(CompNode comp_node, uint64_t seed) { | |||
| @@ -118,32 +112,26 @@ public: | |||
| } | |||
| static uint64_t get_seed(Handle handle) { | |||
| if (!handle) { return glob_default_seed; } | |||
| return reinterpret_cast<HandleData*>(handle)->seed; | |||
| } | |||
| static CompNode get_comp_node(Handle handle) { | |||
| mgb_assert(handle, "invalid handle"); | |||
| return reinterpret_cast<HandleData*>(handle)->comp_node; | |||
| } | |||
| static Handle get_full_handle(Handle handle, CompNode comp_node) { | |||
| if (get_comp_node(handle).valid()) { | |||
| return handle; | |||
| } | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| auto&& full = sm_partial2full[handle][comp_node]; | |||
| if (!full) { | |||
| full = inst().new_handle(comp_node, get_seed(handle)); | |||
| } | |||
| return full; | |||
| } | |||
| static Handle get_default_handle(CompNode comp_node) { | |||
| static Handle glob_partial_handle = | |||
| inst().new_handle(CompNode{}, glob_default_seed); | |||
| if (!comp_node.valid()) { | |||
| return glob_partial_handle; | |||
| mgb_assert(comp_node.valid()); | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| auto&& glob_handle = glob_default_handles[comp_node]; | |||
| if (!glob_handle) { | |||
| glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||
| } else if (get_seed(glob_handle) != glob_default_seed) { | |||
| inst().DnnOpManagerBase::delete_handle(glob_handle); | |||
| glob_handle = inst().do_new_handle(comp_node, glob_default_seed); | |||
| } | |||
| return get_full_handle(glob_partial_handle, comp_node); | |||
| return glob_handle; | |||
| } | |||
| static RNGDnnOpManager& inst() { | |||
| @@ -152,9 +140,15 @@ public: | |||
| } | |||
| static void set_glob_default_seed(uint64_t seed) { | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| glob_default_seed = seed; | |||
| } | |||
| static uint64_t get_glob_default_seed() { | |||
| MGB_LOCK_GUARD(sm_mtx); | |||
| return glob_default_seed; | |||
| } | |||
| private: | |||
| struct HandleData { | |||
| CompNode comp_node; | |||
| @@ -165,16 +159,13 @@ private: | |||
| MemPool<HandleData> m_handle_pool; | |||
| static std::mutex sm_mtx; | |||
| static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>> | |||
| sm_partial2full; | |||
| static CompNode::UnorderedMap<Handle> glob_default_handles; | |||
| static uint64_t glob_default_seed; | |||
| }; | |||
| uint64_t RNGDnnOpManager::glob_default_seed = 0; | |||
| std::mutex RNGDnnOpManager::sm_mtx; | |||
| std::unordered_map<RNGDnnOpManager::Handle, | |||
| CompNode::UnorderedMap<RNGDnnOpManager::Handle>> | |||
| RNGDnnOpManager::sm_partial2full; | |||
| CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles; | |||
| template <typename Op> | |||
| struct OpMeth; | |||
| @@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> { | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::UniformRNG; | |||
| static Param make_param(const UniformRNG& rng) { | |||
| return {RNGDnnOpManager::get_seed(rng.handle())}; | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed}; | |||
| } | |||
| }; | |||
| @@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> { | |||
| using Param = DnnOp::Param; | |||
| using OpNode = mgb::opr::GaussianRNG; | |||
| static Param make_param(const GaussianRNG& rng) { | |||
| return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; | |||
| auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); | |||
| mgb_assert(handle_seed == rng.seed, | |||
| "inconsistent rng seed: rng op: %lu handle: %lu", | |||
| handle_seed, rng.seed); | |||
| return {handle_seed, rng.mean, rng.std}; | |||
| } | |||
| }; | |||
| @@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, | |||
| auto dest = outputs[0]; | |||
| auto cn = dest->comp_node(); | |||
| auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); | |||
| { | |||
| auto handle_cn = RNGDnnOpManager::get_comp_node(handle); | |||
| mgb_assert(cn == handle_cn, | |||
| "inconsistent comp_node: handle: %s, output: %s", | |||
| cn.to_string().c_str(), handle_cn.to_string().c_str()); | |||
| auto handle = rng.handle; | |||
| if (!handle) { | |||
| handle = RNGDnnOpManager::get_default_handle(cn); | |||
| } | |||
| // retrieve dnn_op from glob cache | |||
| auto dnn_op_thread_safe = RNGDnnOpManager::inst() | |||
| .get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn); | |||
| .get_dnn_op<typename OpMeth<Op>::DnnOp>( | |||
| handle, reinterpret_cast<size_t>(op.dyn_typeinfo()), | |||
| cn); | |||
| auto initialized = std::get<0>(dnn_op_thread_safe); | |||
| auto dnn_op = std::get<1>(dnn_op_thread_safe); | |||
| if (initialized) { | |||
| auto handle_seed = RNGDnnOpManager::get_seed(handle); | |||
| mgb_assert(dnn_op->param().seed == handle_seed, | |||
| "inconsistent rng seed: handle: %zu, dnn_op: %zu", | |||
| "inconsistent rng seed: handle: %lu, dnn_op: %lu", | |||
| handle_seed, dnn_op->param().seed); | |||
| } | |||
| dnn_op->param() = OpMeth<Op>::make_param(rng); | |||
| @@ -239,9 +237,12 @@ template <typename Op> | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
| const OpDef& op, const SmallVector<TensorPtr>& inputs) { | |||
| LogicalTensorDesc dest; | |||
| dest.comp_node = op.cast_final_safe<Op>().comp_node(); | |||
| if (!dest.comp_node.valid()) | |||
| auto handle = op.cast_final_safe<Op>().handle; | |||
| if (handle) { | |||
| dest.comp_node = RNGDnnOpManager::get_comp_node(handle); | |||
| } else { | |||
| dest.comp_node = inputs[0]->comp_node(); | |||
| } | |||
| auto hv = inputs[0]->get_value().proxy_to_default_cpu(); | |||
| TensorShape tshape; | |||
| @@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| } | |||
| template<typename Op> | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, const VarNodeArray& inputs) { | |||
| SymbolVar apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually", | |||
| nr_inp); | |||
| auto&& rng = def.cast_final_safe<Op>(); | |||
| mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", | |||
| rng.dyn_typeinfo()->name, | |||
| nr_inp); | |||
| auto param = OpMeth<Op>::make_param(rng); | |||
| return OpMeth<Op>::OpNode::make( | |||
| inputs[0], param, {rng.comp_node()}).node()->owner_opr(); | |||
| OperatorNodeConfig config; | |||
| if (rng.handle) { | |||
| config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; | |||
| } else { | |||
| config = {rng.make_name()}; | |||
| } | |||
| return OpMeth<Op>::OpNode::make(inputs[0], param, config); | |||
| } | |||
| template<typename T> | |||
| @@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| } // anonymous namespace | |||
| RNGMixin::RNGMixin(CompNode cn): | |||
| m_handle(RNGDnnOpManager::get_default_handle(cn)) {} | |||
| uint64_t RNGMixin::seed() const { | |||
| return RNGDnnOpManager::get_seed(m_handle); | |||
| } | |||
| CompNode RNGMixin::comp_node() const { | |||
| return RNGDnnOpManager::get_comp_node(m_handle); | |||
| } | |||
| RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) { | |||
| Handle new_handle(CompNode comp_node, uint64_t seed) { | |||
| return RNGDnnOpManager::inst().new_handle(comp_node, seed); | |||
| } | |||
| size_t RNGMixin::delete_handle(Handle handle) { | |||
| size_t delete_handle(Handle handle) { | |||
| return RNGDnnOpManager::inst().delete_handle(handle); | |||
| } | |||
| void set_rng_seed(uint64_t seed) { | |||
| void set_global_rng_seed(uint64_t seed) { | |||
| RNGDnnOpManager::set_glob_default_seed(seed); | |||
| } | |||
| uint64_t get_global_rng_seed() { | |||
| return RNGDnnOpManager::get_glob_default_seed(); | |||
| } | |||
| #define REG_RNG_OP(NAME)\ | |||
| namespace { \ | |||
| OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
| @@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||
| .fallback(); \ | |||
| } \ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME); | |||
| REG_RNG_OP(UniformRNG) | |||
| REG_RNG_OP(GaussianRNG) | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| } // namespace mgb::imperative::rng | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual) | |||
| .fallback(); | |||
| }} // assert_equal | |||
| namespace { namespace uniform_rng { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const UniformRNG&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::UniformRNG::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(UniformRNG, UniformRNG) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // uniform_rng | |||
| namespace { namespace gaussian_rng { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const GaussianRNG&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::GaussianRNG::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(GaussianRNG, GaussianRNG) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // gaussian_rng | |||
| namespace { namespace roi_align { | |||
| VarNodeArray apply_on_var_node( | |||
| const OpDef& def, | |||
| @@ -2,7 +2,7 @@ | |||
| * \file imperative/src/include/megbrain/imperative/ops/rng.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| @@ -12,84 +12,15 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| namespace mgb::imperative { | |||
| namespace mgb::imperative::rng { | |||
| class RNGMixin { | |||
| public: | |||
| using Handle = size_t; | |||
| using Handle = size_t; | |||
| static Handle new_handle( | |||
| CompNode comp_node={}, uint64_t seed=0); | |||
| Handle new_handle(CompNode comp_node, uint64_t seed); | |||
| size_t delete_handle(Handle handle); | |||
| void set_global_rng_seed(uint64_t seed); | |||
| uint64_t get_global_rng_seed(); | |||
| static size_t delete_handle(Handle handle); | |||
| Handle handle() const { | |||
| return m_handle; | |||
| } | |||
| uint64_t seed() const; | |||
| CompNode comp_node() const; | |||
| protected: | |||
| RNGMixin(Handle handle): m_handle(handle) {} | |||
| RNGMixin(CompNode comp_node); | |||
| private: | |||
| Handle m_handle; | |||
| }; | |||
| class GaussianRNG : public OpDefImplBase<GaussianRNG>, | |||
| public RNGMixin { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| float mean = 1.0f, std = 0.0; | |||
| GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {} | |||
| GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}): | |||
| GaussianRNG(comp_node_) { mean = mean_; std = std_; } | |||
| GaussianRNG(float mean_, float std_, Handle handle): | |||
| RNGMixin(handle), mean(mean_), std(std_) {} | |||
| size_t hash() const override { | |||
| XXHash xxhash{}; | |||
| auto append = [&xxhash](auto field){ | |||
| auto hash_val = HashTrait<decltype(field)>::eval(field); | |||
| xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val)); | |||
| }; | |||
| append(dyn_typeinfo()); | |||
| append(seed()); | |||
| append(mean); | |||
| append(std); | |||
| return xxhash.digest(); | |||
| } | |||
| bool is_same_st(const Hashable& rhs_) const override { | |||
| auto&& rhs = static_cast<const GaussianRNG&>(rhs_); | |||
| return rhs.seed() == seed() | |||
| && rhs.mean == mean | |||
| && rhs.std == std; | |||
| } | |||
| }; | |||
| class UniformRNG : public OpDefImplBase<UniformRNG>, | |||
| public RNGMixin { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| public: | |||
| UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {} | |||
| UniformRNG(Handle handle): RNGMixin(handle) {} | |||
| size_t hash() const override { | |||
| return hash_pair_combine( | |||
| mgb::hash(seed()), | |||
| reinterpret_cast<std::uintptr_t>(dyn_typeinfo())); | |||
| } | |||
| bool is_same_st(const Hashable& rhs_) const override { | |||
| auto&& rhs = static_cast<const UniformRNG&>(rhs_); | |||
| return rhs.dyn_typeinfo() == dyn_typeinfo() | |||
| && rhs.seed() == seed(); | |||
| } | |||
| }; | |||
| void set_rng_seed(uint64_t seed); | |||
| } // namespace mgb::imperative | |||
| } // namespace mgb::imperative::rng | |||
| @@ -14,6 +14,7 @@ | |||
| using namespace mgb; | |||
| using namespace imperative; | |||
| using namespace imperative::rng; | |||
| template<typename Op, typename ...Args> | |||
| void check_rng_basic(Args&& ...args) { | |||
| @@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) { | |||
| {3, 4, 5, 6}, | |||
| {2333}}) | |||
| for (auto&& cn: { | |||
| CompNode::load("cpu0"), | |||
| CompNode::load("xpu0")}) | |||
| CompNode::load("xpu0"), | |||
| CompNode::load("xpu1")}) | |||
| { | |||
| auto op = Op::make(std::forward<Args>(args)..., cn); | |||
| Handle h = new_handle(cn, 123); | |||
| auto op = Op::make(std::forward<Args>(args)..., h); | |||
| DeviceTensorND tshape_dev; | |||
| cg::copy_shape_to_tensor_value(tshape_dev, tshape); | |||
| auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)}); | |||
| SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)}; | |||
| auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); | |||
| ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); | |||
| ASSERT_TRUE(cn == outputs[0]->comp_node()); | |||
| // sync before delete handle | |||
| for (auto&& p: outputs) { | |||
| p->get_value(); | |||
| } | |||
| delete_handle(h); | |||
| } | |||
| } | |||
| TEST(TestImperative, UniformRNGBasic) { | |||
| check_rng_basic<UniformRNG>(); | |||
| check_rng_basic<UniformRNG>(123); | |||
| } | |||
| TEST(TestImperative, GaussianRNGBasic) { | |||
| check_rng_basic<GaussianRNG>(2.f, 3.f); | |||
| check_rng_basic<GaussianRNG>(123, 2.f, 3.f); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>; | |||
| def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; | |||
| def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { | |||
| let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; | |||
| let cmpFunction = [{return true;}]; | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash($_self.handle)); | |||
| }]; | |||
| let cmpFunction = [{return $0.handle == $1.handle;}]; | |||
| } | |||
| def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { | |||
| let extraArguments = (ins | |||
| MgbSizeTAddr:$handle | |||
| ); | |||
| let hashFunction = [{ | |||
| return mgb::hash_pair_combine( | |||
| mgb::hash($_self.dyn_typeinfo()), | |||
| mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.handle), | |||
| mgb::hash_pair_combine( | |||
| mgb::hash($_self.mean), | |||
| mgb::hash($_self.std)) | |||
| ) | |||
| ); | |||
| }]; | |||
| let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; | |||
| let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; | |||
| } | |||
| def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { | |||