GitOrigin-RevId: a72f5460b6
tags/v1.9.0
| @@ -15,9 +15,15 @@ import numpy as np | |||
| from .. import _config | |||
| from .._imperative_rt.common import CompNode | |||
| from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion | |||
| from .._imperative_rt.core2 import ( | |||
| SymbolVar, | |||
| Tensor, | |||
| apply, | |||
| broadcast_cpp, | |||
| dtype_promotion, | |||
| ) | |||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | |||
| from .._imperative_rt.core2 import squeeze_cpp, transpose_cpp | |||
| from .._imperative_rt.core2 import reshape_cpp, squeeze_cpp, transpose_cpp | |||
| from ..ops import builtin | |||
| from . import amp | |||
| from .indexing import getitem, setitem | |||
| @@ -331,70 +337,6 @@ def _matmul( | |||
| return result | |||
| def _broadcast(inp, shape): | |||
| auto_infer = False | |||
| if isinstance(shape, (list, tuple)): | |||
| shape_tuple = list(shape) | |||
| for i, s in enumerate(shape_tuple): | |||
| if isinstance(s, type(None)): | |||
| if s is None: | |||
| right = i - len(shape_tuple) | |||
| inp_shape = inp._tuple_shape | |||
| if len(inp_shape) + right >= 0: | |||
| shape_tuple[right] = list(inp_shape)[right] | |||
| auto_infer = True | |||
| continue | |||
| else: | |||
| raise ValueError("invalided Broadcast shape") | |||
| else: | |||
| raise ValueError( | |||
| "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||
| i, s | |||
| ) | |||
| ) | |||
| if s < 0: | |||
| raise ValueError( | |||
| "expect shape[{}] >= 0 or use `None` or 'x' and 'X' to auto infer, got {}".format( | |||
| i, s | |||
| ) | |||
| ) | |||
| if auto_infer: | |||
| shape = tuple(shape_tuple) | |||
| try: | |||
| shape_tuple = make_shape_tuple(shape) | |||
| except ValueError: | |||
| shape_tuple = shape | |||
| shape = astensor1d(shape_tuple, inp, dtype="int32", device=inp.device) | |||
| (result,) = apply(builtin.Broadcast(), inp, shape) | |||
| return result | |||
| def _reshape(x, shape): | |||
| unspec_axis = None | |||
| try: | |||
| shape_tuple = make_shape_tuple(shape) | |||
| except ValueError: | |||
| pass | |||
| else: | |||
| # XXX: assume unspec_axis is not changed in trace | |||
| for i, s in enumerate(shape_tuple): | |||
| if s < 0: | |||
| if s != -1: | |||
| raise ValueError("expect shape[{}] >= -1, got {}".format(i, s)) | |||
| if unspec_axis is not None: | |||
| raise ValueError( | |||
| "multiple -1 in shape: {} & {}".format(unspec_axis, i) | |||
| ) | |||
| unspec_axis = i | |||
| shape = astensor1d(shape, x, dtype="int32", device=x.device) | |||
| if unspec_axis is None: | |||
| op = builtin.Reshape() | |||
| else: | |||
| op = builtin.Reshape(axis=unspec_axis) | |||
| (x,) = apply(op, x, shape) | |||
| return x | |||
| def _unary_elwise(mode): | |||
| def f(self): | |||
| return _elwise(self, mode=mode) | |||
| @@ -667,11 +609,11 @@ class ArrayMethodMixin(abc.ABC): | |||
| def reshape(self, *args): | |||
| r"""See :func:`~.reshape`.""" | |||
| return _reshape(self, _expand_args(args)) | |||
| return reshape_cpp(self, args) | |||
| # FIXME: remove this method | |||
| def _broadcast(self, *args): | |||
| return _broadcast(self, _expand_args(args)) | |||
| return broadcast_cpp(self, args) | |||
| def transpose(self, *args): | |||
| r"""See :func:`~.transpose`.""" | |||
| @@ -679,7 +621,7 @@ class ArrayMethodMixin(abc.ABC): | |||
| def flatten(self): | |||
| r"""See :func:`~.flatten`.""" | |||
| return self.reshape(-1) | |||
| return reshape_cpp(self, (-1,)) | |||
| def sum(self, axis=None, keepdims: bool = False): | |||
| r"""Returns the sum of each row of the input tensor in the given dimension ``axis``. | |||
| @@ -15,6 +15,7 @@ from ..core._imperative_rt import CompNode | |||
| from ..core._imperative_rt.core2 import ( | |||
| SymbolVar, | |||
| apply, | |||
| broadcast_cpp, | |||
| dtype_promotion, | |||
| expand_dims_cpp, | |||
| split_cpp, | |||
| @@ -24,7 +25,6 @@ from ..core._wrap import as_device | |||
| from ..core.ops import builtin | |||
| from ..core.ops.builtin import Copy, Identity | |||
| from ..core.ops.special import Const | |||
| from ..core.tensor.array_method import _broadcast | |||
| from ..core.tensor.utils import astensor1d, convert_inputs, get_device, subgraph_fn | |||
| from ..device import get_default_device | |||
| from ..tensor import Tensor | |||
| @@ -360,7 +360,7 @@ def broadcast_to(inp: Tensor, shape: Union[int, Iterable[int]]) -> Tensor: | |||
| [[0. 1. 2.] | |||
| [0. 1. 2.]] | |||
| """ | |||
| return _broadcast(inp, shape) | |||
| return broadcast_cpp(inp, shape) | |||
| def concat(inps: Iterable[Tensor], axis: int = 0, device=None) -> Tensor: | |||
| @@ -135,23 +135,24 @@ std::optional<ValueRefList> elemwise_grad_rule( | |||
| std::optional<ValueRefList> reshape_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| mgb_assert(inputs.size() == 2); | |||
| mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
| size_t nr_inp = inputs.size(); | |||
| std::array<ValueRef, 2> input_shapes; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| for (size_t i = 0; i < nr_inp; ++i) { | |||
| if (inputs_require_grad[i]) { | |||
| input_shapes[i] = get_shape(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| SmallVector<ValueRef> ret(2); | |||
| SmallVector<ValueRef> ret(nr_inp); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| for (size_t i = 0; i < nr_inp; ++i) { | |||
| if (shapes[i]) { | |||
| ret[i] = reshape_to(grad, shapes[i]); | |||
| } | |||
| @@ -162,6 +163,37 @@ std::optional<ValueRefList> reshape_grad_rule( | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<ValueRefList> broadcast_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
| size_t nr_inp = inputs.size(); | |||
| std::array<ValueRef, 2> input_shapes; | |||
| for (size_t i = 0; i < nr_inp; ++i) { | |||
| if (inputs_require_grad[i]) { | |||
| input_shapes[i] = get_shape(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([shapes = std::move(input_shapes), nr_inp](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| SmallVector<ValueRef> ret(nr_inp); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| for (size_t i = 0; i < nr_inp; ++i) { | |||
| if (shapes[i]) { | |||
| ret[i] = reduce_to(grad, shapes[i]); | |||
| } | |||
| } | |||
| return ret; | |||
| }); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<ValueRefList> subtensor_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| @@ -330,6 +362,7 @@ struct Init { | |||
| Init() { | |||
| CustomBackward::register_grad_rule(Elemwise::typeinfo(), elemwise_grad_rule); | |||
| CustomBackward::register_grad_rule(Reshape::typeinfo(), reshape_grad_rule); | |||
| CustomBackward::register_grad_rule(Broadcast::typeinfo(), broadcast_grad_rule); | |||
| CustomBackward::register_grad_rule(Subtensor::typeinfo(), subtensor_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| IndexingMultiAxisVec::typeinfo(), indexingMultiAxisVec_grad_rule); | |||
| @@ -637,6 +637,8 @@ WRAP_FUNC_PY35(split_cpp); | |||
| WRAP_FUNC_PY35(expand_dims_cpp); | |||
| WRAP_FUNC_PY35(squeeze_cpp); | |||
| WRAP_FUNC_PY35(transpose_cpp); | |||
| WRAP_FUNC_PY35(broadcast_cpp); | |||
| WRAP_FUNC_PY35(reshape_cpp); | |||
| #undef WRAP_FUNC_PY35 | |||
| #define MGE_PY_INTERFACE(NAME, FUNC) \ | |||
| { #NAME, (PyCFunction)py35_##FUNC, METH_VARARGS, nullptr } | |||
| @@ -773,6 +775,8 @@ void init_tensor(py::module m) { | |||
| MGE_PY_INTERFACE(expand_dims_cpp, expand_dims_cpp), | |||
| MGE_PY_INTERFACE(squeeze_cpp, squeeze_cpp), | |||
| MGE_PY_INTERFACE(transpose_cpp, transpose_cpp), | |||
| MGE_PY_INTERFACE(broadcast_cpp, broadcast_cpp), | |||
| MGE_PY_INTERFACE(reshape_cpp, reshape_cpp), | |||
| {nullptr, nullptr, 0, nullptr}}; | |||
| for (auto&& def : method_defs) { | |||
| if (def.ml_meth != nullptr) { | |||
| @@ -800,29 +800,46 @@ size_t fast_ndim(py::handle tensor) { | |||
| return getattr(tensor, "ndim").cast<size_t>(); | |||
| } | |||
| py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object _expand_args(py::handle args) { | |||
| if (!PyTuple_Check(args.ptr())) { | |||
| return py::reinterpret_borrow<py::object>(args); | |||
| } | |||
| py::tuple args_tup = py::reinterpret_borrow<py::tuple>(args.ptr()); | |||
| if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||
| is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||
| return py::reinterpret_borrow<py::object>(args_tup[0]); | |||
| } else { | |||
| return py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||
| } | |||
| } | |||
| py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object obj = _expand_args(args); | |||
| py::list lis; | |||
| if (!is_tensor_or_symbolvar(obj.ptr()) && PySequence_Check(obj.ptr())) { | |||
| lis = py::reinterpret_steal<py::list>(PySequence_List(obj.ptr())); | |||
| } else { | |||
| py::object np = getattr(obj, "numpy")(); | |||
| PyArrayObject* arr = (PyArrayObject*)np.ptr(); | |||
| PyObject* maybe_list = PyArray_ToList(arr); | |||
| if (PyList_Check(maybe_list)) { | |||
| lis = py::reinterpret_steal<py::list>(maybe_list); | |||
| } | |||
| } | |||
| if (fast_ndim(inp_hdl) == 0) { | |||
| if (args_tup.size() != 0) { | |||
| if (lis.size() != 0) { | |||
| throw py::index_error( | |||
| "transpose for scalar does not accept additional args"); | |||
| } | |||
| return getattr(inp_hdl, "to")(getattr(inp_hdl, "device")); | |||
| } | |||
| std::vector<int32_t> pattern; | |||
| if (!args_tup.size()) { | |||
| if (!lis.size()) { | |||
| size_t ndim = getattr(inp_hdl, "ndim").cast<size_t>(); | |||
| for (size_t i = 0; i < ndim; ++i) { | |||
| pattern.push_back(ndim - i - 1); | |||
| } | |||
| } else { | |||
| py::list lis; | |||
| if (args_tup.size() == 1 && (PySequence_Check(args_tup[0].ptr()) || | |||
| is_tensor_or_symbolvar(args_tup[0].ptr()))) { | |||
| lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup[0].ptr())); | |||
| } else { | |||
| lis = py::reinterpret_steal<py::list>(PySequence_List(args_tup.ptr())); | |||
| } | |||
| for (size_t i = 0; i < lis.size(); ++i) { | |||
| if (PyLong_Check(lis[i].ptr())) { | |||
| pattern.push_back(lis[i].cast<int32_t>()); | |||
| @@ -844,6 +861,182 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| return ret[0]; | |||
| } | |||
| std::tuple<std::vector<int32_t>, bool> tuple2vector(py::object shape) { | |||
| std::vector<int32_t> shp; | |||
| if (!PyTuple_Check(shape.ptr())) { | |||
| return {shp, false}; | |||
| } | |||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(shape); | |||
| for (size_t i = 0; i < tup.size(); ++i) { | |||
| if (!PyLong_Check(tup[i].ptr())) { | |||
| return {shp, false}; | |||
| } else { | |||
| shp.push_back(tup[i].cast<int32_t>()); | |||
| } | |||
| } | |||
| return {shp, true}; | |||
| } | |||
| bool enable_fastpath(py::handle inp) { | |||
| if (!TensorWrapper::try_cast(inp.ptr()) || | |||
| TransformationManager::get_instance() | |||
| .segments[TransformationManager::Segment::Trace] | |||
| .size() > 0 || | |||
| TransformationManager::get_instance() | |||
| .segments[TransformationManager::Segment::ModuleTrace] | |||
| .size() > 0) { | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| py::object _broadcast_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object shape_hdl = _expand_args(args); | |||
| bool auto_infer = false; | |||
| py::list lis; | |||
| py::list new_shape; | |||
| if (PyList_Check(shape_hdl.ptr()) || PyTuple_Check(shape_hdl.ptr())) { | |||
| lis = py::reinterpret_steal<py::list>(PySequence_List(shape_hdl.ptr())); | |||
| for (size_t i = 0; i < lis.size(); ++i) { | |||
| if (lis[i].ptr() == Py_None) { | |||
| auto_infer = true; | |||
| size_t right = lis.size() - i; | |||
| py::object tshp = getattr(inp_hdl, "_tuple_shape"); | |||
| if (tshp.ptr() == Py_None) { | |||
| throw py::index_error("does not support `None` with unknown shape"); | |||
| } | |||
| py::tuple inp_shape = py::reinterpret_borrow<py::tuple>(tshp); | |||
| if (inp_shape.size() >= right) { | |||
| if (enable_fastpath(inp_hdl)) { | |||
| lis[i] = inp_shape[inp_shape.size() - right]; | |||
| } | |||
| new_shape.append(inp_shape[inp_shape.size() - right]); | |||
| } else { | |||
| throw py::value_error("invalid broadcast shape"); | |||
| } | |||
| } else { | |||
| new_shape.append(lis[i]); | |||
| if (PyLong_Check(lis[i].ptr())) { | |||
| int32_t s = lis[i].cast<int32_t>(); | |||
| if (s < 0) { | |||
| throw py::value_error( | |||
| "expect shape[" + std::to_string(i) + | |||
| "] >= 0 or use `None` to auto infer, got " + | |||
| std::to_string(s)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (auto_infer) { | |||
| if (enable_fastpath(inp_hdl)) { | |||
| shape_hdl = py::reinterpret_borrow<py::tuple>(lis); | |||
| } else { | |||
| py::tuple args = py::make_tuple(new_shape, inp_hdl); | |||
| py::dict kwargs; | |||
| kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||
| kwargs["device"] = getattr(inp_hdl, "device"); | |||
| shape_hdl = py::reinterpret_steal<py::object>( | |||
| PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||
| } | |||
| } | |||
| py::object shape_tuple; | |||
| try { | |||
| shape_tuple = _make_shape_tuple(shape_hdl); | |||
| } catch (py::error_already_set& err) { | |||
| shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | |||
| } | |||
| auto [shape, fastpath] = tuple2vector(shape_tuple); | |||
| fastpath &= enable_fastpath(inp_hdl); | |||
| std::shared_ptr<OpDef> op; | |||
| std::vector<PyObject*> p; | |||
| py::object shape_tensor; | |||
| if (fastpath) { | |||
| op = Broadcast::make(shape); | |||
| p.resize(2); | |||
| } else { | |||
| op = Broadcast::make(); | |||
| py::tuple args = py::make_tuple(shape_hdl, inp_hdl); | |||
| py::dict kwargs; | |||
| kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||
| kwargs["device"] = getattr(inp_hdl, "device"); | |||
| shape_tensor = py::reinterpret_steal<py::object>( | |||
| PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||
| p.resize(3); | |||
| p[2] = shape_tensor.ptr(); | |||
| } | |||
| py::object Op = py::cast(op); | |||
| p[0] = Op.ptr(); | |||
| p[1] = inp_hdl.ptr(); | |||
| py::tuple ret = | |||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||
| return ret[0]; | |||
| } | |||
| py::object _reshape_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object shape_hdl = _expand_args(args); | |||
| py::object shape_tuple; | |||
| try { | |||
| shape_tuple = _make_shape_tuple(shape_hdl); | |||
| } catch (py::error_already_set& err) { | |||
| shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | |||
| } | |||
| int32_t unspec_axis = -1; | |||
| if (PyTuple_Check(shape_tuple.ptr())) { | |||
| py::tuple tup = py::reinterpret_borrow<py::tuple>(shape_tuple); | |||
| for (size_t i = 0; i < tup.size(); ++i) { | |||
| py::object obj = py::reinterpret_borrow<py::object>(tup[i]); | |||
| if (obj < py::int_(0)) { | |||
| if (obj.not_equal(py::int_(-1))) { | |||
| throw py::value_error( | |||
| "expect shape [" + std::to_string(i) + "] >= -1, got " + | |||
| repr(obj).cast<std::string>()); | |||
| } | |||
| if (unspec_axis >= 0) { | |||
| throw py::value_error( | |||
| "multiple -1 in shape: " + std::to_string(unspec_axis) + | |||
| " & " + std::to_string(i)); | |||
| } | |||
| unspec_axis = i; | |||
| } | |||
| } | |||
| } | |||
| auto [shape, fastpath] = tuple2vector(shape_tuple); | |||
| fastpath &= enable_fastpath(inp_hdl); | |||
| std::shared_ptr<OpDef> op; | |||
| std::vector<PyObject*> p; | |||
| py::object shape_tensor; | |||
| if (fastpath) { | |||
| if (unspec_axis >= 0) { | |||
| op = Reshape::make(unspec_axis, shape); | |||
| } else { | |||
| op = Reshape::make(::megdnn::param::OptionalAxisV1::INVALID_AXIS, shape); | |||
| } | |||
| p.resize(2); | |||
| } else { | |||
| shape.clear(); | |||
| if (unspec_axis >= 0) { | |||
| op = Reshape::make(unspec_axis, shape); | |||
| } else { | |||
| op = Reshape::make(); | |||
| } | |||
| py::tuple args = py::make_tuple(shape_hdl, inp_hdl); | |||
| py::dict kwargs; | |||
| kwargs["dtype"] = py::cast((mgb::DType)dtype::Int32()); | |||
| kwargs["device"] = getattr(inp_hdl, "device"); | |||
| shape_tensor = py::reinterpret_steal<py::object>( | |||
| PyObject_Call(cpp_astensor1d, args.ptr(), kwargs.ptr())); | |||
| p.resize(3); | |||
| p[2] = shape_tensor.ptr(); | |||
| } | |||
| py::object Op = py::cast(op); | |||
| p[0] = Op.ptr(); | |||
| p[1] = inp_hdl.ptr(); | |||
| py::tuple ret = | |||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||
| return ret[0]; | |||
| } | |||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _make_shape_tuple(py::handle(args[0])).release().ptr(); | |||
| @@ -900,4 +1093,18 @@ PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _broadcast_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _reshape_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr(); | |||
| } | |||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||
| } | |||
| } // namespace mgb::imperative::python | |||
| @@ -16,4 +16,8 @@ PyObject* squeeze_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* transpose_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* broadcast_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| PyObject* reshape_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||
| } // namespace mgb::imperative::python | |||
| @@ -267,7 +267,7 @@ def test_broadcast_auto_infer(is_varnode): | |||
| F.broadcast_to(xx, (None, 1, 2, 3)) | |||
| F.broadcast_to(xx, (1, None, 2, 3)) | |||
| t = tensor(2, dtype=np.int32) | |||
| t = make_tensor(2, network) | |||
| F.broadcast_to(xx, (t, None, 2, 3)) | |||
| @@ -51,57 +51,75 @@ bool valid_broadcast(const TensorShape& src_shape, const TensorShape& tar_shape) | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op = def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| auto&& src = inputs[0]; | |||
| auto&& tshp = inputs[1]; | |||
| TensorShape out_shape; | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_shape.ndim = 0; | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| "target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||
| tshp.layout.ndim); | |||
| size_t target_ndim = tshp.layout.shape[0]; | |||
| out_shape.ndim = target_ndim; | |||
| auto* ptr = tshp.value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < target_ndim; ++i) { | |||
| out_shape[i] = ptr[i]; | |||
| if (nr_inp == 1) { | |||
| out_shape.ndim = op.shape.size(); | |||
| for (size_t i = 0; i < out_shape.ndim; ++i) { | |||
| out_shape[i] = op.shape[i]; | |||
| } | |||
| } else { | |||
| auto&& tshp = inputs[1]; | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_shape.ndim = 0; | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
| false}; | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| "target shape of Broadcast expects ndim=1; got ndim=%lu actually", | |||
| tshp.layout.ndim); | |||
| size_t target_ndim = tshp.layout.shape[0]; | |||
| out_shape.ndim = target_ndim; | |||
| auto* ptr = tshp.value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < target_ndim; ++i) { | |||
| out_shape[i] = ptr[i]; | |||
| } | |||
| } | |||
| mgb_assert( | |||
| valid_broadcast(src.layout, out_shape), | |||
| "the input shape %s can not be broadcasted to target shape %s", | |||
| src.layout.to_string().c_str(), out_shape.to_string().c_str()); | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| def.cast_final_safe<Broadcast>(); | |||
| auto&& op = def.cast_final_safe<Broadcast>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); | |||
| TensorShape tshp; | |||
| auto&& src = inputs[0]; | |||
| auto&& tshp_nd = inputs[1]; | |||
| auto slayout = src->layout(); | |||
| TensorShape tshp; | |||
| cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
| if (nr_inp == 1) { | |||
| tshp.ndim = op.shape.size(); | |||
| for (size_t i = 0; i < tshp.ndim; ++i) { | |||
| tshp[i] = op.shape[i]; | |||
| } | |||
| } else { | |||
| auto&& tshp_nd = inputs[1]; | |||
| cg::copy_tensor_value_to_shape( | |||
| tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
| } | |||
| TensorLayout tlayout = slayout.broadcast(tshp); | |||
| // memory forward | |||
| return {Tensor::make(src->blob(), src->offset(), tlayout)}; | |||
| } | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
| return layout_checker; | |||
| } | |||
| OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .get_input_layout_constraint(get_input_layout_constraint) | |||
| .fallback(); | |||
| } // namespace broadcast | |||
| @@ -118,35 +136,49 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op = def.cast_final_safe<Reshape>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
| auto&& src = inputs[0]; | |||
| auto&& tshp = inputs[1]; | |||
| TensorShape out_shape; | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_shape.ndim = 0; | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| "target shape of Reshape expects ndim=1; got ndim=%lu actually", | |||
| tshp.layout.ndim); | |||
| if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
| } | |||
| size_t target_ndim = tshp.layout.shape[0]; | |||
| out_shape.ndim = target_ndim; | |||
| auto* ptr = tshp.value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < target_ndim; ++i) { | |||
| out_shape[i] = ptr[i]; | |||
| } | |||
| if (src.layout.ndim == 0) { | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, false}; | |||
| if (nr_inp == 1) { | |||
| if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
| false}; | |||
| } | |||
| out_shape.ndim = op.shape.size(); | |||
| for (size_t i = 0; i < out_shape.ndim; ++i) { | |||
| out_shape[i] = op.shape[i]; | |||
| } | |||
| if (src.layout.ndim == 0) { | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
| false}; | |||
| } | |||
| } else { | |||
| auto&& tshp = inputs[1]; | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_shape.ndim = 0; | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
| false}; | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| "target shape of Reshape expects ndim=1; got ndim=%lu actually", | |||
| tshp.layout.ndim); | |||
| if (src.layout.ndim == 0 && op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
| false}; | |||
| } | |||
| size_t target_ndim = tshp.layout.shape[0]; | |||
| out_shape.ndim = target_ndim; | |||
| auto* ptr = tshp.value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < target_ndim; ++i) { | |||
| out_shape[i] = ptr[i]; | |||
| } | |||
| if (src.layout.ndim == 0) { | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, | |||
| false}; | |||
| } | |||
| } | |||
| if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| mgb_assert(out_shape[op.axis] == -1); | |||
| out_shape[op.axis] = 1; | |||
| @@ -167,19 +199,27 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op_def = def.cast_final_safe<Reshape>(); | |||
| auto&& op = def.cast_final_safe<Reshape>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
| auto&& src = inputs[0]; | |||
| auto&& tshp_nd = inputs[1]; | |||
| auto slayout = src->layout(); | |||
| TensorShape tshp; | |||
| cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
| if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| mgb_assert(tshp[op_def.axis] == -1); | |||
| tshp[op_def.axis] = 1; | |||
| tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
| if (nr_inp == 1) { | |||
| tshp.ndim = op.shape.size(); | |||
| for (size_t i = 0; i < tshp.ndim; ++i) { | |||
| tshp[i] = op.shape[i]; | |||
| } | |||
| } else { | |||
| auto&& tshp_nd = inputs[1]; | |||
| cg::copy_tensor_value_to_shape( | |||
| tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
| } | |||
| if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| mgb_assert(tshp[op.axis] == -1); | |||
| tshp[op.axis] = 1; | |||
| tshp[op.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
| } | |||
| TensorLayout tlayout; | |||
| mgb_assert(slayout.try_reshape(tlayout, tshp)); | |||
| @@ -188,17 +228,24 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<Reshape>(); | |||
| auto&& op = def.cast_final_safe<Reshape>(); | |||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||
| layout_checker[0] = [&](const TensorLayout& layout) { | |||
| TensorShape tshp; | |||
| TensorLayout ret; | |||
| cg::copy_tensor_value_to_shape( | |||
| tshp, inputs[1]->get_value().proxy_to_default_cpu()); | |||
| if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| mgb_assert(tshp[op_def.axis] == -1); | |||
| tshp[op_def.axis] = 1; | |||
| tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | |||
| if (inputs.size() == 1) { | |||
| tshp.ndim = op.shape.size(); | |||
| for (size_t i = 0; i < tshp.ndim; ++i) { | |||
| tshp[i] = op.shape[i]; | |||
| } | |||
| } else { | |||
| cg::copy_tensor_value_to_shape( | |||
| tshp, inputs[1]->get_value().proxy_to_default_cpu()); | |||
| } | |||
| if (op.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| mgb_assert(tshp[op.axis] == -1); | |||
| tshp[op.axis] = 1; | |||
| tshp[op.axis] = layout.total_nr_elems() / tshp.total_nr_elems(); | |||
| } | |||
| if (layout.try_reshape(ret, tshp)) { | |||
| return true; | |||
| @@ -243,8 +243,10 @@ ValueRefList get_var_shape_rule( | |||
| ValueRefList reshape_rule( | |||
| const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
| size_t nr_inp = inputs.size(); | |||
| bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || | |||
| (nr_inp == 1 && reshape.shape.size() == 0); | |||
| if (is_scalar) { | |||
| return {scalar_type.make(imperative::apply( | |||
| reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
| @@ -256,8 +258,10 @@ ValueRefList reshape_rule( | |||
| ValueRefList broadcast_rule( | |||
| const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| mgb_assert(inputs.size() == 1 || inputs.size() == 2); | |||
| size_t nr_inp = inputs.size(); | |||
| bool is_scalar = (nr_inp == 2 && is_scalar_shape(inputs[1])) || | |||
| (nr_inp == 1 && broadcast.shape.size() == 0); | |||
| if (is_scalar) { | |||
| return {scalar_type.make(imperative::apply( | |||
| broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
| @@ -250,7 +250,11 @@ def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||
| ); | |||
| } | |||
| def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]>; | |||
| def Broadcast : MgbHashableOp<"Broadcast", [EmptyParam]> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbI32Attr>:$shape | |||
| ); | |||
| } | |||
| def Identity: MgbHashableOp<"Identity">; | |||
| @@ -318,7 +322,11 @@ def Dimshuffle: MgbHashableOp<"Dimshuffle"> { | |||
| let results = (outs AnyMemRef); | |||
| } | |||
| def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]>; | |||
| def Reshape: MgbHashableOp<"Reshape", [OptionalAxisV1Param]> { | |||
| let extraArguments = (ins | |||
| MgbArrayAttr<MgbI32Attr>:$shape | |||
| ); | |||
| } | |||
| // TODO: merge Add/Remove Axis into AxisAddRemove as megbrain? | |||
| def AddAxis: MgbHashableOp<"AddAxis"> { | |||