| @@ -36,7 +36,7 @@ public: | |||||
| virtual void exec( | virtual void exec( | ||||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | ||||
| _megdnn_workspace workspace) = 0; | _megdnn_workspace workspace) = 0; | ||||
| void deduce_dtype(DType A, DType B, DType& C); | |||||
| MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | ||||
| virtual size_t get_workspace_in_bytes( | virtual size_t get_workspace_in_bytes( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | ||||
| @@ -73,7 +73,7 @@ public: | |||||
| virtual void exec( | virtual void exec( | ||||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | ||||
| _megdnn_workspace workspace) = 0; | _megdnn_workspace workspace) = 0; | ||||
| void deduce_dtype(DType A, DType B, DType& C); | |||||
| MGE_WIN_DECLSPEC_FUC void deduce_dtype(DType A, DType B, DType& C); | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | ||||
| virtual size_t get_workspace_in_bytes( | virtual size_t get_workspace_in_bytes( | ||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | ||||
| @@ -44,216 +44,6 @@ def _elwise(*args, mode): | |||||
| return _elwise_apply(args, mode) | return _elwise_apply(args, mode) | ||||
| @lru_cache(maxsize=None) | |||||
| def _get_extentedMatrixMulOp( | |||||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
| ): | |||||
| @subgraph("extentedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
| def extentedMatrixMulOp(inputs, f, c): | |||||
| assert len(inputs) == 2 | |||||
| inp1, inp2 = inputs | |||||
| _dim1, _dim2 = dim1, dim2 | |||||
| def build_shape_head(shape, idx=-1): | |||||
| # shape[:idx] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| def build_shape_tail(shape, idx=-1): | |||||
| # shape[idx:] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| remove_row, remove_col = False, False | |||||
| if _dim1 == 1: | |||||
| _dim1 = 2 | |||||
| remove_row = True | |||||
| if _dim2 == 1: | |||||
| _dim2 = 2 | |||||
| remove_col = True | |||||
| if remove_row: | |||||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
| if remove_col: | |||||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
| shape1 = f(builtin.GetVarShape(), inp1) | |||||
| shape2 = f(builtin.GetVarShape(), inp2) | |||||
| if _dim1 > 2: | |||||
| inp1 = f( | |||||
| builtin.Reshape(), | |||||
| inp1, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape1)), | |||||
| build_shape_tail(shape1), | |||||
| ), | |||||
| ) | |||||
| if _dim2 > 2: | |||||
| inp2 = f( | |||||
| builtin.Reshape(), | |||||
| inp2, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), build_shape_head(shape2)), | |||||
| build_shape_tail(shape2), | |||||
| ), | |||||
| ) | |||||
| op = builtin.MatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=strategy.value, | |||||
| ) | |||||
| result = f(op, inp1, inp2) | |||||
| result_shape = f(builtin.GetVarShape(), result) | |||||
| if _dim1 > 2: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape1), | |||||
| build_shape_tail(result_shape), | |||||
| ), | |||||
| ) | |||||
| if _dim2 > 2: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape2), | |||||
| build_shape_tail(result_shape), | |||||
| ), | |||||
| ) | |||||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
| if remove_row: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
| if remove_col: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
| return (result,), (True,) | |||||
| return extentedMatrixMulOp | |||||
| @lru_cache(maxsize=None) | |||||
| def _get_extentedBatchedMatrixMulOp( | |||||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||||
| ): | |||||
| @subgraph("extentedBatchedMatrixMulOp", dtype, device, 2, gopt_level=2) | |||||
| def extentedBatchedMatrixMulOp(inputs, f, c): | |||||
| assert len(inputs) == 2 | |||||
| inp1, inp2 = inputs | |||||
| _dim1, _dim2 = dim1, dim2 | |||||
| def build_shape_head(shape, idx=-2): | |||||
| # shape[:idx] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, False, True, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| def build_shape_tail(shape, idx=-2): | |||||
| # shape[idx:] | |||||
| return f( | |||||
| builtin.Subtensor(items=[[0, True, False, False, False]]), | |||||
| shape, | |||||
| c(idx, "int32"), | |||||
| ) | |||||
| remove_row, remove_col = False, False | |||||
| if _dim1 == 1: | |||||
| _dim1 = 2 | |||||
| remove_row = True | |||||
| if _dim2 == 1: | |||||
| _dim2 = 2 | |||||
| remove_col = True | |||||
| if remove_row: | |||||
| inp1 = f(builtin.AddAxis(axis=[0,]), inp1) | |||||
| if remove_col: | |||||
| inp2 = f(builtin.AddAxis(axis=[1,]), inp2) | |||||
| shape1 = f(builtin.GetVarShape(), inp1) | |||||
| shape2 = f(builtin.GetVarShape(), inp2) | |||||
| maxdim = _dim1 if _dim1 > _dim2 else _dim2 | |||||
| if _dim1 > _dim2: | |||||
| # broadcast | |||||
| shape2 = f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape1, idx=-_dim2), # shape1[:-_dim2] | |||||
| shape2, | |||||
| ) | |||||
| inp2 = f(builtin.Broadcast(), inp2, shape2) | |||||
| batch_shape = build_shape_head(shape1) | |||||
| if _dim2 > _dim1: | |||||
| # broadcast | |||||
| shape1 = f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| build_shape_head(shape2, idx=-_dim1), # shape2[:-_dim1] | |||||
| shape1, | |||||
| ) | |||||
| inp1 = f(builtin.Broadcast(), inp1, shape1) | |||||
| batch_shape = build_shape_head(shape2) | |||||
| if _dim1 == _dim2: | |||||
| batch_shape = build_shape_head(shape1) | |||||
| # compress inputs to 3d | |||||
| if maxdim > 3: | |||||
| inp1 = f( | |||||
| builtin.Reshape(), | |||||
| inp1, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
| build_shape_tail(shape1), | |||||
| ), | |||||
| ) | |||||
| inp2 = f( | |||||
| builtin.Reshape(), | |||||
| inp2, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| f(builtin.Reduce(mode="product", axis=0), batch_shape), | |||||
| build_shape_tail(shape2), | |||||
| ), | |||||
| ) | |||||
| op = builtin.BatchedMatrixMul( | |||||
| transposeA=transpose_a, | |||||
| transposeB=transpose_b, | |||||
| compute_mode=compute_mode, | |||||
| format=format, | |||||
| strategy=strategy.value, | |||||
| ) | |||||
| result = f(op, inp1, inp2) | |||||
| if maxdim > 3: | |||||
| result = f( | |||||
| builtin.Reshape(), | |||||
| result, | |||||
| f( | |||||
| builtin.Concat(axis=0, comp_node=device), | |||||
| batch_shape, | |||||
| build_shape_tail(f(builtin.GetVarShape(), result)), | |||||
| ), | |||||
| ) | |||||
| if remove_row: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 2]), result) | |||||
| if remove_col: | |||||
| result = f(builtin.RemoveAxis(axis=[maxdim - 1]), result) | |||||
| return (result,), (True,) | |||||
| return extentedBatchedMatrixMulOp | |||||
| class _Hashable: | class _Hashable: | ||||
| def __init__(self, value) -> None: | def __init__(self, value) -> None: | ||||
| self.value = value | self.value = value | ||||
| @@ -267,42 +57,6 @@ class _Hashable: | |||||
| return self.value == o.value | return self.value == o.value | ||||
| def symbolicMatrixMul( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||||
| ): | |||||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(strategy), | |||||
| ) | |||||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| def symbolicBatchedMatrixMul( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||||
| ): | |||||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(strategy), | |||||
| ) | |||||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| def _matmul( | def _matmul( | ||||
| inp1, | inp1, | ||||
| inp2, | inp2, | ||||
| @@ -342,11 +96,8 @@ def _matmul( | |||||
| transpose_a, | transpose_a, | ||||
| transpose_b, | transpose_b, | ||||
| compute_mode, | compute_mode, | ||||
| format, | |||||
| _config._benchmark_kernel, | _config._benchmark_kernel, | ||||
| _config._deterministic_kernel, | _config._deterministic_kernel, | ||||
| strategy, | |||||
| symbolicMatrixMul, | |||||
| ) | ) | ||||
| else: # dispath to BatchedMatrixMul | else: # dispath to BatchedMatrixMul | ||||
| # nx1(transpose_a=True), n>=3 | # nx1(transpose_a=True), n>=3 | ||||
| @@ -362,11 +113,8 @@ def _matmul( | |||||
| transpose_a, | transpose_a, | ||||
| transpose_b, | transpose_b, | ||||
| compute_mode, | compute_mode, | ||||
| format, | |||||
| _config._benchmark_kernel, | _config._benchmark_kernel, | ||||
| _config._deterministic_kernel, | _config._deterministic_kernel, | ||||
| strategy, | |||||
| symbolicBatchedMatrixMul, | |||||
| ) | ) | ||||
| @@ -32,7 +32,7 @@ from ..core.ops.builtin import ( | |||||
| TypeCvt, | TypeCvt, | ||||
| ) | ) | ||||
| from ..core.tensor import amp, megbrain_graph | from ..core.tensor import amp, megbrain_graph | ||||
| from ..core.tensor.array_method import _elwise_apply | |||||
| from ..core.tensor.array_method import _matmul | |||||
| from ..core.tensor.utils import ( | from ..core.tensor.utils import ( | ||||
| astensor1d, | astensor1d, | ||||
| cast_tensors, | cast_tensors, | ||||
| @@ -49,7 +49,7 @@ from ..utils.deprecation import deprecated_func | |||||
| from .debug_param import get_execution_strategy | from .debug_param import get_execution_strategy | ||||
| from .distributed import all_reduce_sum | from .distributed import all_reduce_sum | ||||
| from .elemwise import _elwise, exp, log, log1p, maximum, minimum | from .elemwise import _elwise, exp, log, log1p, maximum, minimum | ||||
| from .math import matmul, max, sum | |||||
| from .math import max, sum | |||||
| from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | from .tensor import broadcast_to, concat, expand_dims, ones, squeeze, zeros | ||||
| __all__ = [ | __all__ = [ | ||||
| @@ -127,7 +127,7 @@ def linear( | |||||
| bias: bias with shape `(out_features,)`. Default: None | bias: bias with shape `(out_features,)`. Default: None | ||||
| """ | """ | ||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
| ret = matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||||
| ret = _matmul(inp, weight, transpose_b=True, compute_mode=compute_mode) | |||||
| if bias is not None: | if bias is not None: | ||||
| if amp._enabled: | if amp._enabled: | ||||
| bias = bias.astype("float16") | bias = bias.astype("float16") | ||||
| @@ -1494,73 +1494,61 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
| py::object _matmul_cpp( | py::object _matmul_cpp( | ||||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | ||||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | ||||
| py::handle format, py::handle profile, py::handle determistic, | |||||
| py::handle strategy, py::handle func) { | |||||
| if (enable_fastpath(inp1)) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
| } | |||||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||||
| if (profile.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
| } else { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
| } | |||||
| if (determistic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op = MatrixMul::make( | |||||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||||
| py::object Op = py::cast(op); | |||||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
| return ret[0]; | |||||
| py::handle profile, py::handle determistic) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
| } | |||||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy = | |||||
| static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); | |||||
| if (profile.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
| } else { | } else { | ||||
| // fallback to traceable implementation | |||||
| return func( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||||
| strategy); | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
| } | } | ||||
| if (determistic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op = MatrixMul::make( | |||||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, | |||||
| dim1.cast<uint32_t>(), dim2.cast<uint32_t>()); | |||||
| py::object Op = py::cast(op); | |||||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
| return ret[0]; | |||||
| } | } | ||||
| py::object _batched_matmul_cpp( | py::object _batched_matmul_cpp( | ||||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | ||||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | ||||
| py::handle format, py::handle profile, py::handle determistic, | |||||
| py::handle strategy, py::handle func) { | |||||
| if (enable_fastpath(inp1)) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
| } | |||||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||||
| if (profile.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
| } else { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
| } | |||||
| if (determistic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||||
| py::object Op = py::cast(op); | |||||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
| return ret[0]; | |||||
| py::handle profile, py::handle determistic) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
| } | |||||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy = | |||||
| static_cast<::megdnn::param::ExecutionPolicy::Strategy>(0); | |||||
| if (profile.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
| } else { | } else { | ||||
| // fallback to traceable implementation | |||||
| return func( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||||
| strategy); | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
| } | } | ||||
| if (determistic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX, | |||||
| dim1.cast<uint32_t>(), dim2.cast<uint32_t>()); | |||||
| py::object Op = py::cast(op); | |||||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
| return ret[0]; | |||||
| } | } | ||||
| py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { | py::object _pixel_shuffle_cpp(py::handle inp, py::handle val, py::handle func) { | ||||
| @@ -1671,7 +1659,7 @@ PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | try { | ||||
| return _matmul_cpp( | return _matmul_cpp( | ||||
| args[0], args[1], args[2], args[3], args[4], args[5], args[6], | args[0], args[1], args[2], args[3], args[4], args[5], args[6], | ||||
| args[7], args[8], args[9], args[10], args[11]) | |||||
| args[7], args[8]) | |||||
| .release() | .release() | ||||
| .ptr(); | .ptr(); | ||||
| } | } | ||||
| @@ -1682,7 +1670,7 @@ PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs | |||||
| try { | try { | ||||
| return _batched_matmul_cpp( | return _batched_matmul_cpp( | ||||
| args[0], args[1], args[2], args[3], args[4], args[5], args[6], | args[0], args[1], args[2], args[3], args[4], args[5], args[6], | ||||
| args[7], args[8], args[9], args[10], args[11]) | |||||
| args[7], args[8]) | |||||
| .release() | .release() | ||||
| .ptr(); | .ptr(); | ||||
| } | } | ||||
| @@ -20,7 +20,6 @@ import megengine.optimizer as optim | |||||
| from megengine import tensor | from megengine import tensor | ||||
| from megengine.autodiff import GradManager | from megengine.autodiff import GradManager | ||||
| from megengine.jit import trace | from megengine.jit import trace | ||||
| from megengine.traced_module import trace_module | |||||
| @contextlib.contextmanager | @contextlib.contextmanager | ||||
| @@ -2,8 +2,12 @@ | |||||
| #include "../blob_manager_impl.h" | #include "../blob_manager_impl.h" | ||||
| #include "../dnn_op_helper.h" | #include "../dnn_op_helper.h" | ||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| #include "megbrain/graph/symbol_var.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | #include "megbrain/imperative/ops/autogen.h" | ||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/blas.h" | #include "megbrain/opr/blas.h" | ||||
| #include "megbrain/opr/io.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "../algo_chooser.h" | #include "../algo_chooser.h" | ||||
| @@ -12,12 +16,93 @@ namespace imperative { | |||||
| namespace { | namespace { | ||||
| namespace matrix_mul { | namespace matrix_mul { | ||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
| auto&& matmul = def.cast_final_safe<MatrixMul>(); | auto&& matmul = def.cast_final_safe<MatrixMul>(); | ||||
| mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
| OperatorNodeConfig config{matmul.make_name()}; | |||||
| return opr::MatrixMul::make( | |||||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
| auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||||
| auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||||
| auto cn = inputs[0]->comp_node(); | |||||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||||
| using IndexDesc = opr::Subtensor::IndexDesc; | |||||
| OperatorNodeConfig config{matmul.make_name(), cn}; | |||||
| DTypeScalar vi{-1}; | |||||
| auto graph = inputs[0]->owner_graph(); | |||||
| bool remove_row = false, remove_col = false; | |||||
| if (dim1 == 1) { | |||||
| dim1 = 2; | |||||
| remove_row = true; | |||||
| inp1 = inp1.add_axis(0); | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| dim2 = 2; | |||||
| remove_col = true; | |||||
| inp2 = inp2.add_axis(1); | |||||
| } | |||||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||||
| if (dim1 > 2) { | |||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||||
| auto shp1 = inp1.symshape(); | |||||
| IndexDesc head_desc(1); | |||||
| head_desc[0].end = idx; | |||||
| shp1_head = opr::Subtensor::make(shp1, head_desc); | |||||
| auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); | |||||
| IndexDesc tail_desc(1); | |||||
| tail_desc[0].begin = idx; | |||||
| shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||||
| auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||||
| inp1 = inp1.reshape(tshp); | |||||
| } | |||||
| if (dim2 > 2) { | |||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||||
| auto shp2 = inp2.symshape(); | |||||
| IndexDesc head_desc(1); | |||||
| head_desc[0].end = idx; | |||||
| shp2_head = opr::Subtensor::make(shp2, head_desc); | |||||
| auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); | |||||
| IndexDesc tail_desc(1); | |||||
| tail_desc[0].begin = idx; | |||||
| auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||||
| auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||||
| inp2 = inp2.reshape(tshp); | |||||
| } | |||||
| auto result = | |||||
| opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config); | |||||
| if (dim1 > 2) { | |||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||||
| auto result_shape = result.symshape(); | |||||
| IndexDesc tail_desc(1); | |||||
| tail_desc[0].begin = idx; | |||||
| auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||||
| auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); | |||||
| result = result.reshape(tshp); | |||||
| } | |||||
| if (dim2 > 2) { | |||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||||
| auto result_shape = result.symshape(); | |||||
| IndexDesc tail_desc(1); | |||||
| tail_desc[0].begin = idx; | |||||
| auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||||
| auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); | |||||
| result = result.reshape(tshp); | |||||
| } | |||||
| auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||||
| if (remove_row) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| if (remove_col) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| return result; | |||||
| } | } | ||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
| @@ -27,8 +112,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| auto layout2 = inputs[1].layout; | auto layout2 = inputs[1].layout; | ||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | ||||
| DType dst_dtype; | |||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||||
| if (dim1 == 0 || dim2 == 0) { | if (dim1 == 0 || dim2 == 0) { | ||||
| return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||||
| return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; | |||||
| } | } | ||||
| if (matmul.transposeA) | if (matmul.transposeA) | ||||
| @@ -37,7 +128,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| std::swap(layout2[0], layout2[1]); | std::swap(layout2[0], layout2[1]); | ||||
| mgb_assert(layout1[dim1 - 1] == layout2[0]); | mgb_assert(layout1[dim1 - 1] == layout2[0]); | ||||
| TensorLayout dst_layout(layout1.dtype); | |||||
| TensorLayout dst_layout(dst_dtype); | |||||
| size_t ci = 0; | size_t ci = 0; | ||||
| for (size_t i = 0; i < dim1 - 1; i++) | for (size_t i = 0; i < dim1 - 1; i++) | ||||
| dst_layout[ci++] = layout1[i]; | dst_layout[ci++] = layout1[i]; | ||||
| @@ -61,6 +153,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| SmallVector<TensorND> inp_tensornds(inputs.size()); | SmallVector<TensorND> inp_tensornds(inputs.size()); | ||||
| TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | ||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| DType dst_dtype; | |||||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||||
| // only matters when layout1 has dim 2 | // only matters when layout1 has dim 2 | ||||
| if (matmul.transposeA) | if (matmul.transposeA) | ||||
| std::swap(layout1.shape[0], layout1.shape[1]); | std::swap(layout1.shape[0], layout1.shape[1]); | ||||
| @@ -69,7 +167,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| std::swap(layout2.shape[0], layout2.shape[1]); | std::swap(layout2.shape[0], layout2.shape[1]); | ||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | ||||
| TensorLayout real_dst_layout(layout1.dtype); | |||||
| TensorLayout real_dst_layout(dst_dtype); | |||||
| if (validated) { | if (validated) { | ||||
| real_dst_layout = output_descs[0].layout; | real_dst_layout = output_descs[0].layout; | ||||
| } else { | } else { | ||||
| @@ -126,12 +224,9 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | inp_tensornds[1] = inputs[1]->dnn_tensor(); | ||||
| } | } | ||||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); | |||||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | |||||
| dst_layout.init_contiguous_stride(); | dst_layout.init_contiguous_stride(); | ||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| DeviceTensorND out = | DeviceTensorND out = | ||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | ||||
| size_t sz = setup_algo<megdnn::MatrixMul>( | size_t sz = setup_algo<megdnn::MatrixMul>( | ||||
| @@ -167,9 +262,99 @@ namespace batched_matrix_mul { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
| auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | ||||
| mgb_assert(inputs.size() == 2); | mgb_assert(inputs.size() == 2); | ||||
| OperatorNodeConfig config{matmul.make_name()}; | |||||
| return opr::BatchedMatrixMul::make( | |||||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
| auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||||
| auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||||
| auto cn = inputs[0]->comp_node(); | |||||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||||
| using IndexDesc = opr::Subtensor::IndexDesc; | |||||
| OperatorNodeConfig config{matmul.make_name(), cn}; | |||||
| DTypeScalar vi{-2}; | |||||
| auto graph = inputs[0]->owner_graph(); | |||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||||
| bool remove_row = false, remove_col = false; | |||||
| if (dim1 == 1) { | |||||
| dim1 = 2; | |||||
| remove_row = true; | |||||
| inp1 = inp1.add_axis(0); | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| dim2 = 2; | |||||
| remove_col = true; | |||||
| inp2 = inp2.add_axis(1); | |||||
| } | |||||
| auto shp1 = inp1.symshape(); | |||||
| auto shp2 = inp2.symshape(); | |||||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||||
| SymbolVar batch_shape; | |||||
| if (dim1 > dim2) { | |||||
| HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); | |||||
| auto* ptr = hv.ptr<dt_int32>(); | |||||
| ptr[0] = -dim2; | |||||
| IndexDesc head_desc(1); | |||||
| head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); | |||||
| shp1_head = opr::Subtensor::make(shp1, head_desc); | |||||
| shp2 = opr::Concat::make({shp1_head, shp2}, 0, cn); | |||||
| inp2 = inp2.broadcast(shp2); | |||||
| head_desc[0].end = idx; | |||||
| batch_shape = opr::Subtensor::make(shp1, head_desc); | |||||
| } | |||||
| if (dim2 > dim1) { | |||||
| HostTensorND hv = HostTensorND(cn, {1}, dtype::Int32()); | |||||
| auto* ptr = hv.ptr<dt_int32>(); | |||||
| ptr[0] = -dim1; | |||||
| IndexDesc head_desc(1); | |||||
| head_desc[0].end = opr::ImmutableTensor::make(*graph, hv, config); | |||||
| shp2_head = opr::Subtensor::make(shp2, head_desc); | |||||
| shp1 = opr::Concat::make({shp2_head, shp1}, 0, cn); | |||||
| inp1 = inp1.broadcast(shp1); | |||||
| head_desc[0].end = idx; | |||||
| batch_shape = opr::Subtensor::make(shp2, head_desc); | |||||
| } | |||||
| if (dim1 == dim2) { | |||||
| IndexDesc head_desc(1); | |||||
| head_desc[0].end = idx; | |||||
| batch_shape = opr::Subtensor::make(shp1, head_desc); | |||||
| } | |||||
| auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||||
| if (maxdim > 3) { | |||||
| IndexDesc tail_desc(1); | |||||
| tail_desc[0].begin = idx; | |||||
| shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||||
| auto batch = opr::Reduce::make(batch_shape, {Reduce::Mode::PRODUCT, 0}); | |||||
| shp1 = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||||
| inp1 = inp1.reshape(shp1); | |||||
| shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||||
| shp2 = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||||
| inp2 = inp2.reshape(shp2); | |||||
| } | |||||
| auto result = opr::BatchedMatrixMul::make( | |||||
| inp1, inp2, matmul.param(), matmul.policy(), config); | |||||
| if (maxdim > 3) { | |||||
| auto result_shp = result.symshape(); | |||||
| IndexDesc tail_desc(1); | |||||
| tail_desc[0].begin = idx; | |||||
| auto shp_tail = opr::Subtensor::make(result_shp, tail_desc); | |||||
| result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | |||||
| result = result.reshape(result_shp); | |||||
| } | |||||
| if (remove_row) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| if (remove_col) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| return result; | |||||
| } | } | ||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
| @@ -178,8 +363,14 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; | TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; | ||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | ||||
| DType dst_dtype; | |||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||||
| if (dim1 == 0 || dim2 == 0) { | if (dim1 == 0 || dim2 == 0) { | ||||
| return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||||
| return {{{TensorLayout(dst_dtype), inputs[0].comp_node}}, false}; | |||||
| } | } | ||||
| if (matmul.transposeA) | if (matmul.transposeA) | ||||
| @@ -187,7 +378,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| if (matmul.transposeB) | if (matmul.transposeB) | ||||
| std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); | std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); | ||||
| TensorLayout dst_layout(layout1.dtype); | |||||
| TensorLayout dst_layout(dst_dtype); | |||||
| size_t di = 0; | size_t di = 0; | ||||
| if (dim1 > dim2) { | if (dim1 > dim2) { | ||||
| for (size_t i = 0; i < dim1 - 2; i++) | for (size_t i = 0; i < dim1 - 2; i++) | ||||
| @@ -217,6 +408,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | ||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | ||||
| DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| DType dst_dtype; | |||||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||||
| bool remove_row = false, remove_col = false; | bool remove_row = false, remove_col = false; | ||||
| if (dim1 == 1) { | if (dim1 == 1) { | ||||
| dim1 = 2; | dim1 = 2; | ||||
| @@ -234,6 +430,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| TensorShape tshp, batch_shp; | TensorShape tshp, batch_shp; | ||||
| size_t j = 0; | size_t j = 0; | ||||
| auto inp1 = inputs[0], inp2 = inputs[1]; | |||||
| if (dim1 > dim2) { | if (dim1 > dim2) { | ||||
| for (size_t i = 0; i < dim1 - 2; i++) | for (size_t i = 0; i < dim1 - 2; i++) | ||||
| tshp[j++] = layout1.shape[i]; | tshp[j++] = layout1.shape[i]; | ||||
| @@ -266,7 +463,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| shp2.ndim += 2; | shp2.ndim += 2; | ||||
| size_t maxdim = dim1 > dim2 ? dim1 : dim2; | size_t maxdim = dim1 > dim2 ? dim1 : dim2; | ||||
| size_t nbatch = batch_shp[0]; | size_t nbatch = batch_shp[0]; | ||||
| auto inp1 = inputs[0], inp2 = inputs[1]; | |||||
| if (maxdim > 3) { | if (maxdim > 3) { | ||||
| nbatch = std::accumulate( | nbatch = std::accumulate( | ||||
| batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, | batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, | ||||
| @@ -274,29 +470,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| TensorLayout layout_a; | TensorLayout layout_a; | ||||
| // batched_matmul does not support memory forwarding, so ensure contiguous | |||||
| // manually | |||||
| TensorShape nl1 = TensorShape( | TensorShape nl1 = TensorShape( | ||||
| {nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); | {nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); | ||||
| if (!layout1.try_reshape(layout_a, nl1)) { | |||||
| inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||||
| inp1->to_contiguous_inplace(); | |||||
| layout1 = inp1->layout(); | |||||
| } | |||||
| inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||||
| inp1->to_contiguous_inplace(); | |||||
| layout1 = inp1->layout(); | |||||
| layout_a = layout1.reshape(nl1); | |||||
| layout1 = layout_a; | layout1 = layout_a; | ||||
| TensorShape nl2 = TensorShape( | TensorShape nl2 = TensorShape( | ||||
| {nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); | {nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); | ||||
| if (!layout2.try_reshape(layout_a, nl2)) { | |||||
| inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||||
| inp2->to_contiguous_inplace(); | |||||
| layout2 = inp2->layout(); | |||||
| } | |||||
| inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||||
| inp2->to_contiguous_inplace(); | |||||
| layout2 = inp2->layout(); | |||||
| layout_a = layout2.reshape(nl2); | |||||
| layout2 = layout_a; | layout2 = layout_a; | ||||
| } | } | ||||
| TensorLayout dst_layout( | TensorLayout dst_layout( | ||||
| {nbatch, matmul.transposeA ? layout1[2] : layout1[1], | {nbatch, matmul.transposeA ? layout1[2] : layout1[1], | ||||
| matmul.transposeB ? layout2[1] : layout2[2]}, | matmul.transposeB ? layout2[1] : layout2[2]}, | ||||
| layout1.dtype); | |||||
| dst_dtype); | |||||
| dst_layout.init_contiguous_stride(); | dst_layout.init_contiguous_stride(); | ||||
| if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | ||||
| @@ -317,9 +513,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| DeviceTensorND out = | DeviceTensorND out = | ||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | ||||
| DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| size_t sz = setup_algo<megdnn::BatchedMatrixMul>( | size_t sz = setup_algo<megdnn::BatchedMatrixMul>( | ||||
| {layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | {layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | ||||
| matmul.policy(), false); | matmul.policy(), false); | ||||
| @@ -246,7 +246,12 @@ private: | |||||
| it.name, enumMember.substr(0, d)); | it.name, enumMember.substr(0, d)); | ||||
| body += " break;\n"; | body += " break;\n"; | ||||
| } | } | ||||
| body += " default: break;\n"; | |||||
| body += " default:\n"; | |||||
| body += | |||||
| formatv(" props_.emplace_back(\"{0}\", " | |||||
| "\"INVALID\");\n", | |||||
| it.name); | |||||
| body += " break;\n"; | |||||
| body += " }\n"; | body += " }\n"; | ||||
| } else { | } else { | ||||
| auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | auto&& attr = llvm::cast<MgbHashableAttrMixin>(it.attr); | ||||
| @@ -89,19 +89,35 @@ void OpDefEmitter::emit_header() { | |||||
| gen_ctor("", "", " = default;"); | gen_ctor("", "", " = default;"); | ||||
| if (!op.getMgbAttributes().empty()) { | if (!op.getMgbAttributes().empty()) { | ||||
| std::string strategy_val = ""; | |||||
| std::vector<std::string> paramList, initList; | std::vector<std::string> paramList, initList; | ||||
| for (auto&& i : op.getMgbAttributes()) { | for (auto&& i : op.getMgbAttributes()) { | ||||
| if (attr_to_ctype(i.attr).compare("Strategy") == 0) { | |||||
| strategy_val = i.name; | |||||
| } | |||||
| paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | ||||
| initList.push_back(formatv("{0}({0}_)", i.name)); | initList.push_back(formatv("{0}({0}_)", i.name)); | ||||
| } | } | ||||
| paramList.push_back("std::string scope_ = {}"); | paramList.push_back("std::string scope_ = {}"); | ||||
| gen_ctor( | |||||
| llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||||
| " { set_scope(scope_); }"); | |||||
| if (!strategy_val.empty()) { | |||||
| gen_ctor( | |||||
| llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||||
| formatv(" {" | |||||
| "\n set_scope(scope_);" | |||||
| "\n mgb_assert(static_cast<uint32_t>({0}) <= " | |||||
| "uint32_t(8));" | |||||
| "\n }", | |||||
| strategy_val)); | |||||
| } else { | |||||
| gen_ctor( | |||||
| llvm::join(paramList, ", "), ": " + llvm::join(initList, ", "), | |||||
| " { set_scope(scope_); }"); | |||||
| } | |||||
| } | } | ||||
| auto packedParams = op.getPackedParams(); | auto packedParams = op.getPackedParams(); | ||||
| if (!packedParams.empty()) { | if (!packedParams.empty()) { | ||||
| std::string strategy_val = ""; | |||||
| std::vector<std::string> paramList, initList; | std::vector<std::string> paramList, initList; | ||||
| for (auto&& p : packedParams) { | for (auto&& p : packedParams) { | ||||
| auto&& paramFields = p.getFields(); | auto&& paramFields = p.getFields(); | ||||
| @@ -111,6 +127,9 @@ void OpDefEmitter::emit_header() { | |||||
| paramFields.empty() ? paramType.str() | paramFields.empty() ? paramType.str() | ||||
| : formatv("{0} {1}", paramType, paramName)); | : formatv("{0} {1}", paramType, paramName)); | ||||
| for (auto&& i : paramFields) { | for (auto&& i : paramFields) { | ||||
| if (i.name.compare("strategy") == 0) { | |||||
| strategy_val = i.name; | |||||
| } | |||||
| initList.push_back(formatv("{0}({1}.{0})", i.name, paramName)); | initList.push_back(formatv("{0}({1}.{0})", i.name, paramName)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -118,9 +137,20 @@ void OpDefEmitter::emit_header() { | |||||
| paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | paramList.push_back(formatv("{0} {1}_", attr_to_ctype(i.attr), i.name)); | ||||
| initList.push_back(formatv("{0}({0}_)", i.name)); | initList.push_back(formatv("{0}({0}_)", i.name)); | ||||
| } | } | ||||
| gen_ctor( | |||||
| llvm::join(paramList, ", "), | |||||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); | |||||
| if (!strategy_val.empty()) { | |||||
| gen_ctor( | |||||
| llvm::join(paramList, ", "), | |||||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), | |||||
| formatv(" {" | |||||
| "\n mgb_assert(static_cast<uint32_t>({0}) <= " | |||||
| "uint32_t(8));" | |||||
| "\n }", | |||||
| strategy_val)); | |||||
| } else { | |||||
| gen_ctor( | |||||
| llvm::join(paramList, ", "), | |||||
| initList.empty() ? "" : ": " + llvm::join(initList, ", "), " {}"); | |||||
| } | |||||
| } | } | ||||
| if (!packedParams.empty()) { | if (!packedParams.empty()) { | ||||
| @@ -43,9 +43,19 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { | |||||
| def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; | def MatrixInverse: MgbHashableOp<"MatrixInverse", [EmptyParam]>; | ||||
| def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | |||||
| def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { | |||||
| let extraArguments = (ins | |||||
| MgbUI32Attr:$dimA, | |||||
| MgbUI32Attr:$dimB | |||||
| ); | |||||
| } | |||||
| def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; | |||||
| def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]> { | |||||
| let extraArguments = (ins | |||||
| MgbUI32Attr:$dimA, | |||||
| MgbUI32Attr:$dimB | |||||
| ); | |||||
| } | |||||
| def Dot: MgbHashableOp<"Dot", [EmptyParam]>; | def Dot: MgbHashableOp<"Dot", [EmptyParam]>; | ||||