| @@ -36,7 +36,7 @@ public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_dtype(DType A, DType B, DType& C); | |||
| MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
| @@ -73,7 +73,7 @@ public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_dtype(DType A, DType B, DType& C); | |||
| MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
| @@ -44,216 +44,6 @@ def _elwise(*args, mode): | |||
| return _elwise_apply(args, mode) | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| ): | |||
| @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
| def extentedMatrixMulOp(inputs, f, c): | |||
| assert len(inputs) == 2 | |||
| inp1, inp2 = inputs | |||
| _dim1, _dim2 = dim1, dim2 | |||
| def build_shape_head(shape, idx=-1): | |||
| # shape[:idx] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| def build_shape_tail(shape, idx=-1): | |||
| # shape[idx:] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| remove_row, remove_col = False, False | |||
| if _dim1 == 1: | |||
| _dim1 = 2 | |||
| remove_row = True | |||
| if _dim2 == 1: | |||
| _dim2 = 2 | |||
| remove_col = True | |||
| if remove_row: | |||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
| if remove_col: | |||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
| shape1 = f(builtin.GetVarShape(), inp1) | |||
| shape2 = f(builtin.GetVarShape(), inp2) | |||
| if _dim1 > 2: | |||
| inp1 = f( | |||
| builtin.Reshape(), | |||
| inp1, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||
| build_shape_tail(shape1), | |||
| ), | |||
| ) | |||
| if _dim2 > 2: | |||
| inp2 = f( | |||
| builtin.Reshape(), | |||
| inp2, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||
| build_shape_tail(shape2), | |||
| ), | |||
| ) | |||
| op = builtin.MatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| result_shape = f(builtin.GetVarShape(), result) | |||
| if _dim1 > 2: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape1), | |||
| build_shape_tail(result_shape), | |||
| ), | |||
| ) | |||
| if _dim2 > 2: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape2), | |||
| build_shape_tail(result_shape), | |||
| ), | |||
| ) | |||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
| if remove_row: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
| if remove_col: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
| return (result,), (True,) | |||
| return extentedMatrixMulOp | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedBatchedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| ): | |||
| @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||
| def extentedBatchedMatrixMulOp(inputs, f, c): | |||
| assert len(inputs) == 2 | |||
| inp1, inp2 = inputs | |||
| _dim1, _dim2 = dim1, dim2 | |||
| def build_shape_head(shape, idx=-2): | |||
| # shape[:idx] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| def build_shape_tail(shape, idx=-2): | |||
| # shape[idx:] | |||
| return f( | |||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||
| shape, | |||
| c(idx, "int32"), | |||
| ) | |||
| remove_row, remove_col = False, False | |||
| if _dim1 == 1: | |||
| _dim1 = 2 | |||
| remove_row = True | |||
| if _dim2 == 1: | |||
| _dim2 = 2 | |||
| remove_col = True | |||
| if remove_row: | |||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||
| if remove_col: | |||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||
| shape1 = f(builtin.GetVarShape(), inp1) | |||
| shape2 = f(builtin.GetVarShape(), inp2) | |||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||
| if _dim1 > _dim2: | |||
| # broadcast | |||
| shape2 = f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||
| shape2, | |||
| ) | |||
| inp2 = f(builtin.Broadcast(), inp2, shape2) | |||
| batch_shape = build_shape_head(shape1) | |||
| if _dim2 > _dim1: | |||
| # broadcast | |||
| shape1 = f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||
| shape1, | |||
| ) | |||
| inp1 = f(builtin.Broadcast(), inp1, shape1) | |||
| batch_shape = build_shape_head(shape2) | |||
| if _dim1 == _dim2: | |||
| batch_shape = build_shape_head(shape1) | |||
| # compress inputs to 3d | |||
| if maxdim > 3: | |||
| inp1 = f( | |||
| builtin.Reshape(), | |||
| inp1, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
| build_shape_tail(shape1), | |||
| ), | |||
| ) | |||
| inp2 = f( | |||
| builtin.Reshape(), | |||
| inp2, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||
| build_shape_tail(shape2), | |||
| ), | |||
| ) | |||
| op = builtin.BatchedMatrixMul( | |||
| transposeA=transpose_a, | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| if maxdim > 3: | |||
| result = f( | |||
| builtin.Reshape(), | |||
| result, | |||
| f( | |||
| builtin.Concat(axis=0, comp_node=device), | |||
| batch_shape, | |||
| build_shape_tail(f(builtin.GetVarShape(), result)), | |||
| ), | |||
| ) | |||
| if remove_row: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||
| if remove_col: | |||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||
| return (result,), (True,) | |||
| return extentedBatchedMatrixMulOp | |||
| class _Hashable: | |||
| def __init__(self, value) -> None: | |||
| self.value = value | |||
| @@ -267,42 +57,6 @@ class _Hashable: | |||
| return self.value == o.value | |||
| def symbolicMatrixMul( | |||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||
| ): | |||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||
| inp1.device, | |||
| inp1.dtype, | |||
| dim1, | |||
| dim2, | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=_Hashable(strategy), | |||
| ) | |||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| def symbolicBatchedMatrixMul( | |||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||
| ): | |||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
| inp1.device, | |||
| inp1.dtype, | |||
| dim1, | |||
| dim2, | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=_Hashable(strategy), | |||
| ) | |||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| def _matmul( | |||
| inp1, | |||
| inp2, | |||
| @@ -342,11 +96,8 @@ def _matmul( | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| _config._benchmark_kernel, | |||
| _config._deterministic_kernel, | |||
| strategy, | |||
| symbolicMatrixMul, | |||
| ) | |||
| else: # dispath to BatchedMatrixMul | |||
| # nx1(transpose_a=True), n>=3 | |||
| @@ -362,11 +113,8 @@ def _matmul( | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| _config._benchmark_kernel, | |||
| _config._deterministic_kernel, | |||
| strategy, | |||
| symbolicBatchedMatrixMul, | |||
| ) | |||
| @@ -32,7 +32,7 @@ from ..core.ops.builtin import ( | |||
| TypeCvt, | |||
| ) | |||
| from ..core.tensor import amp, megbrain_graph | |||
| from ..core.tensor.array_method import _elwise_apply | |||
| from ..core.tensor.array_method import _matmul | |||
| from ..core.tensor.utils import ( | |||
| astensor1d, | |||
| cast_tensors, | |||
| @@ -49,7 +49,7 @@ from ..utils.deprecation import deprecated_func | |||
| from .debug_param import get_execution_strategy | |||
| from .distributed import all_reduce_sum | |||
| from .elemwise import _elwise, exp, log, log1p, maximum, minimum | |||
| from .math import matmul, max, sum | |||
| from .math import max, sum | |||
| from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | |||
| __all__ = [ | |||
| @@ -127,7 +127,7 @@ def linear( | |||
| bias: bias with shape `(out_features,)`. Default: None | |||
| """ | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
| ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||
| if bias is not None: | |||
| if amp._enabled: | |||
| bias = bias.astype("float16") | |||
| @@ -1494,73 +1494,61 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object _matmul_cpp( | |||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
| py::handle format, py::handle profile, py::handle determistic, | |||
| py::handle strategy, py::handle func) { | |||
| if (enable_fastpath(inp1)) { | |||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
| } | |||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||
| if (profile.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
| } else { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
| } | |||
| if (determistic.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
| } | |||
| std::shared_ptr<OpDef> op = MatrixMul::make( | |||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||
| py::object Op = py::cast(op); | |||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
| return ret[0]; | |||
| py::handle profile, py::handle determistic) { | |||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
| } | |||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy = | |||
| static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); | |||
| if (profile.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
| } else { | |||
| // fallback to traceable implementation | |||
| return func( | |||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||
| strategy); | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
| } | |||
| if (determistic.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
| } | |||
| std::shared_ptr<OpDef> op = MatrixMul::make( | |||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, | |||
| dim1.cast<uint32_t>(), dim2.cast<uint32_t>()); | |||
| py::object Op = py::cast(op); | |||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
| return ret[0]; | |||
| } | |||
| py::object _batched_matmul_cpp( | |||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
| py::handle format, py::handle profile, py::handle determistic, | |||
| py::handle strategy, py::handle func) { | |||
| if (enable_fastpath(inp1)) { | |||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
| } | |||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||
| if (profile.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
| } else { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
| } | |||
| if (determistic.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
| } | |||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||
| py::object Op = py::cast(op); | |||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
| return ret[0]; | |||
| py::handle profile, py::handle determistic) { | |||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||
| } | |||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy = | |||
| static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); | |||
| if (profile.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||
| } else { | |||
| // fallback to traceable implementation | |||
| return func( | |||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||
| strategy); | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
| } | |||
| if (determistic.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
| } | |||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, | |||
| dim1.cast<uint32_t>(), dim2.cast<uint32_t>()); | |||
| py::object Op = py::cast(op); | |||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||
| return ret[0]; | |||
| } | |||
| py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { | |||
| @@ -1671,7 +1659,7 @@ PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||
| try { | |||
| return _matmul_cpp( | |||
| args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||
| args[7], args[8], args[9], args[10], args[11]) | |||
| args[7], args[8]) | |||
| .release() | |||
| .ptr(); | |||
| } | |||
| @@ -1682,7 +1670,7 @@ PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs | |||
| try { | |||
| return _batched_matmul_cpp( | |||
| args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||
| args[7], args[8], args[9], args[10], args[11]) | |||
| args[7], args[8]) | |||
| .release() | |||
| .ptr(); | |||
| } | |||
| @@ -20,7 +20,6 @@ import megengine.optimizer as optim | |||
| from megengine import tensor | |||
| from megengine.autodiff import GradManager | |||
| from megengine.jit import trace | |||
| from megengine.traced_module import trace_module | |||
| @contextlib.contextmanager | |||
| @@ -2,8 +2,12 @@ | |||
| #include "../blob_manager_impl.h" | |||
| #include "../dnn_op_helper.h" | |||
| #include "../op_trait.h" | |||
| #include "megbrain/graph/symbol_var.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "../algo_chooser.h" | |||
| @@ -12,12 +16,93 @@ namespace imperative { | |||
| namespace { | |||
| namespace matrix_mul { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{matmul.make_name()}; | |||
| return opr::MatrixMul::make( | |||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||
| auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||
| auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||
| auto cn = inputs[0]->comp_node(); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| using IndexDesc = opr::Subtensor::IndexDesc; | |||
| OperatorNodeConfig config{matmul.make_name(), cn}; | |||
| DTypeScalar vi{-1}; | |||
| auto graph = inputs[0]->owner_graph(); | |||
| bool remove_row = false, remove_col = false; | |||
| if (dim1 == 1) { | |||
| dim1 = 2; | |||
| remove_row = true; | |||
| inp1 = inp1.add_axis(0); | |||
| } | |||
| if (dim2 == 1) { | |||
| dim2 = 2; | |||
| remove_col = true; | |||
| inp2 = inp2.add_axis(1); | |||
| } | |||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
| if (dim1 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto shp1 = inp1.symshape(); | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| shp1_head = opr::Subtensor::make(shp1, head_desc); | |||
| auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||
| auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||
| inp1 = inp1.reshape(tshp); | |||
| } | |||
| if (dim2 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto shp2 = inp2.symshape(); | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| shp2_head = opr::Subtensor::make(shp2, head_desc); | |||
| auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||
| auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||
| inp2 = inp2.reshape(tshp); | |||
| } | |||
| auto result = | |||
| opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config); | |||
| if (dim1 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto result_shape = result.symshape(); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||
| auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); | |||
| result = result.reshape(tshp); | |||
| } | |||
| if (dim2 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto result_shape = result.symshape(); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||
| auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); | |||
| result = result.reshape(tshp); | |||
| } | |||
| auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||
| if (remove_row) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| if (remove_col) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| return result; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| @@ -27,8 +112,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| auto layout2 = inputs[1].layout; | |||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
| DType dst_dtype; | |||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
| if (dim1 == 0 || dim2 == 0) { | |||
| return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||
| return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; | |||
| } | |||
| if (matmul.transposeA) | |||
| @@ -37,7 +128,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| std::swap(layout2[0], layout2[1]); | |||
| mgb_assert(layout1[dim1 - 1] == layout2[0]); | |||
| TensorLayout dst_layout(layout1.dtype); | |||
| TensorLayout dst_layout(dst_dtype); | |||
| size_t ci = 0; | |||
| for (size_t i = 0; i < dim1 - 1; i++) | |||
| dst_layout[ci++] = layout1[i]; | |||
| @@ -61,6 +153,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| SmallVector<TensorND> inp_tensornds(inputs.size()); | |||
| TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| DType dst_dtype; | |||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
| // only matters when layout1 has dim 2 | |||
| if (matmul.transposeA) | |||
| std::swap(layout1.shape[0], layout1.shape[1]); | |||
| @@ -69,7 +167,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| std::swap(layout2.shape[0], layout2.shape[1]); | |||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
| TensorLayout real_dst_layout(layout1.dtype); | |||
| TensorLayout real_dst_layout(dst_dtype); | |||
| if (validated) { | |||
| real_dst_layout = output_descs[0].layout; | |||
| } else { | |||
| @@ -126,12 +224,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
| } | |||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); | |||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | |||
| dst_layout.init_contiguous_stride(); | |||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| DeviceTensorND out = | |||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||
| size_t sz = setup_algo<megdnn::MatrixMul>( | |||
| @@ -167,9 +262,99 @@ namespace batched_matrix_mul { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{matmul.make_name()}; | |||
| return opr::BatchedMatrixMul::make( | |||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||
| auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||
| auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||
| auto cn = inputs[0]->comp_node(); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| using IndexDesc = opr::Subtensor::IndexDesc; | |||
| OperatorNodeConfig config{matmul.make_name(), cn}; | |||
| DTypeScalar vi{-2}; | |||
| auto graph = inputs[0]->owner_graph(); | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| bool remove_row = false, remove_col = false; | |||
| if (dim1 == 1) { | |||
| dim1 = 2; | |||
| remove_row = true; | |||
| inp1 = inp1.add_axis(0); | |||
| } | |||
| if (dim2 == 1) { | |||
| dim2 = 2; | |||
| remove_col = true; | |||
| inp2 = inp2.add_axis(1); | |||
| } | |||
| auto shp1 = inp1.symshape(); | |||
| auto shp2 = inp2.symshape(); | |||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
| SymbolVar batch_shape; | |||
| if (dim1 > dim2) { | |||
| HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); | |||
| auto* ptr = hv.ptr<dt_int32>(); | |||
| ptr[0] = -dim2; | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); | |||
| shp1_head = opr::Subtensor::make(shp1, head_desc); | |||
| shp2 = opr::Concat::make({shp1_head, shp2}, 0, cn); | |||
| inp2 = inp2.broadcast(shp2); | |||
| head_desc[0].end = idx; | |||
| batch_shape = opr::Subtensor::make(shp1, head_desc); | |||
| } | |||
| if (dim2 > dim1) { | |||
| HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); | |||
| auto* ptr = hv.ptr<dt_int32>(); | |||
| ptr[0] = -dim1; | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); | |||
| shp2_head = opr::Subtensor::make(shp2, head_desc); | |||
| shp1 = opr::Concat::make({shp2_head, shp1}, 0, cn); | |||
| inp1 = inp1.broadcast(shp1); | |||
| head_desc[0].end = idx; | |||
| batch_shape = opr::Subtensor::make(shp2, head_desc); | |||
| } | |||
| if (dim1 == dim2) { | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| batch_shape = opr::Subtensor::make(shp1, head_desc); | |||
| } | |||
| auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||
| if (maxdim > 3) { | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||
| auto batch = opr::Reduce::make(batch_shape, {Reduce::Mode::PRODUCT, 0}); | |||
| shp1 = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||
| inp1 = inp1.reshape(shp1); | |||
| shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||
| shp2 = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||
| inp2 = inp2.reshape(shp2); | |||
| } | |||
| auto result = opr::BatchedMatrixMul::make( | |||
| inp1, inp2, matmul.param(), matmul.policy(), config); | |||
| if (maxdim > 3) { | |||
| auto result_shp = result.symshape(); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp_tail = opr::Subtensor::make(result_shp, tail_desc); | |||
| result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | |||
| result = result.reshape(result_shp); | |||
| } | |||
| if (remove_row) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| if (remove_col) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| return result; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| @@ -178,8 +363,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; | |||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
| DType dst_dtype; | |||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
| if (dim1 == 0 || dim2 == 0) { | |||
| return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||
| return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; | |||
| } | |||
| if (matmul.transposeA) | |||
| @@ -187,7 +378,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| if (matmul.transposeB) | |||
| std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); | |||
| TensorLayout dst_layout(layout1.dtype); | |||
| TensorLayout dst_layout(dst_dtype); | |||
| size_t di = 0; | |||
| if (dim1 > dim2) { | |||
| for (size_t i = 0; i < dim1 - 2; i++) | |||
| @@ -217,6 +408,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
| DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| DType dst_dtype; | |||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
| bool remove_row = false, remove_col = false; | |||
| if (dim1 == 1) { | |||
| dim1 = 2; | |||
| @@ -234,6 +430,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| TensorShape tshp, batch_shp; | |||
| size_t j = 0; | |||
| auto inp1 = inputs[0], inp2 = inputs[1]; | |||
| if (dim1 > dim2) { | |||
| for (size_t i = 0; i < dim1 - 2; i++) | |||
| tshp[j++] = layout1.shape[i]; | |||
| @@ -266,7 +463,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| shp2.ndim += 2; | |||
| size_t maxdim = dim1 > dim2 ? dim1 : dim2; | |||
| size_t nbatch = batch_shp[0]; | |||
| auto inp1 = inputs[0], inp2 = inputs[1]; | |||
| if (maxdim > 3) { | |||
| nbatch = std::accumulate( | |||
| batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, | |||
| @@ -274,29 +470,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| TensorLayout layout_a; | |||
| // batched_matmul does not support memory forwarding, so ensure contiguous | |||
| // manually | |||
| TensorShape nl1 = TensorShape( | |||
| {nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); | |||
| if (!layout1.try_reshape(layout_a, nl1)) { | |||
| inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||
| inp1->to_contiguous_inplace(); | |||
| layout1 = inp1->layout(); | |||
| } | |||
| inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||
| inp1->to_contiguous_inplace(); | |||
| layout1 = inp1->layout(); | |||
| layout_a = layout1.reshape(nl1); | |||
| layout1 = layout_a; | |||
| TensorShape nl2 = TensorShape( | |||
| {nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); | |||
| if (!layout2.try_reshape(layout_a, nl2)) { | |||
| inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||
| inp2->to_contiguous_inplace(); | |||
| layout2 = inp2->layout(); | |||
| } | |||
| inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||
| inp2->to_contiguous_inplace(); | |||
| layout2 = inp2->layout(); | |||
| layout_a = layout2.reshape(nl2); | |||
| layout2 = layout_a; | |||
| } | |||
| TensorLayout dst_layout( | |||
| {nbatch, matmul.transposeA ? layout1[2] : layout1[1], | |||
| matmul.transposeB ? layout2[1] : layout2[2]}, | |||
| layout1.dtype); | |||
| dst_dtype); | |||
| dst_layout.init_contiguous_stride(); | |||
| if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | |||
| @@ -317,9 +513,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| DeviceTensorND out = | |||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||
| DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| size_t sz = setup_algo<megdnn::BatchedMatrixMul>( | |||
| {layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | |||
| matmul.policy(), false); | |||
| @@ -246,7 +246,12 @@ private: | |||
| it.name, enumMember.substr(0, d)); | |||
| body += " break;\n"; | |||
| } | |||
| body += " default: break;\n"; | |||
| body += " default:\n"; | |||
| body += | |||
| formatv(" props_.emplace_back(\"{0}\", " | |||
| "\"INVALID\");\n", | |||
| it.name); | |||
| body += " break;\n"; | |||
| body += " }\n"; | |||
| } else { | |||
| auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | |||
| @@ -89,19 +89,35 @@ void OpDefEmitter::emit_header() { | |||
| gen_ctor("", "", " = default;"); | |||
| if (!op.getMgbAttributes().empty()) { | |||
| std::string strategy_val = ""; | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto&& i : op.getMgbAttributes()) { | |||
| if (attr_to_ctype(i.attr).compare("Strategy") == 0) { | |||
| strategy_val = i.name; | |||
| } | |||
| paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | |||
| initList.push_back(formatv("{0}({0}_)", i.name)); | |||
| } | |||
| paramList.push_back("std::string scope_ = {}"); | |||
| gen_ctor( | |||
| llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||
| " { set_scope(scope_); }"); | |||
| if (!strategy_val.empty()) { | |||
| gen_ctor( | |||
| llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||
| formatv(" {" | |||
| "\n set_scope(scope_);" | |||
| "\n mgb_assert(static_cast<uint32_t>({0}) <= " | |||
| "uint32_t(8));" | |||
| "\n }", | |||
| strategy_val)); | |||
| } else { | |||
| gen_ctor( | |||
| llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||
| " { set_scope(scope_); }"); | |||
| } | |||
| } | |||
| auto packedParams = op.getPackedParams(); | |||
| if (!packedParams.empty()) { | |||
| std::string strategy_val = ""; | |||
| std::vector<std::string> paramList, initList; | |||
| for (auto&& p : packedParams) { | |||
| auto&& paramFields = p.getFields(); | |||
| @@ -111,6 +127,9 @@ void OpDefEmitter::emit_header() { | |||
| paramFields.empty() ? paramType.str() | |||
| : formatv("{0} {1}", paramType, paramName)); | |||
| for (auto&& i : paramFields) { | |||
| if (i.name.compare("strategy") == 0) { | |||
| strategy_val = i.name; | |||
| } | |||
| initList.push_back(formatv("{0}({1}.{0})", i.name, paramName)); | |||
| } | |||
| } | |||
| @@ -118,9 +137,20 @@ void OpDefEmitter::emit_header() { | |||
| paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | |||
| initList.push_back(formatv("{0}({0}_)", i.name)); | |||
| } | |||
| gen_ctor( | |||
| llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); | |||
| if (!strategy_val.empty()) { | |||
| gen_ctor( | |||
| llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||
| formatv(" {" | |||
| "\n mgb_assert(static_cast<uint32_t>({0}) <= " | |||
| "uint32_t(8));" | |||
| "\n }", | |||
| strategy_val)); | |||
| } else { | |||
| gen_ctor( | |||
| llvm::join(paramList, ", "), | |||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); | |||
| } | |||
| } | |||
| if (!packedParams.empty()) { | |||
| @@ -43,9 +43,19 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { | |||
| def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; | |||
| def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { | |||
| let extraArguments = (ins | |||
| MgbUI32Attr:$dimA, | |||
| MgbUI32Attr:$dimB | |||
| ); | |||
| } | |||
| def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | |||
| def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { | |||
| let extraArguments = (ins | |||
| MgbUI32Attr:$dimA, | |||
| MgbUI32Attr:$dimB | |||
| ); | |||
| } | |||
| def Dot: MgbHashableOp<"Dot", [EmptyParam]>; | |||