GitOrigin-RevId: 4ceb2eb601
tags/v1.10.0
| @@ -20,9 +20,10 @@ from .._imperative_rt.core2 import ( | |||||
| Tensor, | Tensor, | ||||
| apply, | apply, | ||||
| astype_cpp, | astype_cpp, | ||||
| batched_matmul_cpp, | |||||
| broadcast_cpp, | broadcast_cpp, | ||||
| dtype_promotion, | |||||
| getitem_cpp, | getitem_cpp, | ||||
| matmul_cpp, | |||||
| ) | ) | ||||
| from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar | ||||
| from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp | from .._imperative_rt.core2 import reshape_cpp, setitem_cpp, squeeze_cpp, transpose_cpp | ||||
| @@ -266,6 +267,42 @@ class _Hashable: | |||||
| return self.value == o.value | return self.value == o.value | ||||
| def symbolicMatrixMul( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||||
| ): | |||||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(strategy), | |||||
| ) | |||||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| def symbolicBatchedMatrixMul( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy | |||||
| ): | |||||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| dim1, | |||||
| dim2, | |||||
| transpose_a, | |||||
| transpose_b, | |||||
| compute_mode, | |||||
| format, | |||||
| strategy=_Hashable(strategy), | |||||
| ) | |||||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| def _matmul( | def _matmul( | ||||
| inp1, | inp1, | ||||
| inp2, | inp2, | ||||
| @@ -274,16 +311,6 @@ def _matmul( | |||||
| compute_mode="default", | compute_mode="default", | ||||
| format="default", | format="default", | ||||
| ): | ): | ||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp1, inp2 = cast_tensors(inp1, inp2) | |||||
| else: | |||||
| dtype = dtype_promotion(inp1, inp2) | |||||
| if inp1.dtype != dtype: | |||||
| inp1 = inp1.astype(dtype) | |||||
| if inp2.dtype != dtype: | |||||
| inp2 = inp2.astype(dtype) | |||||
| dim1, dim2 = inp1.ndim, inp2.ndim | dim1, dim2 = inp1.ndim, inp2.ndim | ||||
| assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
| maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
| @@ -301,34 +328,46 @@ def _matmul( | |||||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | if dim1 == 1 and dim2 == 1: # dispatch to Dot | ||||
| (result,) = apply(builtin.Dot(), inp1, inp2) | (result,) = apply(builtin.Dot(), inp1, inp2) | ||||
| return result | return result | ||||
| elif maxdim <= 2 or dim2 <= 2: # dispath to MatrixMul | |||||
| extentedMatrixMulOp = _get_extentedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| elif maxdim <= 2 or (dim2 <= 2 and not transpose_a): # dispath to MatrixMul | |||||
| # 2x1 | |||||
| # 1x2 | |||||
| # 2x2 | |||||
| # nx1(transpose_a=False), n>=3 | |||||
| # nx2(transpose_a=False), n>=3 | |||||
| return matmul_cpp( | |||||
| inp1, | |||||
| inp2, | |||||
| dim1, | dim1, | ||||
| dim2, | dim2, | ||||
| transpose_a, | transpose_a, | ||||
| transpose_b, | transpose_b, | ||||
| compute_mode, | compute_mode, | ||||
| format, | format, | ||||
| strategy=_Hashable(strategy), | |||||
| _config._benchmark_kernel, | |||||
| _config._deterministic_kernel, | |||||
| strategy, | |||||
| symbolicMatrixMul, | |||||
| ) | ) | ||||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| else: # dispath to BatchedMatrixMul | else: # dispath to BatchedMatrixMul | ||||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||||
| inp1.device, | |||||
| inp1.dtype, | |||||
| # nx1(transpose_a=True), n>=3 | |||||
| # nx2(transpose_a=True), n>=3 | |||||
| # nxm,n>=3,m>=3 | |||||
| # 1xm,m>=3 | |||||
| # 2xm,m>=3 | |||||
| return batched_matmul_cpp( | |||||
| inp1, | |||||
| inp2, | |||||
| dim1, | dim1, | ||||
| dim2, | dim2, | ||||
| transpose_a, | transpose_a, | ||||
| transpose_b, | transpose_b, | ||||
| compute_mode, | compute_mode, | ||||
| format, | format, | ||||
| strategy=_Hashable(strategy), | |||||
| _config._benchmark_kernel, | |||||
| _config._deterministic_kernel, | |||||
| strategy, | |||||
| symbolicBatchedMatrixMul, | |||||
| ) | ) | ||||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||||
| return result | |||||
| def _unary_elwise(mode): | def _unary_elwise(mode): | ||||
| @@ -10,7 +10,7 @@ import collections | |||||
| import math | import math | ||||
| from typing import Iterable, Optional, Sequence, Tuple, Union | from typing import Iterable, Optional, Sequence, Tuple, Union | ||||
| from ..core._imperative_rt.core2 import Const, apply, dtype_promotion | |||||
| from ..core._imperative_rt.core2 import Const, apply | |||||
| from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.tensor.array_method import _matmul | from ..core.tensor.array_method import _matmul | ||||
| @@ -17,7 +17,6 @@ from ..core._imperative_rt.core2 import ( | |||||
| apply, | apply, | ||||
| dtype_promotion, | dtype_promotion, | ||||
| ) | ) | ||||
| from ..core._imperative_rt.ops import SubgraphBuilder as _SubgraphBuilder | |||||
| from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed | ||||
| from ..core.ops import builtin | from ..core.ops import builtin | ||||
| from ..core.ops.builtin import ( | from ..core.ops.builtin import ( | ||||
| @@ -177,16 +176,6 @@ def conv1d( | |||||
| assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | assert compute_mode.lower() == "default" or compute_mode.name == "DEFAULT" | ||||
| assert inp.ndim == 3, "the input dimension of conv1d should be 3" | assert inp.ndim == 3, "the input dimension of conv1d should be 3" | ||||
| assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | assert weight.ndim == 3, "the weight dimension of conv1d should be 3" | ||||
| if amp._enabled: | |||||
| compute_mode = "float32" | |||||
| inp, weight, bias = cast_tensors(inp, weight, bias) | |||||
| else: | |||||
| dtype = dtype_promotion(inp, weight) | |||||
| if inp.dtype != dtype: | |||||
| inp = inp.astype(dtype) | |||||
| if weight.dtype != dtype: | |||||
| weight = weight.astype(dtype) | |||||
| if bias is not None: | if bias is not None: | ||||
| assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | assert bias.ndim == 3, "the bias dimension of conv1d should be 3" | ||||
| @@ -522,12 +511,6 @@ def local_conv2d( | |||||
| pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
| dilate_h, dilate_w = expand_hw(dilation) | dilate_h, dilate_w = expand_hw(dilation) | ||||
| dtype = dtype_promotion(inp, weight) | |||||
| if inp.dtype != dtype: | |||||
| inp = inp.astype(dtype) | |||||
| if weight.dtype != dtype: | |||||
| weight = weight.astype(dtype) | |||||
| # local conv only support "dense" mode, but weight could contain group dimension. | # local conv only support "dense" mode, but weight could contain group dimension. | ||||
| op = builtin.GroupLocal( | op = builtin.GroupLocal( | ||||
| stride_h=stride_h, | stride_h=stride_h, | ||||
| @@ -433,6 +433,8 @@ WRAP_FUNC_PY35(reshape_cpp); | |||||
| WRAP_FUNC_PY35(adaptive_pool2d_cpp); | WRAP_FUNC_PY35(adaptive_pool2d_cpp); | ||||
| WRAP_FUNC_PY35(Const); | WRAP_FUNC_PY35(Const); | ||||
| WRAP_FUNC_PY35(astype_cpp); | WRAP_FUNC_PY35(astype_cpp); | ||||
| WRAP_FUNC_PY35(matmul_cpp); | |||||
| WRAP_FUNC_PY35(batched_matmul_cpp); | |||||
| WRAP_FUNC_PY35(convert_single_value_cpp); | WRAP_FUNC_PY35(convert_single_value_cpp); | ||||
| WRAP_FUNC_PY35(convert_inputs_cpp); | WRAP_FUNC_PY35(convert_inputs_cpp); | ||||
| WRAP_FUNC_PY35(astensor1d_cpp); | WRAP_FUNC_PY35(astensor1d_cpp); | ||||
| @@ -588,6 +590,8 @@ void init_tensor(py::module m) { | |||||
| MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp), | MGE_PY_INTERFACE(adaptive_pool2d_cpp, adaptive_pool2d_cpp), | ||||
| MGE_PY_INTERFACE(Const, Const), | MGE_PY_INTERFACE(Const, Const), | ||||
| MGE_PY_INTERFACE(astype_cpp, astype_cpp), | MGE_PY_INTERFACE(astype_cpp, astype_cpp), | ||||
| MGE_PY_INTERFACE(matmul_cpp, matmul_cpp), | |||||
| MGE_PY_INTERFACE(batched_matmul_cpp, batched_matmul_cpp), | |||||
| MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | MGE_PY_INTERFACE(convert_single_value_cpp, convert_single_value_cpp), | ||||
| MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | MGE_PY_INTERFACE(convert_inputs_cpp, convert_inputs_cpp), | ||||
| MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), | MGE_PY_INTERFACE(astensor1d_cpp, astensor1d_cpp), | ||||
| @@ -1490,6 +1490,78 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
| return ret[0]; | return ret[0]; | ||||
| } | } | ||||
| py::object _matmul_cpp( | |||||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||||
| py::handle format, py::handle profile, py::handle determistic, | |||||
| py::handle strategy, py::handle func) { | |||||
| if (enable_fastpath(inp1)) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
| } | |||||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||||
| if (profile.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
| } else { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
| } | |||||
| if (determistic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op = MatrixMul::make( | |||||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||||
| py::object Op = py::cast(op); | |||||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
| return ret[0]; | |||||
| } else { | |||||
| // fallback to traceable implementation | |||||
| return func( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||||
| strategy); | |||||
| } | |||||
| } | |||||
| py::object _batched_matmul_cpp( | |||||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||||
| py::handle format, py::handle profile, py::handle determistic, | |||||
| py::handle strategy, py::handle func) { | |||||
| if (enable_fastpath(inp1)) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||||
| mode = ::megdnn::param::MatrixMul::ComputeMode::FLOAT32; | |||||
| } | |||||
| ::megdnn::param::ExecutionPolicy::Strategy cstrategy; | |||||
| if (profile.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::PROFILE; | |||||
| } else { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||||
| } | |||||
| if (determistic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||||
| transpose_a.cast<bool>(), transpose_b.cast<bool>(), mode, | |||||
| ::megdnn::param::MatrixMul::Format::DEFAULT, cstrategy, UINT64_MAX); | |||||
| py::object Op = py::cast(op); | |||||
| PyObject* p[3] = {Op.ptr(), inp1.ptr(), inp2.ptr()}; | |||||
| py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 3)); | |||||
| return ret[0]; | |||||
| } else { | |||||
| // fallback to traceable implementation | |||||
| return func( | |||||
| inp1, inp2, dim1, dim2, transpose_a, transpose_b, compute_mode, format, | |||||
| strategy); | |||||
| } | |||||
| } | |||||
| PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) { | ||||
| try { | try { | ||||
| return _make_shape_tuple(args[0]).release().ptr(); | return _make_shape_tuple(args[0]).release().ptr(); | ||||
| @@ -1574,6 +1646,28 @@ PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | PYEXT17_TRANSLATE_EXC_RET(nullptr) | ||||
| } | } | ||||
| PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _matmul_cpp( | |||||
| args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||||
| args[7], args[8], args[9], args[10], args[11]) | |||||
| .release() | |||||
| .ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs) { | |||||
| try { | |||||
| return _batched_matmul_cpp( | |||||
| args[0], args[1], args[2], args[3], args[4], args[5], args[6], | |||||
| args[7], args[8], args[9], args[10], args[11]) | |||||
| .release() | |||||
| .ptr(); | |||||
| } | |||||
| PYEXT17_TRANSLATE_EXC_RET(nullptr) | |||||
| } | |||||
| PyObject* convert_single_value_cpp( | PyObject* convert_single_value_cpp( | ||||
| PyObject* self, PyObject* const* args, size_t nargs) { | PyObject* self, PyObject* const* args, size_t nargs) { | ||||
| try { | try { | ||||
| @@ -30,6 +30,10 @@ PyObject* Const(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* astype_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* batched_matmul_cpp(PyObject* self, PyObject* const* args, size_t nargs); | |||||
| PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* convert_single_value_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | PyObject* convert_inputs_cpp(PyObject* self, PyObject* const* args, size_t nargs); | ||||
| @@ -1,87 +0,0 @@ | |||||
| #include "megbrain/imperative/opr_utility.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/imperative/utils/stats.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/blas.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "../blob_manager_impl.h" | |||||
| #include "../dnn_op_helper.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { | |||||
| namespace dot { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final_safe<Dot>(); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::Dot::make(inputs[0], inputs[1], config); | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto comp_node = inputs[0]->comp_node(); | |||||
| using TensorND = megdnn::TensorND; | |||||
| SmallVector<TensorND> inp_tensornds; | |||||
| inp_tensornds.reserve(inputs.size()); | |||||
| DnnOprCaller<megdnn::Dot> dnn_opr(comp_node); | |||||
| for (unsigned i = 0; i < inputs.size(); ++i) { | |||||
| auto dnn_ten = inputs[i]->dnn_tensor(); | |||||
| inp_tensornds.push_back(dnn_ten); | |||||
| } | |||||
| TensorLayout oup_layout{inputs[0]->dtype()}; | |||||
| auto inp1_tensor = inputs[0]->dnn_tensor(); | |||||
| auto inp2_tensor = inputs[1]->dnn_tensor(); | |||||
| dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); | |||||
| if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { | |||||
| DnnOprCaller<megdnn::Fill> fill_opr(comp_node); | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
| fill_opr.op->param() = 0; | |||||
| fill_opr.op->exec(out.as_megdnn(), {}); | |||||
| return {Tensor::make(out)}; | |||||
| } | |||||
| auto sz = dnn_opr.op->get_workspace_in_bytes( | |||||
| inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); | |||||
| DeviceTensorND out_devtensor = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
| TensorLayout w_layout({sz}, dtype::Byte()); | |||||
| auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
| dnn_opr.op->exec( | |||||
| inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); | |||||
| return {Tensor::make(out_devtensor)}; | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| mgb_assert( | |||||
| inputs.size() == 2, "Dot expects 2 inputs; got %lu actually", | |||||
| inputs.size()); | |||||
| SmallVector<LogicalTensorDesc> dests(1); | |||||
| dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); | |||||
| dests[0].comp_node = inputs[0].comp_node; | |||||
| bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0; | |||||
| return {dests, validated}; | |||||
| } | |||||
| OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| } // namespace dot | |||||
| } // anonymous namespace | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| @@ -0,0 +1,435 @@ | |||||
| #include <numeric> | |||||
| #include "../blob_manager_impl.h" | |||||
| #include "../dnn_op_helper.h" | |||||
| #include "../op_trait.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/opr/blas.h" | |||||
| #include "../algo_chooser.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace { | |||||
| namespace matrix_mul { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{matmul.make_name()}; | |||||
| return opr::MatrixMul::make( | |||||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||||
| auto layout1 = inputs[0].layout; | |||||
| auto layout2 = inputs[1].layout; | |||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
| if (dim1 == 0 || dim2 == 0) { | |||||
| return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||||
| } | |||||
| if (matmul.transposeA) | |||||
| std::swap(layout1[0], layout1[1]); | |||||
| if (matmul.transposeB) | |||||
| std::swap(layout2[0], layout2[1]); | |||||
| mgb_assert(layout1[dim1 - 1] == layout2[0]); | |||||
| TensorLayout dst_layout(layout1.dtype); | |||||
| size_t ci = 0; | |||||
| for (size_t i = 0; i < dim1 - 1; i++) | |||||
| dst_layout[ci++] = layout1[i]; | |||||
| if (dim2 == 2) | |||||
| dst_layout[ci++] = layout2[1]; | |||||
| dst_layout.ndim = ci; | |||||
| dst_layout.init_contiguous_stride(); | |||||
| SmallVector<LogicalTensorDesc> out_descs(1u); | |||||
| out_descs[0] = {dst_layout, inputs[0].comp_node}; | |||||
| return {out_descs, true}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& matmul = def.cast_final_safe<MatrixMul>(); | |||||
| auto&& cn = inputs[0]->comp_node(); | |||||
| using TensorND = megdnn::TensorND; | |||||
| SmallVector<TensorND> inp_tensornds(inputs.size()); | |||||
| TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||||
| // only matters when layout1 has dim 2 | |||||
| if (matmul.transposeA) | |||||
| std::swap(layout1.shape[0], layout1.shape[1]); | |||||
| // only matters when layout2 has dim 2 | |||||
| if (matmul.transposeB) | |||||
| std::swap(layout2.shape[0], layout2.shape[1]); | |||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
| TensorLayout real_dst_layout(layout1.dtype); | |||||
| if (validated) { | |||||
| real_dst_layout = output_descs[0].layout; | |||||
| } else { | |||||
| size_t ri = 0; | |||||
| for (size_t i = 0; i < dim1 - 2; i++) | |||||
| real_dst_layout[ri++] = layout1[i]; | |||||
| real_dst_layout[ri++] = layout1[dim1 - 2]; | |||||
| if (dim2 == 2) | |||||
| real_dst_layout[ri++] = layout2[dim2 - 1]; | |||||
| real_dst_layout.ndim = ri; | |||||
| real_dst_layout.init_contiguous_stride(); | |||||
| } | |||||
| if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, real_dst_layout); | |||||
| if (!out.empty()) { | |||||
| dev_tensor_memset(out, 0); | |||||
| } | |||||
| return {Tensor::make(out)}; | |||||
| } | |||||
| TensorLayout layout_a = layout1, layout_b = layout2; | |||||
| if (dim1 == 1) { | |||||
| layout_a.add_axis_cont_inplace(0); | |||||
| inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||||
| inp_tensornds[0].layout = layout_a; | |||||
| } else if (dim1 > 2) { | |||||
| size_t batch = std::accumulate( | |||||
| layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | |||||
| std::multiplies<size_t>()); | |||||
| TensorShape na = TensorShape{batch, layout1[dim1 - 1]}; | |||||
| auto inp1 = inputs[0]; | |||||
| if (!layout1.try_reshape(layout_a, na)) { | |||||
| inp1 = Tensor::make(inp1->blob(), inp1->offset(), layout1); | |||||
| inp1->to_contiguous_inplace(); | |||||
| layout1 = inp1->layout(); | |||||
| layout_a = TensorLayout{{batch, layout1[dim1 - 1]}, layout1.dtype}; | |||||
| } | |||||
| layout_a.init_contiguous_stride(); | |||||
| inp_tensornds[0] = inp1->dnn_tensor(); | |||||
| inp_tensornds[0].layout = layout_a; | |||||
| } else { | |||||
| inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| layout_b.add_axis_inplace(1, 1, 1); | |||||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
| inp_tensornds[1].layout = layout_b; | |||||
| } else { | |||||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
| } | |||||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, layout_a.dtype); | |||||
| dst_layout.init_contiguous_stride(); | |||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||||
| size_t sz = setup_algo<megdnn::MatrixMul>( | |||||
| {layout_a, layout_b, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | |||||
| matmul.policy(), false); | |||||
| TensorLayout w_layout({sz}, dtype::Byte()); | |||||
| auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
| dnn_opr.op->exec(inp_tensornds[0], inp_tensornds[1], out.as_megdnn(), dnn_wk); | |||||
| return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(real_dst_layout)))}; | |||||
| } | |||||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
| layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { | |||||
| return layout.is_contiguous(); | |||||
| }; | |||||
| return layout_checker; | |||||
| } | |||||
| OP_TRAIT_REG(MatrixMul, MatrixMul) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .get_input_layout_constraint(get_input_layout_constraint) | |||||
| .fallback(); | |||||
| } // namespace matrix_mul | |||||
| } // namespace | |||||
| namespace { | |||||
| namespace batched_matrix_mul { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{matmul.make_name()}; | |||||
| return opr::BatchedMatrixMul::make( | |||||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||||
| TensorLayout layout1 = inputs[0].layout, layout2 = inputs[1].layout; | |||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
| if (dim1 == 0 || dim2 == 0) { | |||||
| return {{{TensorLayout(layout1.dtype), inputs[0].comp_node}}, false}; | |||||
| } | |||||
| if (matmul.transposeA) | |||||
| std::swap(layout1[dim1 - 1], layout1[dim1 - 2]); | |||||
| if (matmul.transposeB) | |||||
| std::swap(layout2[dim2 - 1], layout2[dim2 - 2]); | |||||
| TensorLayout dst_layout(layout1.dtype); | |||||
| size_t di = 0; | |||||
| if (dim1 > dim2) { | |||||
| for (size_t i = 0; i < dim1 - 2; i++) | |||||
| dst_layout[di++] = layout1[i]; | |||||
| } else { | |||||
| for (size_t i = 0; i < dim2 - 2; i++) | |||||
| dst_layout[di++] = layout2[i]; | |||||
| } | |||||
| if (dim1 > 1) | |||||
| dst_layout[di++] = layout1[dim1 - 2]; | |||||
| if (dim2 > 1) | |||||
| dst_layout[di++] = layout2[dim2 - 1]; | |||||
| dst_layout.ndim = di; | |||||
| dst_layout.init_contiguous_stride(); | |||||
| SmallVector<LogicalTensorDesc> out_descs(1u); | |||||
| out_descs[0] = {dst_layout, inputs[0].comp_node}; | |||||
| return {out_descs, true}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& matmul = def.cast_final_safe<BatchedMatrixMul>(); | |||||
| auto&& cn = inputs[0]->comp_node(); | |||||
| TensorLayout layout1 = inputs[0]->layout(), layout2 = inputs[1]->layout(); | |||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||||
| bool remove_row = false, remove_col = false; | |||||
| if (dim1 == 1) { | |||||
| dim1 = 2; | |||||
| remove_row = true; | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| dim2 = 2; | |||||
| remove_col = true; | |||||
| } | |||||
| if (remove_row) | |||||
| layout1.add_axis_cont_inplace(0); | |||||
| if (remove_col) | |||||
| layout2.add_axis_inplace(1, 1, 1); | |||||
| TensorShape tshp, batch_shp; | |||||
| size_t j = 0; | |||||
| if (dim1 > dim2) { | |||||
| for (size_t i = 0; i < dim1 - 2; i++) | |||||
| tshp[j++] = layout1.shape[i]; | |||||
| batch_shp = tshp; | |||||
| batch_shp.ndim = dim1 - 2; | |||||
| tshp[j++] = layout2[layout2.ndim - 2]; | |||||
| tshp[j++] = layout2[layout2.ndim - 1]; | |||||
| tshp.ndim = j; | |||||
| layout2 = layout2.broadcast(tshp); | |||||
| } | |||||
| if (dim2 > dim1) { | |||||
| for (size_t i = 0; i < dim2 - 2; i++) | |||||
| tshp[j++] = layout2.shape[i]; | |||||
| batch_shp = tshp; | |||||
| batch_shp.ndim = dim2 - 2; | |||||
| tshp[j++] = layout1[layout1.ndim - 2]; | |||||
| tshp[j++] = layout1[layout1.ndim - 1]; | |||||
| tshp.ndim = j; | |||||
| layout1 = layout1.broadcast(tshp); | |||||
| } | |||||
| if (dim1 == dim2) { | |||||
| for (size_t i = 0; i < dim1 - 2; i++) | |||||
| tshp[j++] = layout1.shape[i]; | |||||
| batch_shp = tshp; | |||||
| batch_shp.ndim = dim1 - 2; | |||||
| } | |||||
| TensorShape shp1 = batch_shp, shp2 = batch_shp; | |||||
| shp1.ndim += 2; | |||||
| shp2.ndim += 2; | |||||
| size_t maxdim = dim1 > dim2 ? dim1 : dim2; | |||||
| size_t nbatch = batch_shp[0]; | |||||
| auto inp1 = inputs[0], inp2 = inputs[1]; | |||||
| if (maxdim > 3) { | |||||
| nbatch = std::accumulate( | |||||
| batch_shp.shape, batch_shp.shape + batch_shp.ndim, (size_t)1, | |||||
| std::multiplies<size_t>()); | |||||
| TensorLayout layout_a; | |||||
| TensorShape nl1 = TensorShape( | |||||
| {nbatch, layout1[layout1.ndim - 2], layout1[layout1.ndim - 1]}); | |||||
| if (!layout1.try_reshape(layout_a, nl1)) { | |||||
| inp1 = Tensor::make(inputs[0]->blob(), inputs[0]->offset(), layout1); | |||||
| inp1->to_contiguous_inplace(); | |||||
| layout1 = inp1->layout(); | |||||
| } | |||||
| layout1 = layout_a; | |||||
| TensorShape nl2 = TensorShape( | |||||
| {nbatch, layout2[layout2.ndim - 2], layout2[layout2.ndim - 1]}); | |||||
| if (!layout2.try_reshape(layout_a, nl2)) { | |||||
| inp2 = Tensor::make(inputs[1]->blob(), inputs[1]->offset(), layout2); | |||||
| inp2->to_contiguous_inplace(); | |||||
| layout2 = inp2->layout(); | |||||
| } | |||||
| layout2 = layout_a; | |||||
| } | |||||
| TensorLayout dst_layout( | |||||
| {nbatch, matmul.transposeA ? layout1[2] : layout1[1], | |||||
| matmul.transposeB ? layout2[1] : layout2[2]}, | |||||
| layout1.dtype); | |||||
| dst_layout.init_contiguous_stride(); | |||||
| if (dim1 == 0 || dim2 == 0 || layout1[layout1.ndim - 1] == 0) { | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||||
| if (!out.empty()) { | |||||
| dev_tensor_memset(out, 0); | |||||
| } | |||||
| return {Tensor::make(out)}; | |||||
| } | |||||
| using TensorND = megdnn::TensorND; | |||||
| TensorND inp_nd1 = inp1->dnn_tensor(); | |||||
| inp_nd1.layout = layout1; | |||||
| TensorND inp_nd2 = inp2->dnn_tensor(); | |||||
| inp_nd2.layout = layout2; | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||||
| DnnOprCaller<megdnn::BatchedMatrixMul> dnn_opr(cn); | |||||
| dnn_opr.op->param() = matmul.param(); | |||||
| size_t sz = setup_algo<megdnn::BatchedMatrixMul>( | |||||
| {layout1, layout2, dst_layout}, dnn_opr.op.get(), 0, false, false, cn, | |||||
| matmul.policy(), false); | |||||
| TensorLayout w_layout({sz}, dtype::Byte()); | |||||
| auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
| dnn_opr.op->exec(inp_nd1, inp_nd2, out.as_megdnn(), dnn_wk); | |||||
| shp1[shp1.ndim - 2] = dst_layout[dst_layout.ndim - 2]; | |||||
| shp1[shp1.ndim - 1] = dst_layout[dst_layout.ndim - 1]; | |||||
| if (maxdim > 3) { | |||||
| dst_layout = dst_layout.reshape(shp1); | |||||
| } | |||||
| if (remove_row) { | |||||
| dst_layout = dst_layout.remove_axis(maxdim - 2); | |||||
| } | |||||
| if (remove_col) { | |||||
| dst_layout = dst_layout.remove_axis(maxdim - 1); | |||||
| } | |||||
| return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | |||||
| } | |||||
| SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||||
| SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size()); | |||||
| layout_checker[0] = layout_checker[1] = [](const TensorLayout& layout) { | |||||
| return layout.is_contiguous(); | |||||
| }; | |||||
| return layout_checker; | |||||
| } | |||||
| OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .get_input_layout_constraint(get_input_layout_constraint) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| } // namespace batched_matrix_mul | |||||
| } // namespace | |||||
| namespace { | |||||
| namespace dot { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final_safe<Dot>(); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::Dot::make(inputs[0], inputs[1], config); | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto comp_node = inputs[0]->comp_node(); | |||||
| using TensorND = megdnn::TensorND; | |||||
| SmallVector<TensorND> inp_tensornds; | |||||
| inp_tensornds.reserve(inputs.size()); | |||||
| DnnOprCaller<megdnn::Dot> dnn_opr(comp_node); | |||||
| for (unsigned i = 0; i < inputs.size(); ++i) { | |||||
| auto dnn_ten = inputs[i]->dnn_tensor(); | |||||
| inp_tensornds.push_back(dnn_ten); | |||||
| } | |||||
| TensorLayout oup_layout{inputs[0]->dtype()}; | |||||
| auto inp1_tensor = inputs[0]->dnn_tensor(); | |||||
| auto inp2_tensor = inputs[1]->dnn_tensor(); | |||||
| dnn_opr.op->deduce_layout(inp1_tensor.layout, inp2_tensor.layout, oup_layout); | |||||
| if (inputs[0]->layout().is_empty() || inputs[1]->layout().is_empty()) { | |||||
| DeviceTensorND out = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
| if (!out.empty()) { | |||||
| dev_tensor_memset(out, 0); | |||||
| } | |||||
| return {Tensor::make(out)}; | |||||
| } | |||||
| auto sz = dnn_opr.op->get_workspace_in_bytes( | |||||
| inp_tensornds[0].layout, inp_tensornds[1].layout, output_descs[0].layout); | |||||
| DeviceTensorND out_devtensor = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(comp_node, oup_layout); | |||||
| TensorLayout w_layout({sz}, dtype::Byte()); | |||||
| auto dnn_wk = dnn_opr.create_workspace(w_layout); | |||||
| dnn_opr.op->exec( | |||||
| inp_tensornds[0], inp_tensornds[1], out_devtensor.as_megdnn(), dnn_wk); | |||||
| return {Tensor::make(out_devtensor)}; | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| mgb_assert( | |||||
| inputs.size() == 2, "Dot expects 2 inputs; got %lu actually", | |||||
| inputs.size()); | |||||
| SmallVector<LogicalTensorDesc> dests(1); | |||||
| dests[0].layout = TensorLayout(TensorShape{1}, inputs[0].layout.dtype); | |||||
| dests[0].comp_node = inputs[0].comp_node; | |||||
| bool validated = inputs[0].layout.ndim != 0 && inputs[1].layout.ndim != 0; | |||||
| return {dests, validated}; | |||||
| } | |||||
| OP_TRAIT_REG(Dot, Dot, mgb::opr::Dot) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| } // namespace dot | |||||
| } // anonymous namespace | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| @@ -123,7 +123,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); | inputs[0]->dev_tensor().reset(inputs[0]->dev_tensor().storage(), src); | ||||
| auto mode = op_def.param().mode; | auto mode = op_def.param().mode; | ||||
| DnnOprCaller<megdnn::Fill> fill_op(comp_node); | |||||
| if (!keepdim && src.ndim > 1) { | if (!keepdim && src.ndim > 1) { | ||||
| layout.remove_axis_inplace(axis); | layout.remove_axis_inplace(axis); | ||||
| @@ -135,12 +134,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| switch (mode) { | switch (mode) { | ||||
| case Reduce::Mode::SUM: | case Reduce::Mode::SUM: | ||||
| if (!out.empty()) { | if (!out.empty()) { | ||||
| fill_op.op->param() = 0; | |||||
| fill_op.op->exec(out.as_megdnn(), {}); | |||||
| dev_tensor_memset(out, 0); | |||||
| } | } | ||||
| break; | break; | ||||
| case Reduce::Mode::PRODUCT: | case Reduce::Mode::PRODUCT: | ||||
| if (!out.empty()) { | if (!out.empty()) { | ||||
| DnnOprCaller<megdnn::Fill> fill_op(comp_node); | |||||
| fill_op.op->param() = 1; | fill_op.op->param() = 1; | ||||
| fill_op.op->exec(out.as_megdnn(), {}); | fill_op.op->exec(out.as_megdnn(), {}); | ||||
| } | } | ||||
| @@ -319,34 +319,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias) | |||||
| } // namespace batch_conv_bias | } // namespace batch_conv_bias | ||||
| } // namespace | } // namespace | ||||
| namespace { | |||||
| namespace matrix_mul { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& matmul = static_cast<const MatrixMul&>(def); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{matmul.make_name()}; | |||||
| return opr::MatrixMul::make( | |||||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
| } | |||||
| OP_TRAIT_REG(MatrixMul, MatrixMul).apply_on_var_node(apply_on_var_node).fallback(); | |||||
| } // namespace matrix_mul | |||||
| } // namespace | |||||
| namespace { | |||||
| namespace batched_matrix_mul { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| OperatorNodeConfig config{matmul.make_name()}; | |||||
| return opr::BatchedMatrixMul::make( | |||||
| inputs[0], inputs[1], matmul.param(), matmul.policy(), config); | |||||
| } | |||||
| OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| } // namespace batched_matrix_mul | |||||
| } // namespace | |||||
| namespace { | namespace { | ||||
| namespace argsort { | namespace argsort { | ||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
| @@ -183,6 +183,57 @@ ValueRefList convolution_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| return imperative::apply(op, converted); | return imperative::apply(op, converted); | ||||
| } | } | ||||
| ValueRefList matmul_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| auto&& conv_op = const_cast<MatrixMul&>(op.cast_final_safe<MatrixMul>()); | |||||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||||
| mgb::DType target_dtype; | |||||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
| conv_op.compute_mode = MatrixMul::ComputeMode::FLOAT32; | |||||
| target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||||
| } else { | |||||
| target_dtype = get_promoted_dtype(dtypes); | |||||
| } | |||||
| ValueRefList converted(inputs.size()); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (dtypes[i] != target_dtype) { | |||||
| converted[i] = imperative::apply( | |||||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
| } else { | |||||
| converted[i] = inputs[i]; | |||||
| } | |||||
| } | |||||
| return imperative::apply(op, converted); | |||||
| } | |||||
| ValueRefList batch_matmul_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| auto&& conv_op = | |||||
| const_cast<BatchedMatrixMul&>(op.cast_final_safe<BatchedMatrixMul>()); | |||||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||||
| mgb::DType target_dtype; | |||||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
| conv_op.compute_mode = BatchedMatrixMul::ComputeMode::FLOAT32; | |||||
| target_dtype = DTypePromoteCfg::amp_low_prec_dtype; | |||||
| } else { | |||||
| target_dtype = get_promoted_dtype(dtypes); | |||||
| } | |||||
| ValueRefList converted(inputs.size()); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| if (dtypes[i] != target_dtype) { | |||||
| converted[i] = imperative::apply( | |||||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
| } else { | |||||
| converted[i] = inputs[i]; | |||||
| } | |||||
| } | |||||
| return imperative::apply(op, converted); | |||||
| } | |||||
| // differ from Convolution, ConvolutionBackwardData is used in both | // differ from Convolution, ConvolutionBackwardData is used in both | ||||
| // functional.conv_transpose2d and quantize.conv_transpose2d | // functional.conv_transpose2d and quantize.conv_transpose2d | ||||
| ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { | ValueRefList convolution_backward_rule(const OpDef& op, Span<ValueRef> inputs) { | ||||
| @@ -259,8 +310,11 @@ struct DTypePromoteRuleRegistry { | |||||
| DTypePromoteRuleRegistry() { | DTypePromoteRuleRegistry() { | ||||
| register_dtype_promote_rule<Elemwise>(elemwise_rule); | register_dtype_promote_rule<Elemwise>(elemwise_rule); | ||||
| register_dtype_promote_rule<Concat>(naive_promote_rule); | register_dtype_promote_rule<Concat>(naive_promote_rule); | ||||
| register_dtype_promote_rule<GroupLocal>(naive_promote_rule); | |||||
| register_dtype_promote_rule<Reduce>(reduce_rule); | register_dtype_promote_rule<Reduce>(reduce_rule); | ||||
| register_dtype_promote_rule<Convolution>(convolution_rule); | register_dtype_promote_rule<Convolution>(convolution_rule); | ||||
| register_dtype_promote_rule<MatrixMul>(matmul_rule); | |||||
| register_dtype_promote_rule<BatchedMatrixMul>(batch_matmul_rule); | |||||
| register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); | register_dtype_promote_rule<ConvolutionBackwardData>(convolution_backward_rule); | ||||
| register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | ||||
| register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | ||||