| @@ -7,7 +7,7 @@ | |||||
| # software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # pylint: disable=redefined-builtin | # pylint: disable=redefined-builtin | ||||
| from . import metric, vision | |||||
| from . import metric, utils, vision | |||||
| from .elemwise import * | from .elemwise import * | ||||
| from .math import * | from .math import * | ||||
| from .nn import * | from .nn import * | ||||
| @@ -11,6 +11,7 @@ from typing import Iterable, Union | |||||
| import numpy as np | import numpy as np | ||||
| from ..tensor import Tensor | from ..tensor import Tensor | ||||
| from .elemwise import abs, maximum, minimum | |||||
| from .math import topk as _topk | from .math import topk as _topk | ||||
| from .tensor import broadcast_to, transpose | from .tensor import broadcast_to, transpose | ||||
| @@ -0,0 +1,57 @@ | |||||
| # -*- coding: utf-8 -*- | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| from ..core._imperative_rt.core2 import apply | |||||
| from ..core._imperative_rt.core2 import sync as _sync | |||||
| from ..core.ops.builtin import AssertEqual | |||||
| from ..tensor import Tensor | |||||
| from .elemwise import abs, maximum, minimum | |||||
| def _assert_equal( | |||||
| expect: Tensor, actual: Tensor, *, maxerr: float = 0.0001, verbose: bool = False | |||||
| ): | |||||
| r""" | |||||
| Asserts two tensors equal and returns expected value (first input). | |||||
| It is a variant of python assert which is symbolically traceable (similar to ``numpy.testing.assert_equal``). | |||||
| If we want to verify the correctness of model, just ``assert`` its states and outputs. | |||||
| While sometimes we need to verify the correctness at different backends for *dumped* model | |||||
| (or in :class:`~jit.trace` context), and no python code could be executed in that case. | |||||
| Thus we have to use :func:`~functional.utils._assert_equal` instead. | |||||
| :param expect: expected tensor value | |||||
| :param actual: tensor to check value | |||||
| :param maxerr: max allowed error; error is defined as the minimal of absolute and relative error | |||||
| :param verbose: whether to print maxerr to stdout during opr exec | |||||
| :return: expected tensor | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| x = tensor([1, 2, 3], np.float32) | |||||
| y = tensor([1, 2, 3], np.float32) | |||||
| print(F.utils._assert_equal(x, y, maxerr=0).numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [1. 2. 3.] | |||||
| """ | |||||
| err = ( | |||||
| abs(expect - actual) | |||||
| / maximum(minimum(abs(expect), abs(actual)), Tensor(1.0, dtype="float32")) | |||||
| ).max() | |||||
| result = apply(AssertEqual(maxerr=maxerr, verbose=verbose), expect, actual, err)[0] | |||||
| _sync() # sync interpreter to get exception | |||||
| return result | |||||
| @@ -28,7 +28,12 @@ from ..core._imperative_rt.core2 import ( | |||||
| unset_compiled, | unset_compiled, | ||||
| unset_tracing, | unset_tracing, | ||||
| ) | ) | ||||
| from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend | |||||
| from ..core._imperative_rt.ops import ( | |||||
| AssertEqual, | |||||
| CollectiveComm, | |||||
| RemoteRecv, | |||||
| RemoteSend, | |||||
| ) | |||||
| from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
| from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
| from ..core.ops.builtin import BackwardGraph, OpDef | from ..core.ops.builtin import BackwardGraph, OpDef | ||||
| @@ -110,7 +115,7 @@ class TensorInfo: | |||||
| self.data_reader = None | self.data_reader = None | ||||
| _io_op_types = {CollectiveComm, RemoteSend, RemoteRecv} | |||||
| _io_op_types = {AssertEqual, CollectiveComm, RemoteSend, RemoteRecv} | |||||
| class trace: | class trace: | ||||
| @@ -21,6 +21,7 @@ from megengine.core._trace_option import use_symbolic_shape | |||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
| from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
| from megengine.jit import trace | |||||
| def test_where(): | def test_where(): | ||||
| @@ -746,3 +747,18 @@ def test_ones(val): | |||||
| shp = tensor(val) | shp = tensor(val) | ||||
| np_shp = np.array(val) | np_shp = np.array(val) | ||||
| np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) | np.testing.assert_equal(F.ones(shp), np.ones(np_shp)) | ||||
| def test_assert_equal(): | |||||
| shape = (2, 3, 4, 5) | |||||
| x = F.ones(shape, dtype=np.float32) | |||||
| y = F.zeros(shape, dtype=np.float32) + 1.00001 | |||||
| z = F.utils._assert_equal(x, y) | |||||
| def test_assert_not_equal(): | |||||
| shape = (2, 3, 4, 5) | |||||
| x = F.ones(shape, dtype=np.float32) | |||||
| y = F.zeros(shape, dtype=np.float32) + 1.1 | |||||
| with pytest.raises(RuntimeError): | |||||
| z = F.utils._assert_equal(x, y) | |||||
| @@ -451,20 +451,22 @@ OP_TRAIT_REG(Identity, Identity) | |||||
| namespace { namespace assert_equal { | namespace { namespace assert_equal { | ||||
| auto apply_on_var_node( | auto apply_on_var_node( | ||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const AssertEqual&>(def); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::AssertEqual::make(inputs[0], inputs[1], op.param(), config); | |||||
| const OpDef& def, | |||||
| const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final<AssertEqual>(); | |||||
| if (inputs.size() == 2) { | |||||
| return opr::AssertEqual::make(inputs[0], inputs[1], op.param()); | |||||
| } else { | |||||
| // workaround for MiniGraph, which only allow one opr in the graph | |||||
| mgb_assert(inputs.size() == 3); | |||||
| return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); | |||||
| } | } | ||||
| } | |||||
| OP_TRAIT_REG(AssertEqual, AssertEqual) | OP_TRAIT_REG(AssertEqual, AssertEqual) | ||||
| .apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
| .fallback(); | .fallback(); | ||||
| }} | |||||
| }} // assert_equal | |||||
| namespace { namespace uniform_rng { | namespace { namespace uniform_rng { | ||||
| auto apply_on_var_node( | auto apply_on_var_node( | ||||
| @@ -445,6 +445,12 @@ public: | |||||
| size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} | size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} | ||||
| void record_async_error(std::unique_ptr<MegBrainError> async_exc) override { | |||||
| if (!ProxyGraph::tm_async_error) { | |||||
| std::swap(async_exc, tm_async_error); | |||||
| } | |||||
| } | |||||
| std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);} | std::unique_ptr<cg::AsyncExecutable> compile(const OutputSpec &out_spec) override {mgb_assert(0);} | ||||
| SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part( | SmallVector<std::unique_ptr<cg::AsyncExecutable>> compile_multi_part( | ||||
| const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);} | const SmallVector<OutputSpec>& out_specs) override {mgb_assert(0);} | ||||
| @@ -457,7 +463,6 @@ public: | |||||
| size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} | size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} | ||||
| size_t clear_device_memory() override {mgb_assert(0);} | size_t clear_device_memory() override {mgb_assert(0);} | ||||
| void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} | void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} | ||||
| void record_async_error(std::unique_ptr<MegBrainError> async_exc) override {mgb_assert(0);} | |||||
| }; | }; | ||||
| std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; | std::atomic<size_t> ProxyGraph::ProxyGraphImpl::m_node_id = 0; | ||||
| @@ -861,6 +866,8 @@ TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { | |||||
| } | } | ||||
| } | } | ||||
| thread_local std::unique_ptr<MegBrainError> ProxyGraph::tm_async_error; | |||||
| } // namespace imperative | } // namespace imperative | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -24,6 +24,9 @@ namespace imperative { | |||||
| class ProxyGraph : public NonCopyableObj { | class ProxyGraph : public NonCopyableObj { | ||||
| public: | public: | ||||
| static ProxyGraph* get_default_graph(); | static ProxyGraph* get_default_graph(); | ||||
| static std::unique_ptr<MegBrainError> get_async_error() { | |||||
| return std::move(tm_async_error); | |||||
| } | |||||
| /********************** Physical Tensor API **********************/ | /********************** Physical Tensor API **********************/ | ||||
| @@ -98,6 +101,8 @@ private: | |||||
| std::unique_ptr<ExecEnv> m_env; | std::unique_ptr<ExecEnv> m_env; | ||||
| std::unique_ptr<StaticInferManager> m_static_infer_manager; | std::unique_ptr<StaticInferManager> m_static_infer_manager; | ||||
| std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | std::unique_ptr<SeqCompNodeOptimizer> m_seq_comp_node_optimizer; | ||||
| static thread_local std::unique_ptr<MegBrainError> tm_async_error; | |||||
| }; | }; | ||||
| } // namespace imperative | } // namespace imperative | ||||
| @@ -101,6 +101,10 @@ apply_on_physical_tensor(const OpDef& def, | |||||
| } | } | ||||
| } | } | ||||
| exec(def, inputs, outputs); | exec(def, inputs, outputs); | ||||
| auto async_error = ProxyGraph::get_async_error(); | |||||
| if (async_error) { | |||||
| throw *async_error; | |||||
| } | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||