| @@ -16,6 +16,7 @@ import numpy as np | |||
| from .. import _config | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | |||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||
| from ..ops import builtin | |||
| from . import amp | |||
| from .indexing import getitem, setitem | |||
| @@ -508,12 +509,8 @@ def _reduce(mode): | |||
| elif self.dtype == np.bool_: | |||
| data = data.astype("int32") | |||
| if axis is None: | |||
| data = data.reshape(-1) | |||
| assert not keepdims, "can not set axis=None and keepdims=True" | |||
| op = builtin.Reduce(mode=mode, axis=0) | |||
| (result,) = apply(op, data) | |||
| result = _remove_axis(result, 0) | |||
| result = _reduce_to_scalar(builtin.Reduce(mode=mode), data) | |||
| elif isinstance(axis, collections.abc.Iterable): | |||
| axis = _normalize_axis(self.ndim, axis, reverse=True) | |||
| for ai in axis: | |||
| @@ -69,7 +69,7 @@ class SGD(Optimizer): | |||
| inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) | |||
| if inplace_mode: | |||
| _neg_lr = tensor(-lr, dtype="float32") | |||
| c1 = tensor([1.0]) | |||
| c1 = tensor(1.0) | |||
| for param in param_group["params"]: | |||
| if param.grad is None: | |||
| @@ -84,14 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| device: str = None, | |||
| is_const: bool = False, | |||
| no_cache: bool = False, | |||
| name: str = "", | |||
| name: str = None, | |||
| ): | |||
| if name is None: | |||
| name = "" | |||
| else: | |||
| self._set_name(name) | |||
| self._custom_name = name | |||
| self._name = name | |||
| self._short_name = name | |||
| self._set_name(self._name) | |||
| self._prefix = None | |||
| @property | |||
| @@ -46,17 +46,17 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
| if (args[1] != Py_None) { | |||
| callback = py::reinterpret_borrow<py::object>(args[1]); | |||
| } | |||
| GenericFunction generic_callback = | |||
| [=](Span<ValueRef> inputs) -> std::vector<ValueRef> { | |||
| GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList { | |||
| mgb_assert(inputs.size() == 1); | |||
| if (callback) { | |||
| callback(TensorWrapper::make(py_tensor_type, inputs[0])); | |||
| } | |||
| return {}; | |||
| }; | |||
| tw->m_tensor->reset(imperative::apply( | |||
| auto attached_value = imperative::apply( | |||
| AttachGrad(m_key), tw->m_tensor->data(), | |||
| FunctionValue::make(generic_callback))[0]); | |||
| FunctionValue::make(generic_callback))[0]; | |||
| tw->m_tensor->reset(attached_value); | |||
| } | |||
| void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) { | |||
| @@ -98,7 +98,7 @@ ValueRef make_empty_tensor( | |||
| return res; | |||
| } | |||
| std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
| std::optional<ValueRefList> elemwise_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto& elemwise = op.cast_final_safe<Elemwise>(); | |||
| @@ -117,7 +117,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(2); | |||
| ValueRefList ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| @@ -132,7 +132,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule( | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
| std::optional<ValueRefList> reshape_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| mgb_assert(inputs.size() == 2); | |||
| @@ -147,7 +147,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(2); | |||
| ValueRefList ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| @@ -162,7 +162,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule( | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
| std::optional<ValueRefList> subtensor_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& subtensor = op.cast_final_safe<Subtensor>(); | |||
| @@ -180,9 +180,9 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
| grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| ValueRefList ret(1); | |||
| if (grad && inputs[0]) { | |||
| SmallVector<ValueRef> args_(inputs.size() + 1); | |||
| ValueRefList args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
| args_[0] = zeros; | |||
| args_[1] = grad; | |||
| @@ -197,7 +197,7 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule( | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
| std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | |||
| @@ -215,9 +215,9 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
| grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| ValueRefList ret(1); | |||
| if (grad && inputs[0]) { | |||
| SmallVector<ValueRef> args_(inputs.size() + 1); | |||
| ValueRefList args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
| args_[0] = zeros; | |||
| args_[1] = grad; | |||
| @@ -232,7 +232,7 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule( | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
| std::optional<ValueRefList> reduce_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto& reduce = op.cast_final_safe<Reduce>(); | |||
| @@ -251,7 +251,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| ValueRefList ret(1); | |||
| if (grad && shapes[0]) { | |||
| ret[0] = broadcast_to(grad, shapes[0]); | |||
| } | |||
| @@ -261,7 +261,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule( | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
| std::optional<ValueRefList> addAxis_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& addAxis = op.cast_final_safe<AddAxis>(); | |||
| @@ -274,7 +274,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| ValueRefList ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| @@ -284,7 +284,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule( | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
| std::optional<ValueRefList> removeAxis_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& removeAxis = op.cast_final_safe<RemoveAxis>(); | |||
| @@ -297,7 +297,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| ValueRefList ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| @@ -307,7 +307,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule( | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||
| std::optional<ValueRefList> fastpathcopy_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| mgb_assert(inputs.size() == 1); | |||
| @@ -316,7 +316,7 @@ std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule( | |||
| maker.backward([](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| std::vector<ValueRef> ret(1); | |||
| ValueRefList ret(1); | |||
| if (grad) { | |||
| ret[0] = grad; | |||
| } | |||
| @@ -25,24 +25,23 @@ private: | |||
| py::function m_hook_fn; | |||
| int m_enabled = 0; | |||
| std::vector<ValueRef> apply_module_trace_hook( | |||
| const OpDef& op, Span<ValueRef> input_values) { | |||
| ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) { | |||
| py::list input_tws; | |||
| for (auto&& input_value : input_values) { | |||
| input_tws.append(TensorWrapper::make(py_tensor_type, input_value)); | |||
| } | |||
| py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws); | |||
| std::vector<ValueRef> outputs; | |||
| ValueRefList outputs(output_tws.size()); | |||
| auto it = outputs.begin(); | |||
| for (auto&& output_tw : output_tws) { | |||
| outputs.push_back( | |||
| TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data()); | |||
| *(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data(); | |||
| } | |||
| return outputs; | |||
| } | |||
| public: | |||
| ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {} | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override { | |||
| if (op.is<ApplyOp>() && m_enabled > 0) { | |||
| auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs); | |||
| @@ -87,7 +87,7 @@ PyObject* py_apply( | |||
| --nargs; | |||
| auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | |||
| SmallVector<ValueRef, 64> tensors(nargs); | |||
| SmallVector<ValueRef, 8> tensors(nargs); | |||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | |||
| // swap to a special context to reuse scalar handle | |||
| @@ -100,16 +100,15 @@ PyObject* py_apply( | |||
| Transformation::top()); | |||
| std::make_shared<ScalarTransformation>()->register_at( | |||
| Transformation::top()); | |||
| SmallVector<ValueRef> inputs(nargs); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| auto* py_input = py::handle(args[i]).cast<PySymbolVar*>(); | |||
| ValueRef input = SymbolValue::make(py_input->m_node); | |||
| if (py_input->is_scalar) { | |||
| input = ScalarValue::make(input); | |||
| } | |||
| inputs[i] = input; | |||
| tensors[i] = input; | |||
| } | |||
| auto outputs = imperative::apply(*op, inputs); | |||
| auto outputs = imperative::apply(*op, tensors); | |||
| auto ret = pybind11::tuple(outputs.size()); | |||
| auto typeobj = py::handle(args[0]).get_type(); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| @@ -140,7 +139,7 @@ PyObject* py_apply( | |||
| } | |||
| } | |||
| auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs}); | |||
| auto outputs = imperative::apply(*op, tensors); | |||
| size_t nout = outputs.size(); | |||
| auto ret = py::tuple(nout); | |||
| for (size_t i = 0; i < nout; ++i) { | |||
| @@ -214,16 +213,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
| if (!name.empty()) { | |||
| m_tensor->reset( | |||
| imperative::apply(RenameValue(name), m_tensor->data())[0]); | |||
| mgb_assert( | |||
| ((std::string&)*m_tensor->data().name()) == name, | |||
| "result name incorrect"); | |||
| } | |||
| if (data.ndim() == 0) { | |||
| mgb_assert(m_tensor->is_scalar(), "result should be scalar"); | |||
| } | |||
| } | |||
| } | |||
| mgb_assert(m_tensor->data()); | |||
| } | |||
| PyObject* TensorWrapper::module_trace_info() { | |||
| @@ -1384,15 +1377,20 @@ void init_tensor(py::module m) { | |||
| std::function<bool(py::object, py::object)> array_comparator; | |||
| bool compare_value(ValueRef lhs, ValueRef rhs) { | |||
| if (!lhs.shape()->eq(*rhs.shape())) { | |||
| auto lvalue = lhs.numpy(); | |||
| auto rvalue = rhs.numpy(); | |||
| if (lvalue->shape() != rvalue->shape()) { | |||
| return false; | |||
| } | |||
| HostTensorND lvalue = lhs.numpy()->as_nd(true); | |||
| HostTensorND rvalue = rhs.numpy()->as_nd(true); | |||
| if (lvalue->shape().is_scalar()) { | |||
| return lvalue->item() == rvalue->item(); | |||
| } | |||
| HostTensorND lnd = lvalue->as_nd(true); | |||
| HostTensorND rnd = rvalue->as_nd(true); | |||
| auto larr = py::reinterpret_steal<py::array>( | |||
| npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE)); | |||
| npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE)); | |||
| auto rarr = py::reinterpret_steal<py::array>( | |||
| npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE)); | |||
| npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE)); | |||
| return array_comparator(larr, rarr); | |||
| } | |||
| @@ -1539,6 +1537,19 @@ void init_tensor(py::module m) { | |||
| } | |||
| }); | |||
| m.def("reduce_to_scalar", [](py::object op, py::object tensor) { | |||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
| auto make_scalar_shape = [&](CompNode device) { | |||
| return imperative::apply( | |||
| CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}), | |||
| HostStorage::make(device))[0]; | |||
| }; | |||
| auto output = imperative::apply( | |||
| *op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(), | |||
| make_scalar_shape(tw->m_tensor->comp_node()))[0]; | |||
| return TensorWrapper::make(py_tensor_type, output); | |||
| }); | |||
| m.def("name_tensor", [](std::string name, py::object tensor) { | |||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
| auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0]; | |||
| @@ -1546,9 +1557,9 @@ void init_tensor(py::module m) { | |||
| }); | |||
| m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool { | |||
| SmallVector<ValueRef> values; | |||
| for (auto&& tensor : tensors) { | |||
| values.push_back(tensor.cast<TensorWrapper>().m_tensor->data()); | |||
| ValueRefList values(tensors.size()); | |||
| for (size_t i = 0; i < tensors.size(); ++i) { | |||
| values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||
| } | |||
| auto outputs = imperative::apply(GetGradKey(), values); | |||
| if (outputs[0].is<GradKeyValue>()) { | |||
| @@ -1559,9 +1570,9 @@ void init_tensor(py::module m) { | |||
| }); | |||
| m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object { | |||
| SmallVector<ValueRef> values; | |||
| for (auto&& tensor : tensors) { | |||
| values.push_back(tensor.cast<TensorWrapper>().m_tensor->data()); | |||
| ValueRefList values(tensors.size()); | |||
| for (size_t i = 0; i < tensors.size(); ++i) { | |||
| values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data(); | |||
| } | |||
| auto outputs = imperative::apply(GetGradKey(), values); | |||
| if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) { | |||
| @@ -1578,7 +1589,7 @@ void init_tensor(py::module m) { | |||
| mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr())); | |||
| auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst(); | |||
| GenericFunction generic_backward_fn = | |||
| [backward_fn](Span<ValueRef> output_grads) -> std::vector<ValueRef> { | |||
| [backward_fn](Span<ValueRef> output_grads) -> ValueRefList { | |||
| py::list output_grad_tws; | |||
| for (auto&& output_grad : output_grads) { | |||
| if (output_grad) { | |||
| @@ -1589,23 +1600,25 @@ void init_tensor(py::module m) { | |||
| } | |||
| } | |||
| py::tuple input_grad_tws = backward_fn(*output_grad_tws); | |||
| std::vector<ValueRef> input_grads; | |||
| for (auto&& input_grad_tw : input_grad_tws) { | |||
| ValueRefList input_grads(input_grad_tws.size()); | |||
| for (size_t i = 0; i < input_grad_tws.size(); ++i) { | |||
| auto input_grad_tw = input_grad_tws[i]; | |||
| if (!input_grad_tw.is_none()) { | |||
| input_grads.push_back( | |||
| py::cast<TensorWrapper>(input_grad_tw).m_tensor->data()); | |||
| input_grads[i] = | |||
| py::cast<TensorWrapper>(input_grad_tw).m_tensor->data(); | |||
| } else { | |||
| input_grads.push_back({}); | |||
| input_grads[i] = {}; | |||
| } | |||
| } | |||
| return input_grads; | |||
| }; | |||
| SmallVector<ValueRef> values; | |||
| for (auto&& input : inputs) { | |||
| values.push_back(input.cast<TensorWrapper>().m_tensor->data()); | |||
| ValueRefList values(inputs.size() + outputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
| } | |||
| for (auto&& output : outputs) { | |||
| values.push_back(output.cast<TensorWrapper>().m_tensor->data()); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| values[i + inputs.size()] = | |||
| outputs[i].cast<TensorWrapper>().m_tensor->data(); | |||
| } | |||
| auto wrapped_output_values = imperative::apply( | |||
| SetGrad(key->m_key, generic_backward_fn, inputs.size()), values); | |||
| @@ -39,7 +39,7 @@ namespace mgb::imperative::python { | |||
| extern interpreter::Interpreter::Channel* interpreter_for_py; | |||
| extern PyTypeObject* py_tensor_type; | |||
| struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
| struct Tensor : NonCopyableObj { | |||
| private: | |||
| std::string m_name; | |||
| ValueRef m_data; | |||
| @@ -52,7 +52,7 @@ public: | |||
| ~Tensor() = default; | |||
| inline std::shared_ptr<Tensor> copy() { | |||
| auto ret = std::make_shared<Tensor>(m_data.unwrap()); | |||
| auto ret = std::make_shared<Tensor>(m_data); | |||
| ret->m_name = m_name; | |||
| return ret; | |||
| } | |||
| @@ -11,7 +11,15 @@ | |||
| #pragma once | |||
| #include <optional> | |||
| #include <string> | |||
| #include "pybind11/pybind11.h" | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/transformation.h" | |||
| #include "megbrain/imperative/value.h" | |||
| #include "megbrain/utils/small_vector.h" | |||
| namespace mgb::imperative::python { | |||
| struct TransformationManager { | |||
| @@ -58,4 +66,14 @@ struct TransformationManager { | |||
| return sl_instance; | |||
| } | |||
| }; | |||
| class PyValue final : public MixinValueImpl<PyValue, pybind11::object> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const { | |||
| return pybind11::str((const pybind11::object&)*this).cast<std::string>(); | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative::python | |||
| @@ -45,7 +45,7 @@ CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout) | |||
| layout.is_contiguous() || layout.is_empty(), "layout should be contiguous"); | |||
| } | |||
| auto CreateTensor::parse(Span<ValueRef> inputs) -> Args { | |||
| auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { | |||
| Args result; | |||
| for (auto&& input : inputs) { | |||
| if (auto host_storage = input.as_ref<HostStorage>()) { | |||
| @@ -16,70 +16,67 @@ | |||
| #include "megbrain/imperative/utils/map.h" | |||
| namespace mgb { | |||
| void imperative_log_profile_begin(const char* message); | |||
| void imperative_log_profile(const char* message); | |||
| void imperative_log_profile_end(const char* message); | |||
| namespace imperative { | |||
| std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs) { | |||
| static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH"); | |||
| bool enable_watch = ValueRef::any_watching(); | |||
| auto& context = Transformation::get_context(); | |||
| size_t& depth = context.next_transformation; | |||
| static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"; | |||
| const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1; | |||
| bool log_current_dispatch = log_dispatch; | |||
| if (enable_watch) { | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| auto& input = inputs[i]; | |||
| if (input.watching()) { | |||
| log_current_dispatch = true; | |||
| mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str()); | |||
| debug::notify_event("apply"); | |||
| } | |||
| } | |||
| } | |||
| // entrance | |||
| std::vector<ValueRef> outputs; | |||
| if (depth >= context.transformations.size()) { | |||
| // fallback | |||
| if (log_current_dispatch) { | |||
| mgb_log_debug( | |||
| "%sfallback apply %s in %s", tabs, op.to_string().c_str(), | |||
| imperative::to_string(inputs).c_str()); | |||
| namespace { | |||
| MGB_NOINLINE void copy_outputs( | |||
| ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) { | |||
| size_t nr_outputs = outputs.size(); | |||
| if (mgb_likely(nr_outputs == 1)) { | |||
| ValueRef output_copy; | |||
| output_copy = outputs[0]; | |||
| allocator.clear(); | |||
| outputs = ValueRefList({output_copy}); | |||
| } else if (!outputs.empty()) { | |||
| SmallVector<ValueRef> outputs_copy(nr_outputs); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| outputs_copy[i] = outputs[i]; | |||
| } | |||
| outputs = op.fallback(inputs); | |||
| outputs.clear(); | |||
| allocator.clear(); | |||
| outputs = {outputs_copy.begin(), outputs_copy.end()}; | |||
| } else { | |||
| // dispatch to stack top | |||
| auto& transformation = *context.transformations[depth]; | |||
| ++depth; | |||
| context.frames.push_back({op, inputs}); | |||
| CleanupGuard _{[&] { | |||
| context.frames.pop_back(); | |||
| --depth; | |||
| }}; | |||
| if (log_current_dispatch) { | |||
| mgb_log_debug( | |||
| "%s%s apply %s in %s", tabs, transformation.name().c_str(), | |||
| op.to_string().c_str(), imperative::to_string(inputs).c_str()); | |||
| } | |||
| outputs = transformation.apply_transformation(op, inputs); | |||
| allocator.clear(); | |||
| } | |||
| if (log_current_dispatch) { | |||
| mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str()); | |||
| } | |||
| } // namespace | |||
| ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
| auto& context = Transformation::get_context(); | |||
| size_t& depth = context.next_transformation; | |||
| bool top = depth == 0; | |||
| auto outputs = ([&] { | |||
| if (mgb_unlikely(depth >= context.transformations.size())) { | |||
| return op.fallback(inputs); | |||
| } else { | |||
| auto& transformation = *context.transformations[depth++]; | |||
| CleanupGuard _{[&] { --depth; }}; | |||
| return transformation.apply_transformation(op, inputs); | |||
| } | |||
| })(); | |||
| if (mgb_unlikely(top)) { | |||
| copy_outputs(context.allocator, outputs); | |||
| } | |||
| return outputs; | |||
| } | |||
| std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs) { | |||
| ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { | |||
| return imperative::apply(ApplyOp{def}, inputs); | |||
| } | |||
| std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||
| ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { | |||
| SmallVector<ValueRef> inputs_storage; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| inputs_storage.push_back(inputs[i]); | |||
| } | |||
| auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs, | |||
| size_t) { | |||
| auto outputs = imperative::apply(ApplyOp(*op), inputs); | |||
| auto outputs = imperative::apply(*op, inputs); | |||
| return SmallVector<ValueRef>(outputs.begin(), outputs.end()); | |||
| }; | |||
| auto make_const = [](TensorPtr constant) -> ValueRef { | |||
| @@ -101,7 +98,7 @@ std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) { | |||
| DeviceStorage::make(device_value.storage()))[0]; | |||
| }; | |||
| auto outputs = graph.apply(inputs_storage, apply_functor, make_const); | |||
| return {outputs.begin(), outputs.end()}; | |||
| return ValueRefList{outputs.begin(), outputs.end()}; | |||
| } | |||
| } // namespace imperative | |||
| @@ -126,7 +126,7 @@ public: | |||
| m_frames[m_frames.size() - 1 - i] = {node, node->version()}; | |||
| node = node->parent(); | |||
| } | |||
| mgb_assert(node->is_root(), ""); | |||
| mgb_assert(node->is_root()); | |||
| } | |||
| Trace() = default; | |||
| std::string to_string() const { | |||
| @@ -3,7 +3,7 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| std::vector<ValueRef> Operator::fallback(Span<ValueRef> inputs) const { | |||
| ValueRefList Operator::fallback(Span<ValueRef> inputs) const { | |||
| mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str()); | |||
| } | |||
| @@ -99,19 +99,22 @@ Tensor::Tensor( | |||
| Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) { | |||
| constexpr int size_threshold = TensorShape::MAX_NDIM; | |||
| if (hv.layout().total_nr_elems() <= size_threshold) { | |||
| size_t nr_elems = hv.layout().total_nr_elems(); | |||
| if (nr_elems <= size_threshold) { | |||
| m_value = hv; | |||
| } | |||
| MGB_RECORD_EVENT( | |||
| profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), | |||
| dev_tensor().raw_ptr()); | |||
| dev_tensor().copy_from_fixlayout(hv); | |||
| // even though hv is saved in m_value, Tensor itself could be | |||
| // released before copy completes | |||
| MGB_RECORD_EVENT( | |||
| profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||
| hv.raw_ptr(), dev_tensor().raw_ptr()); | |||
| AsyncReleaser::inst()->add(hv); | |||
| if (nr_elems) { | |||
| MGB_RECORD_EVENT( | |||
| profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(), | |||
| dev_tensor().raw_ptr()); | |||
| dev_tensor().copy_from_fixlayout(hv); | |||
| // even though hv is saved in m_value, Tensor itself could be | |||
| // released before copy completes | |||
| MGB_RECORD_EVENT( | |||
| profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(), | |||
| hv.raw_ptr(), dev_tensor().raw_ptr()); | |||
| AsyncReleaser::inst()->add(hv); | |||
| } | |||
| } | |||
| Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) { | |||
| @@ -310,7 +310,8 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> { | |||
| } else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) { | |||
| new_host_event("TensorGetProp", 'X') | |||
| .dur(0) | |||
| .args(current_tensor->detail(current->time)); | |||
| .args(current_tensor->detail(current->time)) | |||
| .arg("kind", imperative::to_string(event.prop)); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) { | |||
| new_host_event("TensorWaitProp", 'B'); | |||
| } else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) { | |||
| @@ -15,71 +15,109 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| std::vector<ValueRef> InterpreterTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* op_val = op.as<ApplyOp>()) { | |||
| if (op_val->op().same_type<FastpathCopy>()) { | |||
| return {inputs[0]}; | |||
| } | |||
| SmallVector<Handle> input_handles; | |||
| SmallVector<Handle> output_handles; | |||
| CleanupGuard _{[&] { | |||
| for (auto handle : output_handles) { | |||
| if (handle) { | |||
| m_channel->del(handle); | |||
| } | |||
| DTypeValue::ref_t InterpreterInfo::dtype() const { | |||
| if (!m_dtype) { | |||
| m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); | |||
| } | |||
| return m_dtype; | |||
| } | |||
| CompNodeValue::ref_t InterpreterInfo::comp_node() const { | |||
| if (!m_comp_node) { | |||
| m_comp_node = CompNodeValue::make( | |||
| handle()->channel()->get_device(handle()->handle())); | |||
| } | |||
| return m_comp_node; | |||
| } | |||
| ShapeValue::ref_t InterpreterInfo::shape() const { | |||
| if (!m_shape) { | |||
| m_shape = ShapeValue::make( | |||
| ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); | |||
| } | |||
| return m_shape; | |||
| } | |||
| ValueRefList InterpreterTransformation::apply_op( | |||
| const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||
| if (apply_op.op().same_type<FastpathCopy>()) { | |||
| return {inputs[0]}; | |||
| } | |||
| SmallVector<Handle> input_handles; | |||
| SmallVector<Handle> output_handles; | |||
| CleanupGuard _{[&] { | |||
| for (auto handle : output_handles) { | |||
| if (handle) { | |||
| m_channel->del(handle); | |||
| } | |||
| }}; | |||
| for (auto input : inputs) { | |||
| input_handles.push_back(*input.cast<InterpreterValue>().handle()); | |||
| } | |||
| output_handles = | |||
| m_channel->apply_op(op_val->op().shared_from_this(), input_handles); | |||
| std::vector<ValueRef> outputs; | |||
| for (auto& handle : output_handles) { | |||
| outputs.push_back(InterpreterValue::make(share_handle(handle))); | |||
| handle = nullptr; | |||
| } | |||
| return outputs; | |||
| }}; | |||
| for (auto input : inputs) { | |||
| input_handles.push_back(input.cast<InterpreterValue>().handle()->handle()); | |||
| } | |||
| output_handles = | |||
| m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); | |||
| ValueRefList outputs(output_handles.size()); | |||
| for (size_t i = 0; i < output_handles.size(); ++i) { | |||
| outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); | |||
| output_handles[i] = nullptr; | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList InterpreterTransformation::apply_get_attr( | |||
| const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
| auto& input = inputs.item().cast<InterpreterValue>(); | |||
| ValueRef output; | |||
| switch (get_attr.attr()) { | |||
| case GetAttr::DType: | |||
| output = input.dtype(); | |||
| break; | |||
| case GetAttr::Shape: | |||
| output = input.shape(); | |||
| break; | |||
| case GetAttr::Device: | |||
| output = input.comp_node(); | |||
| break; | |||
| case GetAttr::Value: | |||
| output = HostValue::make(m_channel->get_value(input.handle()->handle())); | |||
| break; | |||
| case GetAttr::Data: | |||
| output = DeviceValue::make( | |||
| m_channel->get_dev_tensor(input.handle()->handle())); | |||
| break; | |||
| default: | |||
| mgb_throw( | |||
| MegBrainError, "Interpreter: malformed GetAttr: %s", | |||
| get_attr.to_string().c_str()); | |||
| } | |||
| return {output}; | |||
| } | |||
| ValueRefList InterpreterTransformation::apply_create_tensor( | |||
| const CreateTensor& create_tensor, Span<ValueRef> inputs) { | |||
| auto args = create_tensor.parse(inputs); | |||
| if (!args.device) { | |||
| // implies H2D | |||
| mgb_assert(args.host, "neither host and device value is valid"); | |||
| return {InterpreterValue::make(share_handle( | |||
| m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||
| } else { | |||
| return {InterpreterValue::make(share_handle(m_channel->put( | |||
| *args.device, args.host ? *args.host : HostTensorND())))}; | |||
| } | |||
| } | |||
| ValueRefList InterpreterTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* op_val = op.as<ApplyOp>()) { | |||
| return apply_op(*op_val, inputs); | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| Handle handle = *inputs[0].cast<InterpreterValue>().handle(); | |||
| ValueRef output; | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::DType: | |||
| output = DTypeValue::make(m_channel->get_dtype(handle)); | |||
| break; | |||
| case GetAttr::Shape: | |||
| output = ShapeValue::make( | |||
| ValueShape::from(m_channel->get_shape(handle))); | |||
| break; | |||
| case GetAttr::Device: | |||
| output = CompNodeValue::make(m_channel->get_device(handle)); | |||
| break; | |||
| case GetAttr::Value: | |||
| output = HostValue::make(m_channel->get_value(handle)); | |||
| break; | |||
| case GetAttr::Data: | |||
| output = DeviceValue::make(m_channel->get_dev_tensor(handle)); | |||
| break; | |||
| default: | |||
| mgb_throw( | |||
| MegBrainError, "Interpreter: malformed GetAttr: %s", | |||
| op.to_string().c_str()); | |||
| } | |||
| return {output}; | |||
| return apply_get_attr(*get_attr, inputs); | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| auto args = create_tensor->parse(inputs); | |||
| if (!args.device) { | |||
| // implies H2D | |||
| mgb_assert(args.host, "neither host and device value is valid"); | |||
| return {InterpreterValue::make(share_handle( | |||
| m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||
| } else { | |||
| return {InterpreterValue::make(share_handle(m_channel->put( | |||
| *args.device, args.host ? *args.host : HostTensorND())))}; | |||
| } | |||
| return apply_create_tensor(*create_tensor, inputs); | |||
| } else if (auto* dtr_command = op.as<DTRCommand>()) { | |||
| auto handle = *inputs[0].cast<InterpreterValue>().handle(); | |||
| auto handle = inputs[0].cast<InterpreterValue>().handle()->handle(); | |||
| switch (dtr_command->kind()) { | |||
| case DTRCommand::Drop: | |||
| m_channel->drop(handle); | |||
| @@ -64,12 +64,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
| size_t count = std::count_if( | |||
| save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | |||
| if (!backward_graph->precomp.empty()) { | |||
| SmallVector<ValueRef> inputs_and_outputs; | |||
| ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); | |||
| auto it = inputs_and_outputs.begin(); | |||
| for (auto&& input : inputs) { | |||
| inputs_and_outputs.push_back(input); | |||
| *it++ = input; | |||
| } | |||
| for (auto&& output : outputs) { | |||
| inputs_and_outputs.push_back(output); | |||
| *it++ = output; | |||
| } | |||
| auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); | |||
| closure.reserve(precomp.size() + count); | |||
| @@ -89,7 +90,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
| } | |||
| } | |||
| void BackwardGraphWithClosure::operator()( | |||
| std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| ValueRef args[closure.size() + grads.size()]; | |||
| size_t nargs = 0; | |||
| for (auto&& value : closure) { | |||
| @@ -120,7 +121,7 @@ void BackwardGraphWithClosure::operator()( | |||
| } | |||
| void CustomBackward::operator()( | |||
| std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| size_t nargs = grads.size(); | |||
| ValueRef args[nargs]; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| @@ -201,9 +202,10 @@ void GradKey::backward() { | |||
| mgb_throw(AssertionError, "invalid backward"); | |||
| } else { | |||
| mgb_assert(grad_fn->m_slots.size() > 0); | |||
| std::vector<ValueRef> grads; | |||
| ValueRefList grads (grad_fn->m_slots.size()); | |||
| auto iter = grads.begin(); | |||
| for (auto&& slot : grad_fn->m_slots) { | |||
| grads.push_back(slot.m_grad); | |||
| *iter++ = slot.m_grad; | |||
| } | |||
| backward(grads, grad_receiver); | |||
| } | |||
| @@ -254,21 +256,28 @@ void GradKey::freeze() { | |||
| m_frozen = true; | |||
| } | |||
| std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| ValueRefList GradTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> { | |||
| SmallVector<ValueRef> unwrapped_inputs; | |||
| for (auto&& input : inputs) { | |||
| if (auto grad_value = as_grad_value(input)) { | |||
| unwrapped_inputs.push_back(grad_value->m_value); | |||
| auto fallback = [&] { | |||
| ValueRefList unwrapped_inputs(inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (auto grad_value = as_grad_value(inputs[i])) { | |||
| unwrapped_inputs[i] = grad_value->m_value; | |||
| } else { | |||
| unwrapped_inputs.push_back(input); | |||
| unwrapped_inputs[i] = inputs[i]; | |||
| } | |||
| } | |||
| return unwrapped_inputs; | |||
| return imperative::apply(op, unwrapped_inputs); | |||
| }; | |||
| if (auto* get_attr = op.as<GetAttr>()) { | |||
| if (auto grad_value = as_grad_value(inputs.item())) { | |||
| return imperative::apply(op, grad_value->m_value); | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } | |||
| if (m_suppressed) { | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| return fallback(); | |||
| } | |||
| if (auto* op_val = op.as<ApplyOp>()) { | |||
| size_t nr_require_grad = 0; | |||
| @@ -284,20 +293,21 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| if (nr_require_grad == 0) { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| SmallVector<ValueRef> captured_inputs; | |||
| SmallVector<bool> inputs_require_grad; | |||
| ValueRefList captured_inputs(inputs.size()); | |||
| SmallVector<bool> inputs_require_grad(inputs.size()); | |||
| // capture value so that trace could assume input as same | |||
| auto capture_value = [](ValueRef value) { | |||
| // TODO: fastpath copy shouldn't be an OpDef | |||
| return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; | |||
| }; | |||
| for (auto& input : inputs) { | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| auto& input = inputs[i]; | |||
| if (auto grad_value = as_grad_value(input)) { | |||
| captured_inputs.push_back(capture_value(grad_value->m_value)); | |||
| inputs_require_grad.push_back(true); | |||
| captured_inputs[i] = capture_value(grad_value->m_value); | |||
| inputs_require_grad[i] = true; | |||
| } else { | |||
| captured_inputs.push_back(capture_value(input)); | |||
| inputs_require_grad.push_back(false); | |||
| captured_inputs[i] = capture_value(input); | |||
| inputs_require_grad[i] = false; | |||
| } | |||
| } | |||
| decltype(std::declval<GradFn>().m_backward) backward_storage; | |||
| @@ -373,9 +383,11 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| mgb_assert(!grad_fn->m_slots.empty()); | |||
| m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); | |||
| return outputs; | |||
| } else if (op.is<CreateTensor>()) { | |||
| return imperative::apply(op, inputs); | |||
| } else if (auto* attach_grad = op.as<AttachGrad>()) { | |||
| if (!has_key(attach_grad->key())) { | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| return fallback(); | |||
| } | |||
| auto tensor = inputs[0]; | |||
| GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
| @@ -386,7 +398,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| return {record_grad(output)}; | |||
| } else if (auto* grad_backward = op.as<GradBackward>()) { | |||
| if (!has_key(grad_backward->key())) { | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| return fallback(); | |||
| } | |||
| size_t nr_grads = inputs.size() / 2; | |||
| mgb_assert(nr_grads * 2 == inputs.size()); | |||
| @@ -416,7 +428,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| backward.m_output_attrs = | |||
| SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
| backward.m_backward = set_grad->grad_fn(); | |||
| std::vector<ValueRef> outputs; | |||
| ValueRefList outputs(nr_outputs); | |||
| grad_fn->m_key = m_key; | |||
| grad_fn->m_slots.resize(nr_outputs); | |||
| grad_fn->m_dests.reserve(nr_inputs); | |||
| @@ -439,13 +451,13 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| } else { | |||
| grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); | |||
| } | |||
| outputs.push_back(record_grad(grad_value)); | |||
| outputs[i] = record_grad(grad_value); | |||
| } | |||
| m_key->m_tape.push_back({grad_fn, nullptr}); | |||
| return outputs; | |||
| } else if (auto* gbc = op.as<GetBackwardColsure>()) { | |||
| if (gbc->key() != m_key) { | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| return fallback(); | |||
| } | |||
| return {FunctionValue::make(make_backward_closure(inputs))}; | |||
| } else if (op.is<DetachGrad>()) { | |||
| @@ -471,21 +483,8 @@ std::vector<ValueRef> GradTransformation::apply_transformation( | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (op.is<CreateTensor>()) { | |||
| return imperative::apply(op, inputs); | |||
| } else { | |||
| SmallVector<ValueRef> unwrapped_inputs; | |||
| for (auto&& input : inputs) { | |||
| if (auto grad_value = as_grad_value(input)) { | |||
| unwrapped_inputs.push_back(grad_value->m_value); | |||
| } else { | |||
| unwrapped_inputs.push_back(input); | |||
| } | |||
| } | |||
| auto outputs = imperative::apply( | |||
| op, {unwrapped_inputs.data(), unwrapped_inputs.size()}); | |||
| mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty()); | |||
| return outputs; | |||
| return fallback(); | |||
| } | |||
| } | |||
| @@ -500,8 +499,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { | |||
| y_slots.emplace_back(); | |||
| } | |||
| } | |||
| GenericFunction closure = [grad_key, | |||
| y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> { | |||
| GenericFunction closure = [grad_key, y_slots](Span<ValueRef> dys) -> ValueRefList { | |||
| size_t nr_grads = y_slots.size(); | |||
| mgb_assert(dys.size() == nr_grads); | |||
| for (size_t i = 0; i < nr_grads; ++i) { | |||
| @@ -21,7 +21,7 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| std::vector<ValueRef> LazyEvalTransformation::apply_transformation( | |||
| ValueRefList LazyEvalTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* op_val = op.as<ApplyOp>()) { | |||
| static std::unordered_set<Typeinfo*> mm_io_ops = { | |||
| @@ -59,9 +59,9 @@ std::vector<ValueRef> LazyEvalTransformation::apply_transformation( | |||
| mgb_assert(!output_nodes.empty()); | |||
| m_io_link = SymbolVar(output_nodes[0]); | |||
| } | |||
| std::vector<ValueRef> outputs; | |||
| for (auto&& output_node : output_nodes) { | |||
| outputs.push_back(record_var(output_node)); | |||
| ValueRefList outputs(output_nodes.size()); | |||
| for (size_t i = 0; i < output_nodes.size(); ++i) { | |||
| outputs[i] = record_var(output_nodes[i]); | |||
| } | |||
| return outputs; | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| @@ -19,26 +19,8 @@ namespace imperative { | |||
| namespace { | |||
| using ScalarRule = std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>; | |||
| static std::unordered_map< | |||
| Typeinfo*, std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>> | |||
| scalar_rules; | |||
| ValueRef unwrap_input(ValueRef input) { | |||
| if (auto scalar_input = input.as_ref<ScalarValue>()) { | |||
| return scalar_input->value(); | |||
| } else { | |||
| return input; | |||
| } | |||
| } | |||
| std::vector<ValueRef> unwrap_inputs(Span<ValueRef> inputs) { | |||
| std::vector<ValueRef> unwrapped_inputs; | |||
| for (auto&& input : inputs) { | |||
| unwrapped_inputs.push_back(unwrap_input(input)); | |||
| } | |||
| return unwrapped_inputs; | |||
| } | |||
| using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>); | |||
| static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules; | |||
| ValueRef make_scalar_shape(CompNode device) { | |||
| HostTensorND scalar_shape(device, {1}, dtype::Int32()); | |||
| @@ -49,9 +31,6 @@ ValueRef make_scalar_shape(CompNode device) { | |||
| } | |||
| bool is_scalar_shape(ValueRef shape) { | |||
| if (shape.is<ScalarValue>()) { | |||
| return false; | |||
| } | |||
| // may have performance issue | |||
| auto shape_of_shape = shape.shape(); | |||
| if (!shape_of_shape) { | |||
| @@ -61,74 +40,65 @@ bool is_scalar_shape(ValueRef shape) { | |||
| return *shape_of_shape == ValueShape{0}; | |||
| } | |||
| template <typename T> | |||
| void register_scalar_rule(std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>)) { | |||
| scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef> inputs) { | |||
| return (*rule)(def.cast_final_safe<T>(), inputs); | |||
| template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)> | |||
| void register_scalar_rule() { | |||
| scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask); | |||
| }; | |||
| } | |||
| std::vector<ValueRef> elemwise_rule(const Elemwise& elem, Span<ValueRef> inputs) { | |||
| template <typename TOpDef, size_t nr_inputs> | |||
| ValueRefList elemwise_rule( | |||
| const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| if constexpr (nr_inputs != 0) { | |||
| mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); | |||
| } | |||
| bool all_scalar = true; | |||
| for (auto&& input : inputs) { | |||
| if (!input.is<ScalarValue>()) { | |||
| for (auto&& input_mask : inputs_mask) { | |||
| if (!input_mask) { | |||
| all_scalar = false; | |||
| break; | |||
| } | |||
| } | |||
| auto output = imperative::apply(elem, unwrap_inputs(inputs))[0]; | |||
| auto outputs = imperative::apply(op_def, inputs); | |||
| if (all_scalar) { | |||
| return {ScalarValue::make(output)}; | |||
| } else { | |||
| return {output}; | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } | |||
| std::vector<ValueRef> remove_axis_rule( | |||
| const RemoveAxis& remove_axis, Span<ValueRef> inputs) { | |||
| mgb_assert(inputs.size() == 1); | |||
| mgb_assert(!inputs[0].is<ScalarValue>()); | |||
| auto output = imperative::apply(remove_axis, inputs)[0]; | |||
| bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size(); | |||
| ValueRefList remove_axis_rule( | |||
| const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| mgb_assert(!inputs_mask.item()); | |||
| bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); | |||
| if (is_scalar && remove_axis.axis.size() == 1) { | |||
| return {ScalarValue::make(inputs.item())}; | |||
| } | |||
| auto outputs = imperative::apply(remove_axis, inputs); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(output)}; | |||
| } else { | |||
| return {output}; | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } | |||
| std::vector<ValueRef> reduce_rule(const Reduce& reduce, Span<ValueRef> inputs) { | |||
| ValueRefList reduce_rule( | |||
| const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| if (inputs.size() == 1) { | |||
| return imperative::apply(reduce, unwrap_inputs(inputs)); | |||
| return imperative::apply(reduce, inputs); | |||
| } | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| if (is_scalar) { | |||
| auto unwrapped_input = unwrap_input(inputs[0]); | |||
| CompNode device = *unwrapped_input.device(); | |||
| return {ScalarValue::make(imperative::apply( | |||
| reduce, unwrapped_input, make_scalar_shape(device))[0])}; | |||
| } | |||
| auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0]; | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(output)}; | |||
| } else { | |||
| return {output}; | |||
| } | |||
| } | |||
| std::vector<ValueRef> typecvt_rule(const TypeCvt& typecvt, Span<ValueRef> inputs) { | |||
| mgb_assert(inputs.size() == 1); | |||
| if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||
| CompNode device = *inputs[0].device(); | |||
| return {ScalarValue::make( | |||
| imperative::apply(typecvt, scalar_input->value())[0])}; | |||
| } else { | |||
| return imperative::apply(typecvt, inputs); | |||
| imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; | |||
| } | |||
| return imperative::apply(reduce, inputs); | |||
| } | |||
| std::vector<ValueRef> collective_comm_rule( | |||
| const CollectiveComm& collective_comm, Span<ValueRef> inputs) { | |||
| ValueRefList collective_comm_rule( | |||
| const CollectiveComm& collective_comm, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| mgb_assert(inputs.size() == 1); | |||
| static std::unordered_set<CollectiveComm::Mode> modes = { | |||
| CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, | |||
| @@ -138,17 +108,17 @@ std::vector<ValueRef> collective_comm_rule( | |||
| if (modes.count(collective_comm.mode) == 0) { | |||
| return imperative::apply(collective_comm, inputs); | |||
| } | |||
| if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||
| return {ScalarValue::make( | |||
| imperative::apply(collective_comm, scalar_input->value())[0])}; | |||
| if (inputs_mask.item()) { | |||
| return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; | |||
| } else { | |||
| return imperative::apply(collective_comm, inputs); | |||
| } | |||
| } | |||
| std::vector<ValueRef> param_pack_split_rule( | |||
| const ParamPackSplit& param_pack_split, Span<ValueRef> inputs) { | |||
| auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs)); | |||
| ValueRefList param_pack_split_rule( | |||
| const ParamPackSplit& param_pack_split, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| auto outputs = imperative::apply(param_pack_split, inputs); | |||
| size_t nr_outputs = outputs.size(); | |||
| mgb_assert(nr_outputs == param_pack_split.shapes.size()); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| @@ -159,29 +129,28 @@ std::vector<ValueRef> param_pack_split_rule( | |||
| return outputs; | |||
| } | |||
| std::vector<ValueRef> dot_rule(const Dot& dot, Span<ValueRef> inputs) { | |||
| return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])}; | |||
| ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; | |||
| } | |||
| std::vector<ValueRef> add_axis_rule(const AddAxis& add_axis, Span<ValueRef> inputs) { | |||
| ValueRefList add_axis_rule( | |||
| const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| mgb_assert(inputs.size() == 1); | |||
| if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) { | |||
| if (inputs_mask.item()) { | |||
| mgb_assert(add_axis.axis[0] == 0); | |||
| if (add_axis.axis.size() == 1) { | |||
| return {scalar_input->value()}; | |||
| return {inputs[0]}; | |||
| } else { | |||
| std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end()); | |||
| return imperative::apply( | |||
| ApplyOp(*AddAxis::make(axis, add_axis.scope())), | |||
| scalar_input->value()); | |||
| return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]); | |||
| } | |||
| } else { | |||
| return imperative::apply(add_axis, inputs); | |||
| } | |||
| } | |||
| std::vector<ValueRef> remote_recv_rule( | |||
| const RemoteRecv& remote_recv, Span<ValueRef> inputs) { | |||
| ValueRefList remote_recv_rule( | |||
| const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| if (remote_recv.shape.empty()) { | |||
| std::vector<int32_t> shape = {1}; | |||
| auto remote_recv_no_scalar = RemoteRecv::make( | |||
| @@ -189,32 +158,32 @@ std::vector<ValueRef> remote_recv_rule( | |||
| remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype, | |||
| remote_recv.backend); | |||
| remote_recv_no_scalar->set_scope(remote_recv.scope()); | |||
| return imperative::apply( | |||
| ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs)); | |||
| return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs); | |||
| } else { | |||
| return imperative::apply(remote_recv, unwrap_inputs(inputs)); | |||
| return imperative::apply(remote_recv, inputs); | |||
| } | |||
| } | |||
| std::vector<ValueRef> check_no_finite_rule( | |||
| const CheckNonFinite& check_no_finite, Span<ValueRef> inputs) { | |||
| auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs)); | |||
| ValueRefList check_no_finite_rule( | |||
| const CheckNonFinite& check_no_finite, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| auto outputs = imperative::apply(check_no_finite, inputs); | |||
| mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); | |||
| outputs.back() = ScalarValue::make(outputs.back()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (inputs[i].is<ScalarValue>()) { | |||
| if (inputs_mask[i]) { | |||
| outputs[i] = ScalarValue::make(outputs[i]); | |||
| } | |||
| } | |||
| return outputs; | |||
| } | |||
| std::vector<ValueRef> subtensor_rule( | |||
| const Subtensor& subtensor, Span<ValueRef> inputs) { | |||
| ValueRefList subtensor_rule( | |||
| const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| mgb_assert(inputs.size() >= 1); | |||
| auto input = inputs[0]; | |||
| bool is_scalar; | |||
| mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input"); | |||
| mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input"); | |||
| if (auto shape = input.shape()) { | |||
| size_t ndim = input.shape()->ndim; | |||
| for (auto&& [axis, begin, end, step, idx] : subtensor.items) { | |||
| @@ -226,25 +195,25 @@ std::vector<ValueRef> subtensor_rule( | |||
| } else { | |||
| is_scalar = false; | |||
| } | |||
| auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0]; | |||
| auto outputs = imperative::apply(subtensor, inputs); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(output)}; | |||
| } else { | |||
| return {output}; | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } | |||
| std::vector<ValueRef> get_var_shape_rule( | |||
| const GetVarShape& get_var_shape, Span<ValueRef> inputs) { | |||
| ValueRefList get_var_shape_rule( | |||
| const GetVarShape& get_var_shape, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| bool all_scalar = true; | |||
| mgb_assert(inputs.size() >= 1); | |||
| for (auto&& input : inputs) { | |||
| if (!input.is<ScalarValue>()) { | |||
| for (auto&& input_mask : inputs_mask) { | |||
| if (!input_mask) { | |||
| all_scalar = false; | |||
| } | |||
| } | |||
| if (all_scalar) { | |||
| auto device = inputs[0].cast<ScalarValue>().value().device(); | |||
| auto device = inputs[0].device(); | |||
| auto storage = HostStorage::make(*device); | |||
| // storage->ensure_size(1); | |||
| return imperative::apply( | |||
| @@ -252,88 +221,49 @@ std::vector<ValueRef> get_var_shape_rule( | |||
| CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}), | |||
| storage); | |||
| } else { | |||
| return imperative::apply(get_var_shape, unwrap_inputs(inputs)); | |||
| } | |||
| } | |||
| std::vector<ValueRef> fastpath_copy_rule( | |||
| const FastpathCopy& fastpath_copy, Span<ValueRef> inputs) { | |||
| mgb_assert(inputs.size() == 1); | |||
| bool is_scalar = inputs[0].is<ScalarValue>(); | |||
| auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0]; | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(output)}; | |||
| } else { | |||
| return {output}; | |||
| return imperative::apply(get_var_shape, inputs); | |||
| } | |||
| } | |||
| std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) { | |||
| ValueRefList reshape_rule( | |||
| const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| auto unwrapped_input = inputs[0].is<ScalarValue>() | |||
| ? inputs[0].cast<ScalarValue>().value() | |||
| : inputs[0]; | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(imperative::apply( | |||
| reshape, unwrapped_input, | |||
| make_scalar_shape(*unwrapped_input.device()))[0])}; | |||
| reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
| } else { | |||
| return imperative::apply(reshape, unwrap_inputs(inputs)); | |||
| return imperative::apply(reshape, inputs); | |||
| } | |||
| } | |||
| std::vector<ValueRef> broadcast_rule( | |||
| const Broadcast& broadcast, Span<ValueRef> inputs) { | |||
| ValueRefList broadcast_rule( | |||
| const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| auto unwrapped_input = inputs[0].is<ScalarValue>() | |||
| ? inputs[0].cast<ScalarValue>().value() | |||
| : inputs[0]; | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(imperative::apply( | |||
| broadcast, unwrapped_input, | |||
| make_scalar_shape(*unwrapped_input.device()))[0])}; | |||
| } else { | |||
| return imperative::apply(broadcast, unwrap_inputs(inputs)); | |||
| } | |||
| } | |||
| std::vector<ValueRef> copy_rule(const Copy& copy, Span<ValueRef> inputs) { | |||
| mgb_assert(inputs.size() == 1); | |||
| bool is_scalar = inputs[0].is<ScalarValue>(); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])}; | |||
| } else { | |||
| return imperative::apply(copy, unwrap_inputs(inputs)); | |||
| } | |||
| } | |||
| std::vector<ValueRef> inplace_add_rule( | |||
| const InplaceAdd& inplace_add, Span<ValueRef> inputs) { | |||
| mgb_assert(inputs.size() == 4); | |||
| bool is_scalar = inputs[0].is<ScalarValue>(); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make( | |||
| imperative::apply(inplace_add, unwrap_inputs(inputs))[0])}; | |||
| broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
| } else { | |||
| return imperative::apply(inplace_add, unwrap_inputs(inputs)); | |||
| return imperative::apply(broadcast, inputs); | |||
| } | |||
| } | |||
| template <typename T> | |||
| std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||
| ValueRefList subgraph_op_rule( | |||
| const T& op, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| // TODO: add flag instead of assume | |||
| bool all_scalar = true; | |||
| for (auto&& input : inputs) { | |||
| if (!input.is<ScalarValue>()) { | |||
| for (auto&& input_mask : inputs_mask) { | |||
| if (!input_mask) { | |||
| all_scalar = false; | |||
| } | |||
| } | |||
| auto outputs = imperative::apply(op, unwrap_inputs(inputs)); | |||
| auto outputs = imperative::apply(op, inputs); | |||
| if (all_scalar) { | |||
| for (auto& output : outputs) { | |||
| output = ScalarValue::make(output); | |||
| output = scalar_type.make(output); | |||
| } | |||
| } | |||
| return outputs; | |||
| @@ -341,67 +271,54 @@ std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) { | |||
| struct ScalarRuleRegistry { | |||
| ScalarRuleRegistry() { | |||
| register_scalar_rule(elemwise_rule); | |||
| register_scalar_rule(remove_axis_rule); | |||
| register_scalar_rule(reduce_rule); | |||
| register_scalar_rule(typecvt_rule); | |||
| register_scalar_rule(collective_comm_rule); | |||
| register_scalar_rule(param_pack_split_rule); | |||
| register_scalar_rule(dot_rule); | |||
| register_scalar_rule(add_axis_rule); | |||
| register_scalar_rule(remote_recv_rule); | |||
| register_scalar_rule(check_no_finite_rule); | |||
| register_scalar_rule(subtensor_rule); | |||
| register_scalar_rule(get_var_shape_rule); | |||
| register_scalar_rule(fastpath_copy_rule); | |||
| register_scalar_rule(reshape_rule); | |||
| register_scalar_rule(broadcast_rule); | |||
| register_scalar_rule(copy_rule); | |||
| register_scalar_rule(inplace_add_rule); | |||
| register_scalar_rule(subgraph_op_rule<SubgraphOp>); | |||
| register_scalar_rule(subgraph_op_rule<CompiledOp>); | |||
| register_scalar_rule<Elemwise, elemwise_rule<Elemwise, 0>>(); | |||
| register_scalar_rule<RemoveAxis, remove_axis_rule>(); | |||
| register_scalar_rule<Reduce, reduce_rule>(); | |||
| register_scalar_rule<TypeCvt, elemwise_rule<TypeCvt, 1>>(); | |||
| register_scalar_rule<CollectiveComm, collective_comm_rule>(); | |||
| register_scalar_rule<ParamPackSplit, param_pack_split_rule>(); | |||
| register_scalar_rule<Dot, dot_rule>(); | |||
| register_scalar_rule<AddAxis, add_axis_rule>(); | |||
| register_scalar_rule<RemoteRecv, remote_recv_rule>(); | |||
| register_scalar_rule<CheckNonFinite, check_no_finite_rule>(); | |||
| register_scalar_rule<Subtensor, subtensor_rule>(); | |||
| register_scalar_rule<GetVarShape, get_var_shape_rule>(); | |||
| register_scalar_rule<FastpathCopy, elemwise_rule<FastpathCopy, 1>>(); | |||
| register_scalar_rule<Reshape, reshape_rule>(); | |||
| register_scalar_rule<Broadcast, broadcast_rule>(); | |||
| register_scalar_rule<Copy, elemwise_rule<Copy, 1>>(); | |||
| register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>(); | |||
| register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>(); | |||
| register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>(); | |||
| } | |||
| } _; | |||
| } // namespace | |||
| std::vector<ValueRef> ScalarTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto apply_op = op.as<ApplyOp>()) { | |||
| auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||
| if (iter != scalar_rules.end()) { | |||
| return iter->second(apply_op->op(), inputs); | |||
| } else { | |||
| // TODO: repeat op | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| } | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| if (create_tensor->shape().is_scalar()) { | |||
| ValueShape scalar_shape = {1}; | |||
| CreateTensor scalar_op( | |||
| create_tensor->kind(), create_tensor->device(), | |||
| create_tensor->dtype(), scalar_shape); | |||
| return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); | |||
| auto output = imperative::apply(op, unwrap_inputs(inputs))[0]; | |||
| if (!is_scalar) { | |||
| return {output}; | |||
| ValueRefList ScalarTransformation::apply_get_attr( | |||
| const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
| auto&& input = inputs.item(); | |||
| bool is_scalar = input.is<ScalarValue>(); | |||
| if (!is_scalar) { | |||
| return imperative::apply(get_attr, input); | |||
| } | |||
| auto unwrapped_input = input.cast<ScalarValue>().value(); | |||
| if (get_attr.attr() == GetAttr::Shape) { | |||
| if (!m_empty_shape) { | |||
| m_empty_shape = ShapeValue::make(); | |||
| } | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::Shape: { | |||
| // Scalar Shape | |||
| return {ShapeValue::make()}; | |||
| } | |||
| return {m_empty_shape}; | |||
| } else { | |||
| auto outputs = imperative::apply(get_attr, unwrapped_input); | |||
| auto& output = outputs[0]; | |||
| switch (get_attr.attr()) { | |||
| case GetAttr::Value: { | |||
| auto& hv = output.cast<HostValue>(); | |||
| mgb_assert( | |||
| hv.shape() == ValueShape({1}), | |||
| "underlying value should has shape {1}, got %s", | |||
| hv.shape().to_string().c_str()); | |||
| return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())}; | |||
| output = HostValue::make(hv.dtype(), ValueShape(), hv.storage()); | |||
| break; | |||
| } | |||
| case GetAttr::Data: { | |||
| auto& dv = output.cast<DeviceValue>(); | |||
| @@ -409,22 +326,67 @@ std::vector<ValueRef> ScalarTransformation::apply_transformation( | |||
| dv.shape() == ValueShape({1}), | |||
| "underlying value should has shape {1}, got %s", | |||
| dv.shape().to_string().c_str()); | |||
| return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())}; | |||
| output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage()); | |||
| break; | |||
| } | |||
| default: | |||
| return {output}; | |||
| break; | |||
| } | |||
| return outputs; | |||
| } | |||
| } | |||
| ValueRefList ScalarTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* get_attr = op.as<GetAttr>()) { | |||
| // fastpath for GetAttr | |||
| return apply_get_attr(*get_attr, inputs); | |||
| } | |||
| size_t nr_inputs = inputs.size(); | |||
| ValueRefList unwrapped_inputs(nr_inputs); | |||
| bool inputs_mask[nr_inputs]; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||
| unwrapped_inputs[i] = scalar_value->value(); | |||
| inputs_mask[i] = true; | |||
| } else { | |||
| unwrapped_inputs[i] = inputs[i]; | |||
| inputs_mask[i] = false; | |||
| } | |||
| } | |||
| auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); }; | |||
| if (auto apply_op = op.as<ApplyOp>()) { | |||
| auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||
| if (iter != scalar_rules.end()) { | |||
| return iter->second( | |||
| apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs}); | |||
| } else { | |||
| // TODO: repeat op | |||
| return fallback(); | |||
| } | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| if (create_tensor->shape().is_scalar()) { | |||
| ValueShape scalar_shape = {1}; | |||
| CreateTensor scalar_op( | |||
| create_tensor->kind(), create_tensor->device(), | |||
| create_tensor->dtype(), scalar_shape); | |||
| return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (op.as<IsScalar>()) { | |||
| return {BoolValue::make(inputs.as_array<1>()[0].is<ScalarValue>())}; | |||
| mgb_assert(nr_inputs == 1); | |||
| return {BoolValue::make(inputs_mask[0])}; | |||
| } else if (op.is<Operator::IdentityLike>()) { | |||
| bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>(); | |||
| mgb_assert(nr_inputs == 1); | |||
| bool is_scalar = inputs_mask[0]; | |||
| auto outputs = fallback(); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } else { | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| return fallback(); | |||
| } | |||
| }; | |||
| @@ -0,0 +1,25 @@ | |||
| /** | |||
| * \file imperative/src/impl/transformations/tangent.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/imperative/transformations/tangent.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| ValueRefList TangentTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||
| } | |||
| mgb_assert(false); | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -153,7 +153,7 @@ VarNodeArray TraceResult::dump( | |||
| return output_nodes; | |||
| } | |||
| std::vector<ValueRef> TracingTransformation::apply_transformation( | |||
| ValueRefList TracingTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* op_value = op.as<ApplyOp>()) { | |||
| SmallVector<ValueRef> unwrapped_inputs; | |||
| @@ -180,11 +180,12 @@ std::vector<ValueRef> TracingTransformation::apply_transformation( | |||
| } | |||
| const_cast<OpDef&>(op_value->op()).set_scope(scopes_join); | |||
| auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs); | |||
| std::vector<ValueRef> wrapped_outputs; | |||
| ValueRefList wrapped_outputs(unwrapped_outputs.size()); | |||
| SmallVector<size_t> output_ids; | |||
| for (auto&& output : unwrapped_outputs) { | |||
| for (size_t i = 0; i < unwrapped_outputs.size(); ++i) { | |||
| auto&& output = unwrapped_outputs[i]; | |||
| auto wrapped_output = record_var(output, false, VarKind::Internal); | |||
| wrapped_outputs.push_back(wrapped_output); | |||
| wrapped_outputs[i] = wrapped_output; | |||
| output_ids.push_back(wrapped_output->id()); | |||
| } | |||
| m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids}); | |||
| @@ -375,6 +376,11 @@ void CompiledTransformation::compile() { | |||
| return accessor; | |||
| }; | |||
| std::vector<VarAccessor> var_accessors(m_vars.size()); | |||
| auto exc_setter = std::bind( | |||
| &CompiledTransformation::set_exception, this, std::placeholders::_1); | |||
| for (auto&& accessor : var_accessors) { | |||
| accessor.exc_setter = exc_setter; | |||
| } | |||
| for (auto&& item : m_seq) { | |||
| bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo()); | |||
| VarNodeArray input_vars; | |||
| @@ -509,8 +515,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { | |||
| } | |||
| } | |||
| TracedValue::ref_t CompiledTransformation::trace_output(size_t id) { | |||
| auto traced_value = TracedValue::make(id); | |||
| auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { | |||
| auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); | |||
| m_weak_values.push_back(traced_value); | |||
| return traced_value; | |||
| } | |||
| @@ -520,64 +526,99 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { | |||
| return m_seq[m_pc++]; | |||
| } | |||
| std::vector<ValueRef> CompiledTransformation::apply_transformation( | |||
| ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { | |||
| if (!m_shape) { | |||
| trace_assert(m_accessor->shape_getter, "shape unreadable"); | |||
| m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); | |||
| } | |||
| return m_shape; | |||
| } | |||
| DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { | |||
| if (!m_dtype) { | |||
| m_dtype = DTypeValue::make(m_var->dtype); | |||
| } | |||
| return m_dtype; | |||
| } | |||
| CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { | |||
| if (!m_comp_node) { | |||
| m_comp_node = CompNodeValue::make(m_var->device); | |||
| } | |||
| return m_comp_node; | |||
| } | |||
| auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { | |||
| return *m_accessor; | |||
| } | |||
| ValueRefList CompiledTransformation::apply_op( | |||
| const ApplyOp& apply_op, Span<ValueRef> inputs) { | |||
| auto& item = next_instruction(); | |||
| trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); | |||
| trace_assert(apply_op.op().is_same(*item.op), "operator mismatch"); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| trace_input(item.inputs[i], inputs[i]); | |||
| } | |||
| ValueRefList outputs(item.outputs.size()); | |||
| for (size_t i = 0; i < item.outputs.size(); ++i) { | |||
| outputs[i] = trace_output(item.outputs[i]); | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList CompiledTransformation::apply_get_attr( | |||
| const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
| if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||
| ValueRef output; | |||
| auto& var_accessor = traced_value->accessor(); | |||
| switch (get_attr.attr()) { | |||
| case GetAttr::Shape: | |||
| output = traced_value->shape(); | |||
| break; | |||
| case GetAttr::Data: | |||
| trace_assert(var_accessor.data_getter, "data unreadable"); | |||
| output = DeviceValue::make(var_accessor.data_getter()); | |||
| break; | |||
| case GetAttr::Value: | |||
| trace_assert(var_accessor.value_getter, "value unreadable"); | |||
| output = HostValue::make(var_accessor.value_getter()); | |||
| break; | |||
| case GetAttr::DType: | |||
| output = traced_value->dtype(); | |||
| break; | |||
| case GetAttr::Device: | |||
| output = traced_value->comp_node(); | |||
| default: | |||
| break; | |||
| } | |||
| return {output}; | |||
| } else { | |||
| return imperative::apply(get_attr, inputs); | |||
| } | |||
| } | |||
| ValueRefList CompiledTransformation::apply_create_tensor( | |||
| const CreateTensor& create_tensor, Span<ValueRef> inputs) { | |||
| if (create_tensor.kind() == CreateTensor::NoTrace) { | |||
| return imperative::apply(create_tensor, inputs); | |||
| } | |||
| auto& item = next_instruction(); | |||
| trace_assert(item.op == nullptr, "operator mismatch"); | |||
| auto input_id = item.inputs[0]; | |||
| auto output_id = item.outputs[0]; | |||
| auto tensor = imperative::apply(create_tensor, inputs)[0]; | |||
| trace_input(input_id, tensor); | |||
| return {trace_output(output_id)}; | |||
| } | |||
| ValueRefList CompiledTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| if (auto* op_value = op.as<ApplyOp>()) { | |||
| auto& item = next_instruction(); | |||
| SmallVector<ValueRef> unwrapped_inputs; | |||
| SmallVector<ValueRef> wrapped_inputs; | |||
| trace_assert(inputs.size() == item.inputs.size(), "input size mismatch"); | |||
| trace_assert(op_value->op().is_same(*item.op), "operator mismatch"); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| trace_input(item.inputs[i], inputs[i]); | |||
| } | |||
| std::vector<ValueRef> outputs; | |||
| for (auto&& output_id : item.outputs) { | |||
| outputs.push_back(trace_output(output_id)); | |||
| } | |||
| return outputs; | |||
| return apply_op(*op_value, inputs); | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| if (create_tensor->kind() == CreateTensor::NoTrace) { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| auto& item = next_instruction(); | |||
| trace_assert(item.op == nullptr, "operator mismatch"); | |||
| auto input_id = item.inputs[0]; | |||
| auto output_id = item.outputs[0]; | |||
| auto tensor = imperative::apply(op, inputs)[0]; | |||
| trace_input(input_id, tensor); | |||
| return {trace_output(output_id)}; | |||
| return apply_create_tensor(*create_tensor, inputs); | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||
| ValueRef output; | |||
| auto& var = m_vars[traced_value->id()]; | |||
| auto& var_accessor = m_var_accessors[traced_value->id()]; | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::Shape: | |||
| trace_assert(var_accessor.shape_getter, "shape unreadable"); | |||
| output = ShapeValue::make( | |||
| ValueShape::from(var_accessor.shape_getter())); | |||
| break; | |||
| case GetAttr::Data: | |||
| trace_assert(var_accessor.data_getter, "data unreadable"); | |||
| output = DeviceValue::make(var_accessor.data_getter()); | |||
| break; | |||
| case GetAttr::Value: | |||
| trace_assert(var_accessor.value_getter, "value unreadable"); | |||
| output = HostValue::make(var_accessor.value_getter()); | |||
| break; | |||
| case GetAttr::DType: | |||
| output = DTypeValue::make(var.dtype); | |||
| break; | |||
| case GetAttr::Device: | |||
| output = CompNodeValue::make(var.device); | |||
| default: | |||
| break; | |||
| } | |||
| return {output}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| return apply_get_attr(*get_attr, inputs); | |||
| } else if (auto* trace_mark_var = op.as<TraceMarkVar>()) { | |||
| auto& item = next_instruction(); | |||
| trace_assert(item.op == nullptr, "operator mismatch"); | |||
| @@ -8,50 +8,58 @@ namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| static thread_local size_t nr_watched_values = 0; | |||
| static thread_local uint64_t nr_values = 0; | |||
| static thread_local bool recording_values = false; | |||
| static thread_local std::vector<ValueWeakRef> recorded_values; | |||
| static /*thread_local*/ size_t nr_watched_values = 0; | |||
| static /*thread_local*/ uint64_t nr_values = 0; | |||
| static /*thread_local*/ bool recording_values = false; | |||
| static /*thread_local*/ std::vector<ValueWeakRef> recorded_values; | |||
| static WeakValueMap<uint64_t, ValueWeakRef> registered_values; | |||
| } // namespace | |||
| ValueRef::storage_t& ValueRef::storage() const { | |||
| if (!m_storage) { | |||
| if (mgb_likely(!m_storage->m_successor.m_storage)) { | |||
| return m_storage; | |||
| } | |||
| if (auto& storage = m_storage->m_successor.m_storage) { | |||
| while (storage->m_successor.m_storage) { | |||
| storage = storage->m_successor.m_storage; | |||
| } | |||
| return storage; | |||
| } else { | |||
| return m_storage; | |||
| while (m_storage->m_successor.m_storage) { | |||
| m_storage = m_storage->m_successor.m_storage; | |||
| } | |||
| return m_storage; | |||
| } | |||
| const Value* ValueRef::as(size_t typecode) const { | |||
| auto&& storage = this->storage(); | |||
| if (storage->m_typecode != typecode) { | |||
| return nullptr; | |||
| } | |||
| return static_cast<Value*>(storage.get()); | |||
| } | |||
| bool ValueRef::is(size_t typecode) const { | |||
| return this->storage()->m_typecode == typecode; | |||
| } | |||
| TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { | |||
| return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref<DeviceValue>(); | |||
| return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref<DeviceValue>(); | |||
| } | |||
| TypedValueRef<HostValue> ValueRef::numpy() const { | |||
| return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref<HostValue>(); | |||
| return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref<HostValue>(); | |||
| } | |||
| TypedValueRef<CompNodeValue> ValueRef::device() const { | |||
| return imperative::apply(GetAttr(GetAttr::Device), *this)[0] | |||
| .as_ref<CompNodeValue>(); | |||
| .cast_ref<CompNodeValue>(); | |||
| } | |||
| TypedValueRef<ShapeValue> ValueRef::shape() const { | |||
| return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref<ShapeValue>(); | |||
| return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref<ShapeValue>(); | |||
| } | |||
| TypedValueRef<DTypeValue> ValueRef::dtype() const { | |||
| return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref<DTypeValue>(); | |||
| return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>(); | |||
| } | |||
| TypedValueRef<StringValue> ValueRef::name() const { | |||
| return imperative::apply(GetName(), *this)[0].as_ref<StringValue>(); | |||
| return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>(); | |||
| } | |||
| bool ValueRef::is_scalar() const { | |||
| @@ -75,13 +83,15 @@ void ValueRef::unwatch() const { | |||
| } | |||
| ValueRef ValueRef::unwrap() const { | |||
| ValueRef value = *this; | |||
| auto& context = Transformation::get_context(); | |||
| for (size_t i = 0; i < context.next_transformation; ++i) { | |||
| value = context.transformations[i]->unwrap(value); | |||
| if (mgb_unlikely(context.next_transformation)) { | |||
| ValueRef value = *this; | |||
| for (size_t i = 0; i < context.next_transformation; ++i) { | |||
| value = context.transformations[i]->unwrap(value); | |||
| } | |||
| return value; | |||
| } | |||
| mgb_assert(value); | |||
| return value; | |||
| return *this; | |||
| } | |||
| std::string ValueRef::to_string() const { | |||
| @@ -101,13 +111,11 @@ std::string ValueRef::raw_type() const { | |||
| return types[m_storage->m_typecode].name(); | |||
| } | |||
| uint64_t ValueRef::id() const { | |||
| return m_storage ? m_storage->m_id : std::numeric_limits<uint64_t>::max(); | |||
| } | |||
| bool ValueRef::watching() const { | |||
| auto storage = this->storage(); | |||
| return storage && storage->m_watching; | |||
| if (!m_storage) { | |||
| return false; | |||
| } | |||
| return this->storage()->m_watching; | |||
| } | |||
| ValueRef ValueRef::make(ValueRef::storage_t storage) { | |||
| @@ -186,5 +194,96 @@ void Value::try_rethrow() { | |||
| } | |||
| } | |||
| inline void ValueRefList::init(size_t nr_elems) { | |||
| m_size = nr_elems; | |||
| if (m_size > 0) { | |||
| if (m_size == 1) { | |||
| m_data = inline_storage(); | |||
| } else { | |||
| auto& context = Transformation::get_context(); | |||
| m_data = context.allocator.allocate(m_size); | |||
| } | |||
| for (size_t i = 0; i < m_size; ++i) { | |||
| new (m_data + i) ValueRef(); | |||
| } | |||
| } else { | |||
| m_data = nullptr; | |||
| } | |||
| } | |||
| ValueRefList::ValueRefList(size_t nr_elems) { | |||
| init(nr_elems); | |||
| } | |||
| ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||
| : ValueRefList(values.begin(), values.end()) {} | |||
| ValueRefList::ValueRefList(const ValueRefList& rhs) | |||
| : ValueRefList(rhs.cbegin(), rhs.cend()) {} | |||
| ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() { | |||
| m_size = rhs.m_size; | |||
| if (rhs.m_data == rhs.inline_storage()) { | |||
| m_data = inline_storage(); | |||
| new (m_data) ValueRef(); | |||
| m_data[0] = std::move(rhs.m_data[0]); | |||
| } else { | |||
| m_data = rhs.m_data; | |||
| rhs.m_data = nullptr; | |||
| rhs.m_size = 0; | |||
| } | |||
| } | |||
| ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) { | |||
| if (this == &rhs) { | |||
| return *this; | |||
| } | |||
| clear(); | |||
| init(rhs.m_size); | |||
| for (size_t i = 0; i < m_size; ++i) { | |||
| m_data[i] = rhs.m_data[i]; | |||
| } | |||
| return *this; | |||
| } | |||
| ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) { | |||
| if (this == &rhs) { | |||
| return *this; | |||
| } | |||
| clear(); | |||
| if (rhs.m_data == rhs.inline_storage()) { | |||
| m_data = inline_storage(); | |||
| new (m_data) ValueRef(); | |||
| m_data[0] = rhs.m_data[0]; | |||
| m_size = 1; | |||
| rhs.clear(); | |||
| } else { | |||
| m_data = rhs.m_data; | |||
| m_size = rhs.m_size; | |||
| rhs.m_data = nullptr; | |||
| rhs.m_size = 0; | |||
| } | |||
| return *this; | |||
| } | |||
| ValueRefList::~ValueRefList() { | |||
| clear(); | |||
| } | |||
| void ValueRefList::clear() { | |||
| for (size_t i = 0; i < m_size; ++i) { | |||
| m_data[i].~ValueRef(); | |||
| } | |||
| if (m_data) { | |||
| if (m_size != 1) { | |||
| Transformation::get_context().allocator.deallocate(m_data, m_size); | |||
| } else { | |||
| mgb_assert(m_data == inline_storage()); | |||
| } | |||
| } | |||
| m_data = nullptr; | |||
| m_size = 0; | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -24,8 +24,6 @@ namespace imperative { | |||
| class GradKey; | |||
| using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
| /** | |||
| * \brief apply an OpDef to values | |||
| * | |||
| @@ -37,7 +35,7 @@ private: | |||
| public: | |||
| ApplyOp(const OpDef& op) : m_op(op) {} | |||
| const OpDef& op() { return m_op; } | |||
| const OpDef& op() const { return m_op; } | |||
| std::string to_string() const override; | |||
| }; | |||
| @@ -106,7 +104,7 @@ public: | |||
| * \param inputs contains host_storage and device_storage | |||
| * \return Args unpacked args | |||
| */ | |||
| Args parse(Span<ValueRef> inputs); | |||
| Args parse(Span<ValueRef> inputs) const; | |||
| Kind kind() const { return m_kind; } | |||
| CompNode device() const { return m_device; } | |||
| @@ -129,11 +127,11 @@ private: | |||
| public: | |||
| DTRCommand(Kind kind) : m_kind(kind) {} | |||
| Kind kind() { return m_kind; } | |||
| Kind kind() const { return m_kind; } | |||
| std::string to_string() const override; | |||
| std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { return {}; } | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { return {}; } | |||
| }; | |||
| // deprecated | |||
| @@ -141,9 +139,7 @@ class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> { | |||
| public: | |||
| std::string to_string() const override; | |||
| std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
| return {ValueRef()}; | |||
| } | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } | |||
| }; | |||
| /** | |||
| @@ -161,7 +157,7 @@ public: | |||
| std::string to_string() const override; | |||
| std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
| return {inputs.as_array<1>()[0]}; | |||
| } | |||
| }; | |||
| @@ -23,7 +23,7 @@ namespace imperative { | |||
| class GradKey; | |||
| using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
| using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||
| class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> { | |||
| public: | |||
| @@ -97,6 +97,10 @@ public: | |||
| ValueShape shape() const { return m_shape; } | |||
| CompNode device() const { return m_storage.comp_node(); } | |||
| HostTensorStorage storage() const { return m_storage; } | |||
| DTypeScalar item() const { | |||
| mgb_assert(m_shape.is_scalar()); | |||
| return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr()); | |||
| } | |||
| HostTensorND as_nd(bool allow_scalar = false) const; | |||
| }; | |||
| @@ -36,11 +36,11 @@ namespace imperative { | |||
| * | |||
| * \param op | |||
| * \param inputs | |||
| * \return std::vector<ValueRef> | |||
| * \return ValueRefList | |||
| */ | |||
| std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
| std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs); | |||
| std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs); | |||
| ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
| ValueRefList apply(const OpDef& def, Span<ValueRef> inputs); | |||
| ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs); | |||
| template <typename... TArgs> | |||
| constexpr bool is_all_value_ref_v = | |||
| @@ -49,7 +49,7 @@ constexpr bool is_all_value_ref_v = | |||
| template <typename T, typename... TArgs> | |||
| static auto apply(T&& op, TArgs&&... args) | |||
| -> std::enable_if_t<is_all_value_ref_v<TArgs...>, std::vector<ValueRef>> { | |||
| -> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> { | |||
| ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...}; | |||
| return imperative::apply( | |||
| std::forward<T&&>(op), | |||
| @@ -63,7 +63,7 @@ static auto apply(T&& op, TContainer&& container) -> std::enable_if_t< | |||
| ValueRef> && | |||
| std::is_same_v<decltype(container.size()), size_t> && | |||
| !std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>, | |||
| std::vector<ValueRef>> { | |||
| ValueRefList> { | |||
| return imperative::apply( | |||
| std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size())); | |||
| } | |||
| @@ -25,6 +25,8 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||
| /** | |||
| * \brief base class for all operators | |||
| * | |||
| @@ -49,25 +51,24 @@ public: | |||
| Kind kind() const { return m_kind; } | |||
| template <typename U> | |||
| U* as() const { | |||
| const U* as() const { | |||
| if (m_typecode != U::TYPE_CODE) { | |||
| return nullptr; | |||
| } | |||
| return static_cast<U*>(const_cast<Operator*>(this)); | |||
| return static_cast<const U*>(this); | |||
| } | |||
| template <typename U> | |||
| bool is() const { | |||
| return as<U>() != nullptr; | |||
| return m_typecode == U::TYPE_CODE; | |||
| } | |||
| template <Kind kKind> | |||
| bool is() const { | |||
| return kind() == kKind; | |||
| } | |||
| template <typename U> | |||
| U& cast() const { | |||
| U* ptr = as<U>(); | |||
| mgb_assert(ptr); | |||
| return *ptr; | |||
| const U& cast() const { | |||
| mgb_assert(m_typecode == U::TYPE_CODE); | |||
| return static_cast<const U&>(*this); | |||
| } | |||
| virtual std::string to_string() const = 0; | |||
| @@ -77,9 +78,9 @@ public: | |||
| * implementation. | |||
| * | |||
| * \param inputs | |||
| * \return std::vector<ValueRef> | |||
| * \return ValueRefList | |||
| */ | |||
| virtual std::vector<ValueRef> fallback(Span<ValueRef> inputs) const; | |||
| virtual ValueRefList fallback(Span<ValueRef> inputs) const; | |||
| std::type_index type() const { return registered_types()[m_typecode]; } | |||
| @@ -123,7 +123,6 @@ public: | |||
| template <typename T, typename... TArgs> | |||
| static uint64_t record(TArgs&&... args) { | |||
| auto& profiler = get_instance(); | |||
| // auto& mem_pool = get_mem_pool<T>(); | |||
| if constexpr (sm_debug) { | |||
| Status expected = Running; | |||
| mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording)); | |||
| @@ -18,6 +18,7 @@ | |||
| #include "megbrain/common.h" | |||
| #include "megbrain/imperative/subgraph.h" | |||
| #include "megbrain/imperative/utils/allocator.h" | |||
| #include "megbrain/imperative/utils/local_ptr.h" | |||
| #include "megbrain/imperative/utils/span.h" | |||
| @@ -25,6 +26,7 @@ namespace mgb { | |||
| namespace imperative { | |||
| class ValueRef; | |||
| class ValueRefList; | |||
| class Operator; | |||
| class Transformation; | |||
| @@ -43,6 +45,7 @@ struct TransformationContext { | |||
| // TODO: deprecate TransformationGuard, let next_transformation == frames.size() | |||
| size_t next_transformation = 0; | |||
| std::vector<TransformationFrame> frames; | |||
| ForwardAllocator<ValueRef> allocator; | |||
| }; | |||
| /** | |||
| @@ -86,9 +89,9 @@ public: | |||
| * | |||
| * \param op | |||
| * \param inputs | |||
| * \return std::vector<ValueRef> | |||
| * \return ValueRefList | |||
| */ | |||
| virtual std::vector<ValueRef> apply_transformation( | |||
| virtual ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) = 0; | |||
| virtual ValueRef unwrap(ValueRef value) = 0; | |||
| @@ -187,11 +190,12 @@ public: | |||
| std::swap(context.transformations, current_context.transformations); | |||
| std::swap(context.scopes, current_context.scopes); | |||
| std::swap(context.next_transformation, current_context.next_transformation); | |||
| std::swap(context.allocator, current_context.allocator); | |||
| } | |||
| static TransformationContext& get_context(); | |||
| friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
| friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
| friend class ValueRef; | |||
| }; | |||
| @@ -23,16 +23,38 @@ public: | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| using Channel = interpreter::Interpreter::Channel; | |||
| class RAIIHandle : public NonCopyableObj { | |||
| private: | |||
| Handle m_handle = nullptr; | |||
| Channel* m_channel = nullptr; | |||
| public: | |||
| RAIIHandle(Handle handle, Channel* channel) | |||
| : m_handle(handle), m_channel(channel) {} | |||
| ~RAIIHandle() { m_channel->del(m_handle); } | |||
| Handle handle() const { return m_handle; } | |||
| Channel* channel() const { return m_channel; } | |||
| }; | |||
| private: | |||
| std::shared_ptr<Handle> m_handle = nullptr; | |||
| LocalPtr<RAIIHandle> m_handle; | |||
| std::string m_name; | |||
| mutable DTypeValue::ref_t m_dtype; | |||
| mutable CompNodeValue::ref_t m_comp_node; | |||
| mutable ShapeValue::ref_t m_shape; | |||
| public: | |||
| InterpreterInfo() = default; | |||
| InterpreterInfo(std::shared_ptr<Handle> handle, std::string name = {}) | |||
| InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {}) | |||
| : m_handle(handle), m_name(name) {} | |||
| std::shared_ptr<Handle> handle() const { return m_handle; } | |||
| const LocalPtr<RAIIHandle>& handle() const { return m_handle; } | |||
| DTypeValue::ref_t dtype() const; | |||
| CompNodeValue::ref_t comp_node() const; | |||
| ShapeValue::ref_t shape() const; | |||
| std::string name() const { return m_name; } | |||
| }; | |||
| @@ -60,6 +82,7 @@ class InterpreterTransformation final : public Transformation { | |||
| public: | |||
| using Interpreter = interpreter::Interpreter; | |||
| using Handle = Interpreter::Handle; | |||
| using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>; | |||
| using Channel = Interpreter::Channel; | |||
| private: | |||
| @@ -71,7 +94,14 @@ public: | |||
| Channel* channel() { return m_channel.get(); } | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs); | |||
| ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
| ValueRefList apply_create_tensor( | |||
| const CreateTensor& create_tensor, Span<ValueRef> inputs); | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| @@ -81,14 +111,8 @@ public: | |||
| std::string name() const override { return "InterpreterTransformation"; } | |||
| std::shared_ptr<Handle> share_handle(Handle handle) { | |||
| return std::shared_ptr<Handle>( | |||
| new Handle(handle), [channel = m_channel.get()](Handle* ptr) { | |||
| if (ptr) { | |||
| channel->del(*ptr); | |||
| delete ptr; | |||
| } | |||
| }); | |||
| SharedHandle share_handle(Handle handle) { | |||
| return SharedHandle::make(handle, m_channel.get()); | |||
| } | |||
| }; | |||
| @@ -34,9 +34,7 @@ struct BackwardGraphWithClosure { | |||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | |||
| std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs); | |||
| void operator()( | |||
| std::vector<ValueRef> grads, | |||
| std::function<void(size_t, ValueRef)> receiver); | |||
| void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||
| bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | |||
| @@ -50,12 +48,11 @@ struct BackwardGraphWithClosure { | |||
| struct CustomBackward; | |||
| using GradRuleFn = | |||
| std::function<std::vector<ValueRef>(Span<ValueRef> inputs, CustomBackward&)>; | |||
| using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>; | |||
| struct CustomBackward { | |||
| using BackwardFn = std::function<std::vector<ValueRef>(Span<ValueRef>)>; | |||
| using BackwardRule = std::function<std::optional<std::vector<ValueRef>>( | |||
| using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>; | |||
| using BackwardRule = std::function<std::optional<ValueRefList>( | |||
| const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>; | |||
| BackwardFn m_backward; | |||
| SmallVector<bool, 8> m_input_has_grad; | |||
| @@ -65,9 +62,7 @@ struct CustomBackward { | |||
| SmallVector<OutputAttr> m_output_attrs; | |||
| public: | |||
| void operator()( | |||
| std::vector<ValueRef> grads, | |||
| std::function<void(size_t, ValueRef)> receiver); | |||
| void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||
| bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | |||
| bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | |||
| @@ -188,7 +183,7 @@ public: | |||
| std::string to_string() const override; | |||
| bool has_key(std::shared_ptr<GradKey> key) const { return m_key == key; } | |||
| bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; } | |||
| const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const { | |||
| mgb_assert(m_key == key); | |||
| @@ -287,7 +282,7 @@ public: | |||
| return false; | |||
| } | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| @@ -314,7 +309,7 @@ private: | |||
| public: | |||
| std::string to_string() const override { return "DetachValue"; } | |||
| std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
| return {inputs.as_array<1>()[0]}; | |||
| } | |||
| }; | |||
| @@ -325,7 +320,7 @@ private: | |||
| public: | |||
| AttachGrad(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
| std::shared_ptr<GradKey> key() { return m_key; } | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| std::string to_string() const override { | |||
| return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str()); | |||
| @@ -339,7 +334,7 @@ private: | |||
| public: | |||
| GradBackward(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
| std::shared_ptr<GradKey> key() { return m_key; } | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| std::string to_string() const override { | |||
| return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str()); | |||
| @@ -352,13 +347,13 @@ private: | |||
| public: | |||
| IsAttachedTo(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
| std::shared_ptr<GradKey> key() { return m_key; } | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| std::string to_string() const override { | |||
| return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str()); | |||
| } | |||
| std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { | |||
| return {BoolValue::make(false)}; | |||
| } | |||
| }; | |||
| @@ -373,9 +368,9 @@ public: | |||
| SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs) | |||
| : m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
| GenericFunction grad_fn() { return m_grad_fn; } | |||
| GenericFunction grad_fn() const { return m_grad_fn; } | |||
| size_t nr_inputs() { return m_nr_inputs; } | |||
| size_t nr_inputs() const { return m_nr_inputs; } | |||
| std::string to_string() const override { | |||
| return ssprintf("SetGradValue{key=%s}", m_key->name().c_str()); | |||
| @@ -388,9 +383,7 @@ public: | |||
| std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); } | |||
| std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { | |||
| return {ValueRef()}; | |||
| } | |||
| ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; } | |||
| }; | |||
| class GetBackwardColsure | |||
| @@ -401,7 +394,7 @@ private: | |||
| public: | |||
| GetBackwardColsure(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
| std::shared_ptr<GradKey> key() { return m_key; } | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| std::string to_string() const override { | |||
| return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str()); | |||
| @@ -81,7 +81,7 @@ public: | |||
| ComputingGraph::Options& options() { return m_graph->options(); } | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| @@ -11,6 +11,7 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/basic_operators.h" | |||
| #include "megbrain/imperative/dispatch.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| @@ -45,8 +46,10 @@ public: | |||
| */ | |||
| class ScalarTransformation final : public Transformation { | |||
| private: | |||
| ShapeValue::ref_t m_empty_shape; // [] | |||
| public: | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| @@ -50,7 +50,7 @@ private: | |||
| public: | |||
| SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override { | |||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||
| SmallVector<VarNode*> input_nodes; | |||
| @@ -58,9 +58,9 @@ public: | |||
| input_nodes.push_back(input.cast<SymbolValue>().node()); | |||
| } | |||
| auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | |||
| std::vector<ValueRef> outputs; | |||
| for (auto&& output_node : output_nodes) { | |||
| outputs.push_back(SymbolValue::make(output_node)); | |||
| ValueRefList outputs(output_nodes.size()); | |||
| for (size_t i = 0; i < output_nodes.size(); ++i) { | |||
| outputs[i] = SymbolValue::make(output_nodes[i]); | |||
| } | |||
| return outputs; | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * \file imperative/src/include/megbrain/imperative/grad.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/imperative/basic_operators.h" | |||
| #include "megbrain/imperative/operator.h" | |||
| #include "megbrain/imperative/transformation.h" | |||
| #include "megbrain/imperative/value.h" | |||
| namespace mgb::imperative { | |||
| struct TangentInfo { | |||
| ValueRef value; | |||
| ValueRef tangent; | |||
| }; | |||
| class TangentTransformation final : public Transformation { | |||
| public: | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { mgb_assert(false); } | |||
| std::string name() const override { return "Tangent"; } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -126,25 +126,6 @@ public: | |||
| void on_unwatch() override { value().unwatch(); } | |||
| }; | |||
| class TracedInfo { | |||
| private: | |||
| size_t m_id = 0; | |||
| public: | |||
| TracedInfo() = default; | |||
| TracedInfo(size_t id) : m_id(id) {} | |||
| size_t id() const { return m_id; } | |||
| }; | |||
| class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const override { | |||
| return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||
| } | |||
| }; | |||
| /** | |||
| * \brief trace operation sequence to TraceResult | |||
| * | |||
| @@ -202,7 +183,7 @@ public: | |||
| return value; | |||
| } | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| @@ -248,6 +229,40 @@ public: | |||
| std::function<DeviceTensorND()> data_getter; | |||
| std::function<HostTensorND()> value_getter; | |||
| std::function<void(DeviceTensorND)> data_setter; | |||
| std::function<void(std::exception_ptr)> exc_setter; | |||
| }; | |||
| class TracedInfo { | |||
| private: | |||
| size_t m_id = 0; | |||
| VarInfo* m_var = nullptr; | |||
| VarAccessor* m_accessor = nullptr; | |||
| mutable ShapeValue::ref_t m_shape; | |||
| mutable DTypeValue::ref_t m_dtype; | |||
| mutable CompNodeValue::ref_t m_comp_node; | |||
| public: | |||
| TracedInfo() = default; | |||
| TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor) | |||
| : m_id(id), m_var(var), m_accessor(accessor) {} | |||
| size_t id() const { return m_id; } | |||
| ShapeValue::ref_t shape() const; | |||
| DTypeValue::ref_t dtype() const; | |||
| CompNodeValue::ref_t comp_node() const; | |||
| const VarAccessor& accessor() const; | |||
| void set_exception(std::exception_ptr exc) const { | |||
| m_accessor->exc_setter(exc); | |||
| } | |||
| }; | |||
| class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const override { | |||
| return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||
| } | |||
| }; | |||
| private: | |||
| @@ -319,7 +334,14 @@ public: | |||
| TraceResult::SeqItem& next_instruction(); | |||
| std::vector<ValueRef> apply_transformation( | |||
| ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs); | |||
| ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
| ValueRefList apply_create_tensor( | |||
| const CreateTensor& create_tensor, Span<ValueRef> inputs); | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| void on_unregister() noexcept override; | |||
| @@ -36,12 +36,12 @@ private: | |||
| public: | |||
| Allocator(pool_type* pool) : m_pool(pool) {} | |||
| T* allocate(size_type n) { | |||
| pointer allocate(size_type n) { | |||
| mgb_assert(n == 1); | |||
| return m_pool->alloc(sizeof(T)); | |||
| } | |||
| void deallocate(pointer* p, size_type n) { | |||
| void deallocate(pointer p, size_type n) { | |||
| mgb_assert(n == 1); | |||
| m_pool->free(p); | |||
| } | |||
| @@ -68,4 +68,114 @@ public: | |||
| bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| template <typename T> | |||
| class ForwardAllocator { | |||
| public: | |||
| using value_type = T; | |||
| using size_type = std::size_t; | |||
| using pointer = T*; | |||
| static constexpr size_t alignment = alignof(T); | |||
| static constexpr size_t element_offset = | |||
| sizeof(T) + | |||
| ((sizeof(T) % alignment) ? 0 : (alignment - sizeof(T) % alignment)); | |||
| private: | |||
| struct Block { | |||
| std::unique_ptr<std::byte[]> data; | |||
| size_t size = 0; | |||
| size_t capacity = 0; | |||
| T* allocate(size_type n) { | |||
| static_assert(element_offset > std::max(alignment, sizeof(T))); | |||
| size_t begin = size; | |||
| size_t end = begin + element_offset * n; | |||
| if (end > capacity) { | |||
| return nullptr; | |||
| } | |||
| size = end; | |||
| return reinterpret_cast<T*>(data.get() + begin); | |||
| } | |||
| void reset() { size = 0; } | |||
| }; | |||
| std::vector<Block> m_used; | |||
| std::optional<Block> m_current; | |||
| size_t block_size = 16 * 1024 * 1024; | |||
| size_t nr_allocated = 0; | |||
| private: | |||
| Block allocate_block() { | |||
| block_size *= 2; | |||
| return Block{std::make_unique<std::byte[]>(block_size), 0, block_size}; | |||
| } | |||
| public: | |||
| pointer allocate(size_type n) { | |||
| if (!m_current) { | |||
| m_current.emplace(allocate_block()); | |||
| } | |||
| pointer pointer = m_current->allocate(n); | |||
| while (pointer == nullptr) { | |||
| m_used.push_back(allocate_block()); | |||
| std::swap(m_used.back(), *m_current); | |||
| pointer = m_current->allocate(n); | |||
| } | |||
| nr_allocated++; | |||
| return pointer; | |||
| } | |||
| void deallocate(pointer p, size_type n) { | |||
| mgb_assert(nr_allocated > 0); | |||
| nr_allocated--; | |||
| } | |||
| void clear() { | |||
| if (mgb_likely(m_used.empty())) { | |||
| // fastpath | |||
| if (m_current) { | |||
| m_current->reset(); | |||
| } | |||
| } else { | |||
| // trim | |||
| *m_current = allocate_block(); | |||
| m_used.clear(); | |||
| } | |||
| mgb_assert(nr_allocated == 0); | |||
| } | |||
| bool operator==(const ForwardAllocator& rhs) const { return &rhs == this; } | |||
| bool operator!=(const ForwardAllocator& rhs) const { return &rhs != this; } | |||
| }; | |||
| template <typename T, template <typename> typename TAllocator> | |||
| class ProxyAllocator { | |||
| public: | |||
| using value_type = T; | |||
| using size_type = typename TAllocator<T>::size_type; | |||
| using pointer = typename TAllocator<T>::pointer; | |||
| private: | |||
| TAllocator<T>* m_impl; | |||
| public: | |||
| T* allocate(size_type n) { return m_impl->allocate(n); } | |||
| void deallocate(pointer* p, size_type n) { return m_impl->deallocate(p, n); } | |||
| bool operator==(const ProxyAllocator<T, TAllocator>& rhs) const { | |||
| if (m_impl == rhs.m_impl) { | |||
| return true; | |||
| } else if (bool(m_impl) ^ bool(rhs.m_impl)) { | |||
| return false; | |||
| } else { | |||
| return *m_impl == *rhs.m_impl; | |||
| } | |||
| } | |||
| bool operator!=(const ProxyAllocator<T, TAllocator>& rhs) const { | |||
| return !((*this) == rhs); | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -16,6 +16,8 @@ | |||
| #include "megbrain/imperative/utils/mempool.h" | |||
| #include "megbrain/utils/metahelper.h" | |||
| #define MGB_FAT_LOCAL_PTR 0 | |||
| namespace mgb::imperative { | |||
| template <typename T> | |||
| @@ -52,6 +54,8 @@ private: | |||
| } | |||
| } | |||
| size_t ref_count() const { return m_ref_count; } | |||
| template <typename U> | |||
| friend class LocalPtr; | |||
| @@ -88,14 +92,24 @@ public: | |||
| using storage_t = LocalPtrStorage<T>; | |||
| using pool_t = MemPool<storage_t>; | |||
| using weak_type = LocalWeakPtr<T>; | |||
| using pointer_t = T*; | |||
| private: | |||
| storage_t* m_storage = nullptr; | |||
| #if MGB_FAT_LOCAL_PTR | |||
| pointer_t m_pointer = nullptr; | |||
| #endif | |||
| // (m_storage == nullptr) == (m_pointer == nullptr) | |||
| void emplace(storage_t* ptr) { | |||
| if (ptr) { | |||
| ptr->inc_ref(); | |||
| m_storage = ptr; | |||
| #if MGB_FAT_LOCAL_PTR | |||
| m_pointer = ptr->m_pointer; | |||
| #endif | |||
| } | |||
| } | |||
| @@ -103,8 +117,22 @@ private: | |||
| public: | |||
| LocalPtr() = default; | |||
| LocalPtr(const LocalPtr& rhs) { (*this) = rhs; } | |||
| LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); } | |||
| LocalPtr(const LocalPtr& rhs) { | |||
| auto storage = rhs.m_storage; | |||
| if (storage) { | |||
| storage->inc_ref(); | |||
| } | |||
| m_storage = storage; | |||
| #if MGB_FAT_LOCAL_PTR | |||
| m_pointer = rhs.m_pointer; | |||
| #endif | |||
| } | |||
| LocalPtr(LocalPtr&& rhs) { | |||
| std::swap(m_storage, rhs.m_storage); | |||
| #if MGB_FAT_LOCAL_PTR | |||
| std::swap(m_pointer, rhs.m_pointer); | |||
| #endif | |||
| } | |||
| LocalPtr& operator=(const LocalPtr& rhs) { | |||
| if (this == &rhs) { | |||
| return *this; | |||
| @@ -115,9 +143,11 @@ public: | |||
| } | |||
| if (m_storage) { | |||
| m_storage->dec_ref(); | |||
| // rhs.m_storage may be invalid here | |||
| } | |||
| m_storage = storage; | |||
| #if MGB_FAT_LOCAL_PTR | |||
| m_pointer = rhs.m_pointer; | |||
| #endif | |||
| return *this; | |||
| } | |||
| LocalPtr& operator=(LocalPtr&& rhs) { | |||
| @@ -125,6 +155,9 @@ public: | |||
| return *this; | |||
| } | |||
| std::swap(m_storage, rhs.m_storage); | |||
| #if MGB_FAT_LOCAL_PTR | |||
| std::swap(m_pointer, rhs.m_pointer); | |||
| #endif | |||
| rhs.reset(); | |||
| return *this; | |||
| } | |||
| @@ -186,10 +219,11 @@ public: | |||
| T& operator*() const { return *get(); } | |||
| T* get() const { | |||
| if ((!m_storage) || !m_storage->m_pointer) { | |||
| return nullptr; | |||
| } | |||
| return m_storage->m_pointer; | |||
| #if MGB_FAT_LOCAL_PTR | |||
| return m_pointer; | |||
| #else | |||
| return m_storage ? m_storage->m_pointer : nullptr; | |||
| #endif | |||
| } | |||
| T* operator->() const { return get(); } | |||
| @@ -202,6 +236,9 @@ public: | |||
| if (m_storage) { | |||
| m_storage->dec_ref(); | |||
| m_storage = nullptr; | |||
| #if MGB_FAT_LOCAL_PTR | |||
| m_pointer = nullptr; | |||
| #endif | |||
| } | |||
| } | |||
| @@ -49,8 +49,8 @@ public: | |||
| instance = std::make_unique<MemPool<T>>(); | |||
| sm_instance = instance.get(); | |||
| } | |||
| mgb_assert(sm_instance); | |||
| } | |||
| return *sm_instance; | |||
| } | |||
| }; | |||
| @@ -62,9 +62,9 @@ std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>> | |||
| MemPoolUtils<T>::sm_instances; | |||
| template <typename T> | |||
| thread_local MemPool<T>* MemPoolUtils<T>::tm_instance; | |||
| thread_local MemPool<T>* MemPoolUtils<T>::tm_instance = nullptr; | |||
| template <typename T> | |||
| MemPool<T>* MemPoolUtils<T>::sm_instance; | |||
| MemPool<T>* MemPoolUtils<T>::sm_instance = nullptr; | |||
| } // namespace mgb::imperative | |||
| } // namespace mgb::imperative | |||
| @@ -95,6 +95,8 @@ struct ValueShape { | |||
| } | |||
| return true; | |||
| } | |||
| bool operator!=(const ValueShape& rhs) const { return !operator==(rhs); } | |||
| }; | |||
| static_assert(sizeof(size_t) >= sizeof(int)); | |||
| @@ -47,6 +47,17 @@ class StringValue; | |||
| class Operator; | |||
| class ValueRefList; | |||
| template <typename T> | |||
| class Type { | |||
| private: | |||
| const size_t m_code = T::TYPE_CODE; | |||
| public: | |||
| inline size_t code() const { return m_code; } | |||
| }; | |||
| /** | |||
| * \brief an smart reference of value | |||
| * | |||
| @@ -64,8 +75,9 @@ public: | |||
| protected: | |||
| mutable storage_t m_storage; | |||
| size_t m_id = std::numeric_limits<size_t>::max(); | |||
| ValueRef(storage_t storage) { m_storage = storage; } | |||
| inline ValueRef(storage_t storage); | |||
| private: | |||
| /** | |||
| @@ -75,6 +87,10 @@ private: | |||
| */ | |||
| storage_t& storage() const; | |||
| const Value* as(size_t typecode) const; | |||
| bool is(size_t typecode) const; | |||
| public: | |||
| ValueRef() = default; | |||
| @@ -86,7 +102,7 @@ public: | |||
| * \return false if empty or type of value is not TValue | |||
| */ | |||
| template <typename TValue> | |||
| bool is() const; | |||
| inline bool is(Type<TValue> type = {}) const; | |||
| /** | |||
| * \brief try cast value as target type | |||
| @@ -95,7 +111,7 @@ public: | |||
| * \return TValue* raw pointer if success, otherwise nullptr | |||
| */ | |||
| template <typename TValue> | |||
| const TValue* as() const; | |||
| inline const TValue* as(Type<TValue> type = {}) const; | |||
| /** | |||
| * \brief cast value to target type | |||
| @@ -104,7 +120,7 @@ public: | |||
| * \return TValue& reference of value | |||
| */ | |||
| template <typename TValue> | |||
| const TValue& cast() const; | |||
| inline const TValue& cast(Type<TValue> type = {}) const; | |||
| /** | |||
| * \brief like as(), but returns TypedValueRef instead | |||
| @@ -113,7 +129,13 @@ public: | |||
| * \return TypedValueRef<TValue> reference if success, otherwise empty reference | |||
| */ | |||
| template <typename TValue> | |||
| inline TypedValueRef<TValue> as_ref() const; | |||
| inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const; | |||
| template <typename TValue> | |||
| inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const; | |||
| template <typename TValue> | |||
| void on_cast_failure() const; | |||
| operator bool() const { return bool(m_storage); } | |||
| @@ -132,7 +154,7 @@ public: | |||
| ValueRef unwrap() const; | |||
| std::string to_string() const; | |||
| std::string raw_type() const; | |||
| uint64_t id() const; | |||
| uint64_t id() const { return m_id; } | |||
| size_t hash() const { return id(); } | |||
| static ValueRef make(storage_t storage); | |||
| @@ -144,7 +166,7 @@ public: | |||
| friend class TypedValueRef; | |||
| template <typename T> | |||
| friend class ValueImpl; | |||
| friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs); | |||
| friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
| }; | |||
| template <> | |||
| @@ -244,7 +266,7 @@ public: | |||
| using ref_t = TypedValueRef<T>; | |||
| using weak_ref_t = TypedValueWeakRef<T>; | |||
| static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
| static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
| /** | |||
| * \brief helper function for construct a value | |||
| @@ -254,7 +276,7 @@ public: | |||
| * \return TypedValueRef<T> reference of value | |||
| */ | |||
| template <typename... TArgs> | |||
| static TypedValueRef<T> make(TArgs&&... args) { | |||
| static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) { | |||
| static_assert(std::is_final_v<T>); | |||
| return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...)); | |||
| } | |||
| @@ -279,46 +301,60 @@ public: | |||
| bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } | |||
| }; | |||
| inline ValueRef::ValueRef(storage_t storage) { | |||
| // mgb_assert(storage); | |||
| m_storage = storage; | |||
| m_id = m_storage->m_id; | |||
| } | |||
| template <typename TValue> | |||
| const TValue* ValueRef::as() const { | |||
| inline const TValue* ValueRef::as(Type<TValue> type) const { | |||
| static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>); | |||
| auto storage = this->storage(); | |||
| if (!storage) { | |||
| return nullptr; | |||
| } | |||
| if (storage->m_typecode != TValue::TYPE_CODE) { | |||
| return nullptr; | |||
| } | |||
| return static_cast<TValue*>(storage.get()); | |||
| return static_cast<const TValue*>(as(type.code())); | |||
| } | |||
| template <typename TValue> | |||
| const TValue& ValueRef::cast() const { | |||
| auto* ptr = as<TValue>(); | |||
| if (!ptr) { | |||
| // if this is ErrorValue, rethrow directly | |||
| storage()->try_rethrow(); | |||
| mgb_assert( | |||
| ptr, "expect type %s, got %s", typeid(TValue).name(), | |||
| to_string().c_str()); | |||
| inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||
| auto* ptr = as<TValue>(type); | |||
| if (mgb_unlikely(!ptr)) { | |||
| on_cast_failure<TValue>(); | |||
| } | |||
| return *ptr; | |||
| return static_cast<const TValue&>(*ptr); | |||
| } | |||
| template <typename TValue> | |||
| inline bool ValueRef::is(Type<TValue> type) const { | |||
| return is(type.code()); | |||
| } | |||
| template <typename TValue> | |||
| bool ValueRef::is() const { | |||
| auto* ptr = as<TValue>(); | |||
| return ptr != nullptr; | |||
| inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const { | |||
| if (!is<TValue>(type)) { | |||
| return {}; | |||
| } | |||
| return TypedValueRef<TValue>(*this); | |||
| } | |||
| template <typename TValue> | |||
| TypedValueRef<TValue> ValueRef::as_ref() const { | |||
| if (!is<TValue>()) { | |||
| inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const { | |||
| if (!m_storage) { | |||
| return {}; | |||
| } | |||
| if (mgb_unlikely(!is<TValue>(type))) { | |||
| on_cast_failure<TValue>(); | |||
| } | |||
| return TypedValueRef<TValue>(*this); | |||
| } | |||
| template <typename TValue> | |||
| void ValueRef::on_cast_failure() const { | |||
| // if this is ErrorValue, rethrow directly | |||
| storage()->try_rethrow(); | |||
| mgb_assert( | |||
| storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", | |||
| typeid(TValue).name(), to_string().c_str()); | |||
| } | |||
| /** | |||
| * \brief ValueRef with concrete type, convenient for dereference | |||
| * | |||
| @@ -361,11 +397,87 @@ private: | |||
| public: | |||
| TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} | |||
| TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} | |||
| TypedValueRef<T> lock() { return ValueWeakRef::lock().template as_ref<T>(); } | |||
| TypedValueRef<T> lock() { | |||
| auto value = ValueWeakRef::lock(); | |||
| if (value) { | |||
| return value.template as_ref<T>(); | |||
| } else { | |||
| return {}; | |||
| } | |||
| } | |||
| }; | |||
| // TODO: add proxy value type, which is meant to be reset in the end | |||
| class ValueRefList { | |||
| private: | |||
| ValueRef* m_data = nullptr; | |||
| size_t m_size = 0; | |||
| std::aligned_storage_t<sizeof(ValueRef), alignof(ValueRef)> m_storage; | |||
| private: | |||
| void init(size_t nr_elems); | |||
| ValueRef* inline_storage() { return reinterpret_cast<ValueRef*>(&m_storage); } | |||
| public: | |||
| ValueRefList() = default; | |||
| ValueRefList(size_t nr_elems); | |||
| ValueRefList(ValueRef item); | |||
| ValueRefList(std::initializer_list<ValueRef> values); | |||
| template <typename TIterator> | |||
| ValueRefList(TIterator begin, TIterator end); | |||
| ValueRefList(const ValueRefList& rhs); | |||
| ValueRefList(ValueRefList&& rhs); | |||
| ValueRefList& operator=(const ValueRefList& rhs); | |||
| ValueRefList& operator=(ValueRefList&& rhs); | |||
| ~ValueRefList(); | |||
| void clear(); | |||
| ValueRef* begin() { return m_data; } | |||
| ValueRef* end() { return m_data + m_size; } | |||
| const ValueRef* cbegin() const { return m_data; } | |||
| const ValueRef* cend() const { return m_data + m_size; } | |||
| size_t size() const { return m_size; } | |||
| ValueRef& at(size_t idx) { | |||
| mgb_assert(idx < m_size); | |||
| return m_data[idx]; | |||
| } | |||
| const ValueRef& at(size_t idx) const { | |||
| mgb_assert(idx < m_size); | |||
| return m_data[idx]; | |||
| } | |||
| ValueRef& operator[](size_t idx) { return m_data[idx]; } | |||
| const ValueRef& operator[](size_t idx) const { return m_data[idx]; } | |||
| ValueRef* data() { return m_data; } | |||
| const ValueRef* data() const { return m_data; } | |||
| bool empty() const { return m_size == 0; } | |||
| ValueRef& front() { | |||
| mgb_assert(m_size > 1); | |||
| return m_data[0]; | |||
| } | |||
| ValueRef& back() { | |||
| mgb_assert(m_size > 1); | |||
| return m_data[m_size - 1]; | |||
| } | |||
| }; | |||
| template <typename TIterator> | |||
| ValueRefList::ValueRefList(TIterator begin, TIterator end) : ValueRefList(end - begin) { | |||
| for (size_t i = 0; i < m_size; ++i) { | |||
| m_data[i] = *(begin + i); | |||
| } | |||
| } | |||
| inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_size(1) { | |||
| new (m_data) ValueRef(); | |||
| m_data[0] = std::move(item); | |||
| } | |||
| /*class ValueRefList : public SmallVector<ValueRef, 1> { | |||
| public: | |||
| using SmallVector::SmallVector; | |||
| };*/ | |||
| } // namespace imperative | |||
| } // namespace mgb | |||