GitOrigin-RevId: b8feb49321
tags/v1.2.0
| @@ -0,0 +1,15 @@ | |||
| # -*- coding: utf-8 -*- | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from ..core._imperative_rt.core2 import apply | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import InplaceAdd | |||
| def _inplace_add_(dest, delta, alpha, beta): | |||
| return dest._reset(apply(InplaceAdd(), dest, delta, alpha, beta)[0]) | |||
| @@ -502,6 +502,8 @@ class trace: | |||
| # profile | |||
| if self._profiling: | |||
| self._profiler = GraphProfiler(graph) | |||
| if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")): | |||
| graph.options.var_sanity_check_first_run = False | |||
| def _compile(self): | |||
| graph = self._graph = G.Graph() | |||
| @@ -1073,7 +1075,7 @@ def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
| return active_trace._apply_op(op, args) | |||
| def apply_const_compiled_mode(value, dtype, device, is_const): | |||
| def apply_const_compiled_mode(value, dtype, device, is_const, no_cache): | |||
| if skip_tracing: | |||
| args = [ | |||
| RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
| @@ -1099,7 +1101,7 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
| return list(outputs) | |||
| def apply_const_with_tracing(value, dtype, device, is_const): | |||
| def apply_const_with_tracing(value, dtype, device, is_const, no_cache): | |||
| if active_trace._symbolic: | |||
| outputs = apply_const_symbolic_mode(value, dtype, device) | |||
| else: | |||
| @@ -6,8 +6,10 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| from typing import Iterable, Tuple, Union | |||
| from ..functional.inplace import _inplace_add_ | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| @@ -58,15 +60,24 @@ class Adam(Optimizer): | |||
| eps = param_group["eps"] | |||
| beta0, beta1 = param_group["betas"] | |||
| def make_scalar(val): | |||
| return tensor([val]) | |||
| # since `conver_inputs` is disabled for param updates, | |||
| # scalar should be explicitly tansforred to tensor | |||
| _lr = tensor([lr]) | |||
| _weight_decay = tensor([weight_decay]) | |||
| _eps = tensor([eps]) | |||
| _beta0, _beta1 = tensor([beta0]), tensor([beta1]) | |||
| c1 = tensor([1.0]) | |||
| c05 = tensor([0.5]) | |||
| _lr, _neg_lr = map(make_scalar, (lr, -lr)) | |||
| _weight_decay = make_scalar(weight_decay) | |||
| _eps = make_scalar(eps) | |||
| _beta0, _beta1 = map(make_scalar, (beta0, beta1)) | |||
| c1, c05 = map(make_scalar, (1.0, 0.5)) | |||
| inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) | |||
| if inplace_mode: | |||
| # reduce device sync | |||
| c1_sub_beta0, c1_sub_beta1 = map(make_scalar, (1 - beta0, 1 - beta1)) | |||
| for param in param_group["params"]: | |||
| if param.grad is None: | |||
| @@ -77,18 +88,38 @@ class Adam(Optimizer): | |||
| grad += param * _weight_decay | |||
| states = self._state[param] | |||
| step = states["step"] | |||
| step, exp_avg, exp_avg_sq = ( | |||
| states["step"], | |||
| states["exp_avg"], | |||
| states["exp_avg_sq"], | |||
| ) | |||
| if inplace_mode: | |||
| _inplace_add_(step, c1, alpha=c1, beta=c1) | |||
| _inplace_add_(exp_avg, grad, alpha=_beta0, beta=c1_sub_beta0) | |||
| _inplace_add_( | |||
| exp_avg_sq, grad * grad, alpha=_beta1, beta=c1_sub_beta1, | |||
| ) | |||
| delta = (exp_avg / (c1 - _beta0 ** step)) / ( | |||
| (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps | |||
| ) | |||
| _inplace_add_(param, delta, alpha=c1, beta=_neg_lr) | |||
| continue | |||
| # step = step + c1 | |||
| step += c1 | |||
| exp_avg = states["exp_avg"] | |||
| exp_avg_sq = states["exp_avg_sq"] | |||
| exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) | |||
| exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) | |||
| # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) | |||
| exp_avg *= _beta0 | |||
| exp_avg += grad * (c1 - _beta0) | |||
| # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) | |||
| exp_avg_sq *= _beta1 | |||
| exp_avg_sq += (c1 - _beta1) * (grad * grad) | |||
| delta = (exp_avg / (c1 - _beta0 ** step)) / ( | |||
| (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps | |||
| ) | |||
| param -= _lr * delta | |||
| # not inplace change, need to update underlying tensor handler in state | |||
| states["exp_avg"]._reset(exp_avg) | |||
| states["exp_avg_sq"]._reset(exp_avg_sq) | |||
| @@ -96,6 +96,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| "optimizer can only optimize Parameters, but one of the params is " | |||
| + str(type(param)) | |||
| ) | |||
| param._reset(Tensor(param.numpy(), no_cache=True)) | |||
| for name, default in self._defaults.items(): | |||
| if default is required and name not in param_group: | |||
| @@ -121,7 +122,7 @@ class Optimizer(metaclass=ABCMeta): | |||
| initializer = np.zeros(param.shape, dtype=np.float32) | |||
| state_dict = self._state.setdefault(param, {}) | |||
| assert state_name not in state_dict | |||
| state = Tensor(initializer) | |||
| state = Tensor(initializer, no_cache=True) | |||
| state_dict[state_name] = state | |||
| @abstractmethod | |||
| @@ -6,8 +6,10 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import os | |||
| from typing import Iterable, Union | |||
| from ..functional.inplace import _inplace_add_ | |||
| from ..tensor import Parameter, tensor | |||
| from .optimizer import Optimizer | |||
| @@ -54,10 +56,16 @@ class SGD(Optimizer): | |||
| # since `conver_inputs` is disabled for param updates, | |||
| # scalar should be explicitly tansforred to tensor | |||
| _lr = tensor([lr]) | |||
| _weight_decay = tensor([weight_decay]) | |||
| _momentum = tensor([momentum]) | |||
| inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) | |||
| if inplace_mode: | |||
| _neg_lr = tensor([-lr]) | |||
| c1 = tensor([1.0]) | |||
| for param in param_group["params"]: | |||
| if param.grad is None: | |||
| continue | |||
| @@ -66,10 +74,21 @@ class SGD(Optimizer): | |||
| if weight_decay != 0.0: | |||
| grad += param * _weight_decay | |||
| if inplace_mode: | |||
| if momentum: | |||
| v = self._state[param]["momentum_buffer"] | |||
| _inplace_add_(v, grad, alpha=_momentum, beta=c1) | |||
| _inplace_add_(param, v, alpha=c1, beta=_neg_lr) | |||
| else: | |||
| _inplace_add_(param, grad, alpha=c1, beta=_neg_lr) | |||
| continue | |||
| if momentum: | |||
| v = self._state[param]["momentum_buffer"] | |||
| v = _momentum * v + grad | |||
| # v = v * _momentum + grad | |||
| v *= _momentum | |||
| v += grad | |||
| param -= _lr * v | |||
| self._state[param]["momentum_buffer"]._reset(v) | |||
| else: | |||
| param -= _lr * grad | |||
| @@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| dmap_callback = None | |||
| q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
| def __new__(cls, data, dtype=None, device=None, is_const=False): | |||
| def __new__(cls, data, dtype=None, device=None, is_const=False, no_cache=False): | |||
| if device is None: | |||
| cn = get_default_device() | |||
| elif isinstance(device, str): | |||
| @@ -49,7 +49,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| if 0 in data.strides: | |||
| data = data.squeeze().reshape(data.shape) | |||
| obj = _Tensor.__new__(cls, data, dtype, cn, is_const) | |||
| obj = _Tensor.__new__(cls, data, dtype, cn, is_const, no_cache) | |||
| return obj | |||
| @property | |||
| @@ -38,9 +38,9 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) { | |||
| std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) { | |||
| HostTensorND scalar{cn, {{1}, dtype::Float32()}}; | |||
| scalar.ptr<float>()[0] = v; | |||
| interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar); | |||
| interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); | |||
| auto&& t = std::make_shared<Tensor>(handle); | |||
| auto&& res = broadcast_to(t.get(), shape); | |||
| auto res = broadcast_to(t.get(), shape); | |||
| return res; | |||
| } | |||
| @@ -231,13 +231,14 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| } | |||
| } else { | |||
| py::detail::loader_life_support life_sup; // FIXME!!!required to cast DType | |||
| if (nargs != 4 && nargs != 5) { | |||
| throw py::type_error("expect 4 or 5 arguments"); | |||
| } | |||
| auto data = tup[0].cast<py::array>(); | |||
| DType dtype = tup[1].cast<DType>(); | |||
| CompNode cn = tup[2].cast<CompNode>(); | |||
| bool is_const = tup[3].cast<bool>(); | |||
| if (nargs != 4) { | |||
| throw py::type_error("expect 3 arguments"); | |||
| } | |||
| bool no_cache = nargs == 5 ? tup[4].cast<bool>() : false; | |||
| // const op | |||
| if (is_const && is_tracing) { | |||
| @@ -259,10 +260,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| interpreter::Interpreter::Handle handle; | |||
| constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||
| if (data.size() > size_threshhold) { | |||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype), no_cache); | |||
| } else { | |||
| HostTensorND ret(cn); | |||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||
| handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype), no_cache); | |||
| } | |||
| m_tensor = std::make_shared<Tensor>(handle); | |||
| @@ -6,6 +6,9 @@ | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import itertools | |||
| import os | |||
| import numpy as np | |||
| import megengine | |||
| @@ -58,13 +61,16 @@ def test_sgd_momentum(): | |||
| np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) | |||
| np.testing.assert_almost_equal( | |||
| optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 | |||
| optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5 | |||
| ) | |||
| def test_sgd_momentum_trace(): | |||
| for symbolic in (True, False): | |||
| origin_inplace = os.getenv("MEGENGINE_INPLACE_UPDATE") | |||
| symbolic = (True, False) | |||
| inplace = (0, 1) | |||
| for symbolic, inplace in itertools.product(symbolic, inplace): | |||
| os.environ["MEGENGINE_INPLACE_UPDATE"] = str(inplace) | |||
| @trace(symbolic=symbolic) | |||
| def train_func(data, *, model=None, optim=None, gm=None): | |||
| @@ -101,5 +107,9 @@ def test_sgd_momentum_trace(): | |||
| train_func(data, model=net, optim=optim, gm=gm) | |||
| np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) | |||
| np.testing.assert_almost_equal( | |||
| optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 | |||
| optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34, 5 | |||
| ) | |||
| if origin_inplace: | |||
| os.environ["MEGENGINE_INPLACE_UPDATE"] = origin_inplace | |||
| else: | |||
| del os.environ["MEGENGINE_INPLACE_UPDATE"] | |||
| @@ -325,7 +325,6 @@ def test_raise_on_trace(): | |||
| @trace | |||
| def add_abc(a, b, c): | |||
| print("Hello") | |||
| ps = a + b | |||
| result = ps + c | |||
| if step_count == bad_step: | |||
| @@ -11,6 +11,7 @@ | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/comp_node.h" | |||
| #include "megbrain/imperative/physical_tensor.h" | |||
| using namespace megdnn; | |||
| @@ -29,19 +30,21 @@ struct DnnOprCaller { | |||
| Workspace workspace; | |||
| std::unique_ptr<Opr> op; | |||
| DnnOprCaller(CompNode cn): cn(cn) { | |||
| DnnOprCaller(CompNode cn): cn(cn), op(create_operator(cn)) {} | |||
| static std::unique_ptr<Opr> create_operator(CompNode cn) { | |||
| auto&& handle = MegDNNHandle::get( | |||
| CompNodeEnv::from_comp_node(cn)).handle(); | |||
| op = handle->create_operator<Opr>(); | |||
| return handle->create_operator<Opr>(); | |||
| } | |||
| megdnn::Workspace create_workspace(TensorLayout layout) { | |||
| dev_tensor = Tensor::make(layout, cn)->dev_tensor(); | |||
| workspace = megdnn::Workspace(dev_tensor.raw_ptr(), | |||
| workspace = megdnn::Workspace(dev_tensor.raw_ptr(), | |||
| dev_tensor.storage().size()); | |||
| return workspace; | |||
| } | |||
| ~DnnOprCaller() { | |||
| using DT = CompNode::DeviceType; | |||
| if (cn.device_type() == DT::CPU && cn != CompNode::default_cpu()) { | |||
| @@ -52,5 +55,36 @@ struct DnnOprCaller { | |||
| } | |||
| }; | |||
| template <size_t OSize> | |||
| class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { | |||
| using Output = std::array<TensorPtr, OSize>; | |||
| CompNode m_cn; | |||
| Output m_out; | |||
| public: | |||
| MegDNNDynOutMallocImpl(CompNode cn): m_cn{cn} {} | |||
| megdnn::TensorND alloc_output( | |||
| size_t id, DType dtype, const TensorShape &shape, | |||
| void *user_data) override { | |||
| TensorLayout m_layout(shape, dtype); | |||
| m_out[id] = Tensor::make(m_layout, m_cn); | |||
| return m_out[id]->dev_tensor().as_megdnn(); | |||
| } | |||
| void* alloc_workspace(size_t sz, void *user_data) override { | |||
| return m_cn.alloc_device(sz); | |||
| } | |||
| void free_workspace(void *ptr, void *user_data) override { | |||
| m_cn.free_device(ptr); | |||
| } | |||
| TensorPtr at(size_t id) { | |||
| return m_out[id]; | |||
| } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| } // namespace mgb | |||
| @@ -28,13 +28,13 @@ Interpreter& Interpreter::inst() { | |||
| return inst_; | |||
| } | |||
| void* ChannelImpl::put(const HostTensorND& value) { | |||
| void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
| auto info = alloc(); | |||
| info->desc.layout = value.layout(); | |||
| info->desc.comp_node = value.comp_node(); | |||
| info->desc.value = value.proxy_to_default_cpu(); | |||
| m_valid_handle.insert(info); | |||
| m_worker.add_task(Put{info, value}); | |||
| m_worker.add_task(Put{info, value, no_cache}); | |||
| return info; | |||
| } | |||
| @@ -395,7 +395,8 @@ void ChannelImpl::process_one_task(Command& cmd) { | |||
| using T = std::remove_reference_t<decltype(cmd)>; | |||
| try { | |||
| if constexpr (std::is_same_v<T, Put>) { | |||
| produce_tensor(cmd.dest, Tensor::make(cmd.value)); | |||
| auto value = cmd.no_cache ? std::make_shared<Tensor>(cmd.value) : Tensor::make(cmd.value); | |||
| produce_tensor(cmd.dest, std::move(value)); | |||
| } else if constexpr (std::is_same_v<T, ApplyOp>) { | |||
| SmallVector<TensorPtr> tensor_inputs; | |||
| tensor_inputs.reserve(cmd.inputs.size()); | |||
| @@ -45,7 +45,7 @@ struct TensorInfo { | |||
| HostTensorND h_value; | |||
| size_t locked = 0; | |||
| size_t recompute_times = 0; | |||
| struct ComputePath { | |||
| std::shared_ptr<OpDef> op; | |||
| SmallVector<TensorInfoPtr> inputs; | |||
| @@ -57,6 +57,7 @@ struct TensorInfo { | |||
| struct Put { | |||
| TensorInfo* dest; | |||
| HostTensorND value; | |||
| bool no_cache = false; | |||
| }; | |||
| struct ApplyOp { | |||
| std::shared_ptr<OpDef> op; | |||
| @@ -92,7 +93,7 @@ struct ChannelImpl : Interpreter::Channel { | |||
| ChannelImpl() : m_worker(this) {} | |||
| ~ChannelImpl() override; | |||
| Handle put(const HostTensorND& value) override; | |||
| Handle put(const HostTensorND& value, bool no_cache) override; | |||
| Handle put(const DeviceTensorND& value) override; | |||
| void del(Handle) override; | |||
| @@ -20,44 +20,6 @@ namespace mgb::imperative { | |||
| namespace { | |||
| class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { | |||
| using Output = std::array<TensorPtr, 2>; | |||
| CompNode m_cn; | |||
| Output m_out; | |||
| public: | |||
| MegDNNDynOutMallocImpl(CompNode cn): m_cn{cn} {} | |||
| megdnn::TensorND alloc_output( | |||
| size_t id, DType dtype, const TensorShape &shape, | |||
| void *user_data) override; | |||
| void* alloc_workspace(size_t sz, void *user_data) override; | |||
| void free_workspace(void *ptr, void *user_data) override; | |||
| TensorPtr at(size_t id); | |||
| }; | |||
| megdnn::TensorND MegDNNDynOutMallocImpl::alloc_output( | |||
| size_t id, DType dtype, const TensorShape &shape, | |||
| void * /*user_data*/) { | |||
| TensorLayout m_layout(shape, dtype); | |||
| m_out[id] = Tensor::make(m_layout, m_cn); | |||
| return m_out[id]->dev_tensor().as_megdnn(); | |||
| } | |||
| void* MegDNNDynOutMallocImpl::alloc_workspace(size_t sz, void * /*user_data*/) { | |||
| return m_cn.alloc_device(sz); | |||
| } | |||
| void MegDNNDynOutMallocImpl::free_workspace(void *ptr, void * /*user_data*/) { | |||
| m_cn.free_device(ptr); | |||
| } | |||
| TensorPtr MegDNNDynOutMallocImpl::at(size_t id) { | |||
| return m_out[id]; | |||
| } | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| @@ -94,7 +56,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| dtype::Byte()); | |||
| auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||
| MegDNNDynOutMallocImpl policy{inp->comp_node()}; | |||
| MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; | |||
| dnn_op.op->exec(inp->dev_tensor().as_megdnn(), | |||
| msk->dev_tensor().as_megdnn(), | |||
| @@ -11,8 +11,11 @@ | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/imperative/opr_utility.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../op_trait.h" | |||
| #include "../dnn_op_helper.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -84,12 +87,142 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return {Tensor::make(out)}; | |||
| } | |||
| MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | |||
| public: | |||
| struct Param{ | |||
| using Mode = megdnn::Elemwise::Param::Mode; | |||
| Mode mode; | |||
| size_t inplace_index; | |||
| }; | |||
| using Mode = Param::Mode; | |||
| ForceInplaceElemwise(const VarNodeArray& inputs, Param param, | |||
| OperatorNodeConfig config = {}) | |||
| : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), m_param{param} { | |||
| for (auto* input: inputs) { | |||
| add_input({input}); | |||
| } | |||
| add_output(None)-> | |||
| set_fwd_in2out_writable_force(input(param.inplace_index)). | |||
| add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||
| } | |||
| static SymbolVar make(const VarNodeArray& inputs, Param param) { | |||
| return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>( | |||
| inputs, param); | |||
| } | |||
| static cg::OperatorNodeBase* shallow_copy( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config); | |||
| protected: | |||
| NodeProp* do_make_node_prop() const override { | |||
| auto ret = Super::do_make_node_prop(); | |||
| ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); | |||
| return ret; | |||
| } | |||
| void create_megdnn_opr() override { | |||
| auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node()); | |||
| opr->param().mode = m_param.mode; | |||
| set_megdnn_opr(std::move(opr)); | |||
| } | |||
| void scn_do_execute() override { | |||
| auto to_dnnnd = [&](auto* var){ return var->dev_tensor().as_megdnn(); }; | |||
| megdnn::TensorNDArray inputs_dnnnd; | |||
| for (auto* input: input()) { | |||
| inputs_dnnnd.push_back(to_dnnnd(input)); | |||
| } | |||
| mgb_assert(input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC), | |||
| "ForceInplaceElemwise cannot be applied in internal tensor"); | |||
| auto* out_dest = output(0); | |||
| auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr()); | |||
| opr->exec(std::move(inputs_dnnnd), | |||
| to_dnnnd(out_dest)); | |||
| } | |||
| void init_output_static_infer_desc() override { | |||
| using namespace cg::static_infer; | |||
| owner_graph()->static_infer_manager().register_shape_infer( | |||
| output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index))); | |||
| } | |||
| private: | |||
| Param m_param; | |||
| void record_execute_deps(ExecDependencyArray& deps) override { | |||
| record_megdnn_opr(deps); | |||
| } | |||
| }; | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise); | |||
| cg::OperatorNodeBase* ForceInplaceElemwise::shallow_copy( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config) { | |||
| auto &&opr = opr_.cast_final_safe<ForceInplaceElemwise>(); | |||
| auto* graph = ctx.owner_graph(opr, inputs); | |||
| return graph->insert_opr(std::make_unique<ForceInplaceElemwise>(inputs, opr.m_param, config)); | |||
| } | |||
| MGB_REG_OPR_SHALLOW_COPY(ForceInplaceElemwise, ForceInplaceElemwise::shallow_copy); | |||
| cg::OperatorNodeBase* apply_inplace_add_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto dest = inputs[0], delta = inputs[1], | |||
| alpha = inputs[2], beta = inputs[3]; | |||
| auto mode = ForceInplaceElemwise::Param::Mode::FUSE_MUL_ADD4; | |||
| return ForceInplaceElemwise::make({alpha, dest, beta, delta}, {mode, 1}).node()->owner_opr(); | |||
| } | |||
| SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs){ | |||
| auto dest = inputs[0], delta = inputs[1], | |||
| alpha = inputs[2], beta = inputs[3]; | |||
| auto tensor_to_scalar = [](const TensorPtr& tensor) -> float { | |||
| return *tensor->get_value().ptr<float>(); | |||
| }; | |||
| DnnOprCaller<megdnn::AddUpdate> caller{dest->comp_node()}; | |||
| caller.op->param() = { tensor_to_scalar(alpha), tensor_to_scalar(beta) }; | |||
| caller.op->exec(dest->dev_tensor().as_megdnn(), delta->dev_tensor().as_megdnn()); | |||
| return { std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout()) }; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| mgb_assert(inputs.size() == 4, "invalid input number for inplace_add"); | |||
| CompNode cn; | |||
| for (auto&& input: inputs) { | |||
| if (!cn.valid()) { | |||
| cn = input.comp_node; | |||
| } else { | |||
| mgb_assert(input.comp_node == cn, "inputs should be in same comp_node"); | |||
| } | |||
| } | |||
| auto dest = inputs[0], delta = inputs[1], | |||
| alpha = inputs[2], beta = inputs[3]; | |||
| bool succeed = dest.layout.ndim != 0; | |||
| if (succeed) { | |||
| mgb_assert(delta.layout.ndim == 0 || dest.layout.eq_shape(delta.layout), "dest and delta must have same shape"); | |||
| mgb_assert(alpha.layout.ndim == 0 || alpha.layout.eq_shape({1}), "alpha should be scalar"); | |||
| mgb_assert(beta.layout.ndim == 0 || beta.layout.eq_shape({1}), "beta should be scalar"); | |||
| } | |||
| mgb_assert(alpha.layout.dtype == dtype::Float32(), "alpha should be float32"); | |||
| mgb_assert(beta.layout.dtype == dtype::Float32(), "beta should be float32"); | |||
| return {{dest}, succeed}; | |||
| } | |||
| OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .fallback(); | |||
| OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) | |||
| .apply_on_var_node(apply_inplace_add_on_var_node) | |||
| .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) | |||
| .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| } // namespace imperative | |||
| @@ -32,14 +32,22 @@ SmallVector<Tensor*> to_raw_ptr_array( | |||
| return ret; | |||
| } | |||
| SmallVector<LogicalTensorDesc> | |||
| infer_output_attrs(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| auto&& graph = ProxyGraph::get_default_graph(); | |||
| return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); | |||
| } | |||
| } // anonymous namespace | |||
| void exec(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs_, | |||
| const SmallVector<TensorPtr>& outputs_) { | |||
| const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs) { | |||
| auto&& graph = ProxyGraph::get_default_graph(); | |||
| auto inputs = to_raw_ptr_array(inputs_), | |||
| outputs = to_raw_ptr_array(outputs_); | |||
| auto raw_inputs = to_raw_ptr_array(inputs), | |||
| raw_outputs = to_raw_ptr_array(outputs); | |||
| CompNode::UnorderedSet used_cns; | |||
| for (auto&& out: outputs) { | |||
| for (auto&& out: raw_outputs) { | |||
| auto cn = out->comp_node(); | |||
| if (used_cns.insert(cn).second) { | |||
| for (auto&& in: inputs) { | |||
| @@ -50,7 +58,7 @@ void exec(const OpDef& def, | |||
| } | |||
| } | |||
| } | |||
| graph->invoke_op(def, inputs, outputs); | |||
| graph->invoke_op(def, raw_inputs, raw_outputs); | |||
| for (auto&& cn: used_cns) { | |||
| for (auto&& in: inputs) { | |||
| if (in->comp_node() != cn) { | |||
| @@ -60,14 +68,6 @@ void exec(const OpDef& def, | |||
| } | |||
| } | |||
| SmallVector<LogicalTensorDesc> | |||
| infer_output_attrs(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| auto&& graph = ProxyGraph::get_default_graph(); | |||
| return graph->infer_output_attrs(def, to_raw_ptr_array(inputs)); | |||
| } | |||
| } // anonymous namespace | |||
| SmallVector<TensorPtr> | |||
| apply_on_physical_tensor(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| @@ -21,7 +21,7 @@ struct Interpreter { | |||
| struct Channel { | |||
| virtual ~Channel() = default; | |||
| virtual Handle put(const HostTensorND& value) = 0; | |||
| virtual Handle put(const HostTensorND& value, bool no_cache) = 0; | |||
| virtual Handle put(const DeviceTensorND& value) = 0; | |||
| virtual void del(Handle) = 0; | |||
| @@ -101,6 +101,10 @@ public: | |||
| return m_layout; | |||
| } | |||
| size_t offset() const { | |||
| return m_offset; | |||
| } | |||
| DeviceTensorND dev_tensor(); | |||
| static TensorPtr make_scalar(DTypeScalar value, CompNode cn); | |||
| @@ -24,6 +24,10 @@ apply_on_physical_tensor(const OpDef& def, | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs); | |||
| void exec(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs); | |||
| BackwardGraphResult | |||
| make_backward_graph(const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs, | |||
| @@ -239,4 +239,6 @@ def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypePara | |||
| ); | |||
| } | |||
| def InplaceAdd: MgbHashableOp<"InplaceAdd", [EmptyParam]>; | |||
| #endif // MGB_OPS | |||
| @@ -886,12 +886,9 @@ AddUpdate::AddUpdate(VarNode *dest, VarNode *delta, | |||
| m_param{param} | |||
| { | |||
| auto dest_opr = dest->owner_opr(); | |||
| mgb_throw_if(!(dest_opr->same_type<SharedDeviceTensor>() || | |||
| dest_opr->same_type<VolatileSharedDeviceTensor>()), | |||
| mgb_throw_if(dest_opr->same_type<ImmutableTensor>(), | |||
| GraphError, | |||
| "AddUpdate must be applied on SharedDeviceTensor; " | |||
| "got %s{%s} actually", | |||
| dest_opr->cname(), dest_opr->dyn_typeinfo()->name); | |||
| "AddUpdate cannot be applied on ImmutableTensor; "); | |||
| add_input({dest, delta}); | |||
| /* | |||
| @@ -80,6 +80,22 @@ public: | |||
| MGB_TYPEINFO_OBJ_IMPL(ForwardInputToOutput::MutableSrc); | |||
| void ForwardInputToOutput::mixin_init_rt_force_dynamic_mem_alloc_imply_chain( | |||
| OperatorNodeBase &opr) { | |||
| VarNode *valid_out = nullptr; | |||
| for (auto i: opr.output()) { | |||
| if (!i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
| mgb_assert(!valid_out); | |||
| valid_out = i; | |||
| } | |||
| } | |||
| mgb_assert(valid_out); | |||
| // There may be many inputs such as in opr::VirtualDep, but we only forward first one | |||
| opr.input(0)->add_rt_force_dynamic_mem_alloc_imply_chain(valid_out); | |||
| valid_out->add_rt_force_dynamic_mem_alloc_imply_chain(opr.input(0)); | |||
| } | |||
| void ForwardInputToOutput::mixin_mem_plan_fwd_in2out_readonly( | |||
| OperatorNodeBase& opr) { | |||
| m_mem_fwd_success = opr.output(0)->set_fwd_in2out_readonly( | |||
| @@ -67,6 +67,7 @@ class ForwardInputToOutput: public cg::OperatorNodeMixinBase { | |||
| virtual void mixin_scn_do_execute(OperatorNodeBase &opr); | |||
| void mixin_init_rt_force_dynamic_mem_alloc_imply_chain(OperatorNodeBase &opr); | |||
| void mixin_mem_plan_fwd_in2out_readonly(OperatorNodeBase &opr); | |||
| void mixin_init_output_static_infer_desc(OperatorNodeBase &opr); | |||
| virtual cg::static_infer::ValueInferDesc mixin_get_static_infer_desc(OperatorNodeBase &opr); | |||
| @@ -173,8 +174,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ForwardInputToOutput, | |||
| protected: | |||
| using Super::Super; | |||
| void init_rt_force_dynamic_mem_alloc_imply_chain() override { | |||
| mixin::init_rt_force_dynamic_mem_alloc_imply_chain_for_dyn_pass_i2o( | |||
| *this); | |||
| this->mixin_init_rt_force_dynamic_mem_alloc_imply_chain(*this); | |||
| } | |||
| void mem_plan_fwd_in2out_readonly() override { | |||