| @@ -13,6 +13,7 @@ from .._imperative_rt.core2 import ( | |||
| astype_cpp, | |||
| batched_matmul_cpp, | |||
| broadcast_cpp, | |||
| expand_dims_cpp, | |||
| getitem_cpp, | |||
| matmul_cpp, | |||
| reshape_cpp, | |||
| @@ -62,7 +63,6 @@ def _matmul( | |||
| assert dim1 > 0 and dim2 > 0 | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
| (result,) = apply(builtin.Dot(), inp1, inp2) | |||
| return result | |||
| @@ -72,34 +72,44 @@ def _matmul( | |||
| # 2x2 | |||
| # nx1(transpose_a=False), n>=3 | |||
| # nx2(transpose_a=False), n>=3 | |||
| return matmul_cpp( | |||
| inp1, | |||
| inp2, | |||
| dim1, | |||
| dim2, | |||
| ret = matmul_cpp( | |||
| inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||
| inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||
| max(dim1, 2), | |||
| max(dim2, 2), | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| _config._benchmark_kernel, | |||
| _config._deterministic_kernel, | |||
| ) | |||
| if dim1 == 1: | |||
| ret = squeeze_cpp(ret, -2) | |||
| elif dim2 == 1: | |||
| ret = squeeze_cpp(ret, -1) | |||
| return ret | |||
| else: # dispath to BatchedMatrixMul | |||
| # nx1(transpose_a=True), n>=3 | |||
| # nx2(transpose_a=True), n>=3 | |||
| # nxm,n>=3,m>=3 | |||
| # 1xm,m>=3 | |||
| # 2xm,m>=3 | |||
| return batched_matmul_cpp( | |||
| inp1, | |||
| inp2, | |||
| dim1, | |||
| dim2, | |||
| ret = batched_matmul_cpp( | |||
| inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||
| inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||
| max(dim1, 2), | |||
| max(dim2, 2), | |||
| transpose_a, | |||
| transpose_b, | |||
| compute_mode, | |||
| _config._benchmark_kernel, | |||
| _config._deterministic_kernel, | |||
| ) | |||
| if dim1 == 1: | |||
| ret = squeeze_cpp(ret, -2) | |||
| elif dim2 == 1: | |||
| ret = squeeze_cpp(ret, -1) | |||
| return ret | |||
| def _unary_elwise(mode): | |||
| @@ -87,6 +87,136 @@ ValueRef make_empty_tensor( | |||
| return res; | |||
| } | |||
| std::optional<ValueRefList> matrix_mul_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& matmul = op.cast_final_safe<MatrixMul>(); | |||
| size_t dimA = matmul.dimA; | |||
| size_t dimB = matmul.dimB; | |||
| auto&& param = matmul.param(); | |||
| auto&& policy = matmul.policy(); | |||
| mgb_assert(inputs.size() == 2); | |||
| std::array<ValueRef, 2> inps, input_shapes; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (inputs_require_grad[i ^ 1]) { | |||
| inps[i] = inputs[i]; | |||
| input_shapes[i] = get_shape(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||
| param, policy, dimA, dimB](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| SmallVector<ValueRef> ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| size_t dimG = std::max(dimA, dimB); | |||
| if (inps_[1]) { | |||
| if (param.transposeA) { | |||
| // A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul | |||
| auto&& grad_op = MatrixMul::make( | |||
| param.transposeB, true, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimB, dimG); | |||
| ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||
| } else { | |||
| // A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul | |||
| auto&& grad_op = MatrixMul::make( | |||
| false, !param.transposeB, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimG, dimB); | |||
| ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||
| } | |||
| } | |||
| if (inps_[0]) { | |||
| if (param.transposeB) { | |||
| // A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul | |||
| // (specialized) | |||
| auto&& grad_op = MatrixMul::make( | |||
| true, param.transposeA, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimG, dimA); | |||
| ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||
| } else { | |||
| // A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul | |||
| // (specialized) | |||
| auto&& grad_op = MatrixMul::make( | |||
| !param.transposeA, false, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimA, dimG); | |||
| ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||
| } | |||
| } | |||
| return ret; | |||
| }); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<ValueRefList> batched_matrix_mul_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| auto&& bmm = op.cast_final_safe<BatchedMatrixMul>(); | |||
| size_t dimA = bmm.dimA; | |||
| size_t dimB = bmm.dimB; | |||
| auto&& param = bmm.param(); | |||
| auto&& policy = bmm.policy(); | |||
| mgb_assert(inputs.size() == 2); | |||
| std::array<ValueRef, 2> inps, input_shapes; | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| if (inputs_require_grad[i ^ 1]) { | |||
| inps[i] = inputs[i]; | |||
| input_shapes[i] = get_shape(inputs[i]); | |||
| } | |||
| } | |||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||
| maker.output_size(1).output_captured(0, false); | |||
| maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||
| param, policy, dimA, dimB](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| SmallVector<ValueRef> ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| size_t dimG = std::max(dimA, dimB); | |||
| if (inps_[1]) { | |||
| if (param.transposeA) { | |||
| auto&& grad_op = BatchedMatrixMul::make( | |||
| param.transposeB, true, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimB, dimG); | |||
| ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||
| } else { | |||
| auto&& grad_op = BatchedMatrixMul::make( | |||
| false, !param.transposeB, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimG, dimB); | |||
| ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||
| } | |||
| if (dimG != dimA) { | |||
| ret[0] = reduce_to(ret[0], input_shapes_[0]); | |||
| } | |||
| } | |||
| if (inps_[0]) { | |||
| if (param.transposeB) { | |||
| auto&& grad_op = BatchedMatrixMul::make( | |||
| true, param.transposeA, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimG, dimA); | |||
| ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||
| } else { | |||
| auto&& grad_op = BatchedMatrixMul::make( | |||
| !param.transposeA, false, param.compute_mode, param.format, | |||
| policy.strategy, policy.workspace_limit, dimA, dimG); | |||
| ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||
| } | |||
| if (dimG != dimB) { | |||
| ret[1] = reduce_to(ret[1], input_shapes_[1]); | |||
| } | |||
| } | |||
| return ret; | |||
| }); | |||
| maker.finalize(); | |||
| return imperative::apply(ApplyOp(op), inputs); | |||
| } | |||
| std::optional<ValueRefList> elemwise_grad_rule( | |||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||
| CustomBackward& backward) { | |||
| @@ -395,6 +525,9 @@ struct Init { | |||
| FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | |||
| CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule); | |||
| CustomBackward::register_grad_rule( | |||
| BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); | |||
| } | |||
| } _; | |||
| @@ -511,3 +511,45 @@ def test_pixel_shuffle(): | |||
| y = f(x) | |||
| grad(y, F.ones_like(y)) | |||
| np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | |||
| def test_matmul(): | |||
| def test_one(xdim, ydim, transposeA, transposeB): | |||
| xshape = (1, 4) if xdim == 1 else (2,) * (xdim - 2) + (3, 4) | |||
| yshape = (4, 1) if ydim == 1 else (2,) * (ydim - 2) + (4, 5) | |||
| x = np.random.rand(*xshape).astype("float32") | |||
| y = np.random.rand(*yshape).astype("float32") | |||
| gshape = (x @ y).shape | |||
| g = np.random.rand(*gshape).astype("float32") | |||
| dx = g @ np.swapaxes(y, -1, -2) | |||
| dy = np.swapaxes(x, -1, -2) @ g | |||
| while dx.shape != x.shape: | |||
| dx = dx.sum(0) | |||
| while dy.shape != y.shape: | |||
| dy = dy.sum(0) | |||
| if transposeA: | |||
| x = np.swapaxes(x, -1, -2) | |||
| dx = np.swapaxes(dx, -1, -2) | |||
| if transposeB: | |||
| y = np.swapaxes(y, -1, -2) | |||
| dy = np.swapaxes(dy, -1, -2) | |||
| x = mge.Tensor(x.squeeze()) | |||
| y = mge.Tensor(y.squeeze()) | |||
| g = mge.Tensor(g.squeeze()) | |||
| with Grad() as grad: | |||
| grad.wrt(x, callback=save_to(x)) | |||
| grad.wrt(y, callback=save_to(y)) | |||
| z = F.matmul(x, y, transpose_a=transposeA, transpose_b=transposeB) | |||
| grad(z, g) | |||
| np.testing.assert_almost_equal(dx.squeeze(), x.grad.numpy(), decimal=5) | |||
| np.testing.assert_almost_equal(dy.squeeze(), y.grad.numpy(), decimal=5) | |||
| for xdim in [1, 2, 3, 4]: | |||
| for ydim in [1, 2, 3, 4]: | |||
| for transposeA in [False, True]: | |||
| if xdim == 1 and transposeA == True: | |||
| continue | |||
| for transposeB in [False, True]: | |||
| if ydim == 1 and transposeB == True: | |||
| continue | |||
| test_one(xdim, ydim, transposeA, transposeB) | |||
| @@ -31,18 +31,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| DTypeScalar vi{-1}; | |||
| auto graph = inputs[0]->owner_graph(); | |||
| bool remove_row = false, remove_col = false; | |||
| if (dim1 == 1) { | |||
| dim1 = 2; | |||
| remove_row = true; | |||
| inp1 = inp1.add_axis(0); | |||
| } | |||
| if (dim2 == 1) { | |||
| dim2 = 2; | |||
| remove_col = true; | |||
| inp2 = inp2.add_axis(1); | |||
| } | |||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
| if (dim1 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| @@ -91,17 +79,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| result = result.reshape(tshp); | |||
| } | |||
| auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||
| if (remove_row) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| if (remove_col) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| return result; | |||
| } | |||
| @@ -113,6 +90,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | |||
| DType dst_dtype; | |||
| if (dim1 == dim2 && dim2 >= 3) { // only happens in backward | |||
| for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||
| layout1[0] *= layout1[i]; | |||
| layout2[0] *= layout2[i]; | |||
| } | |||
| layout1[1] = layout1[layout1.ndim - 1]; | |||
| layout1.ndim = 2; | |||
| layout1.init_contiguous_stride(); | |||
| layout2[1] = layout2[layout2.ndim - 1]; | |||
| layout2.ndim = 2; | |||
| layout2.init_contiguous_stride(); | |||
| dim1 = dim2 = 2; | |||
| } | |||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| @@ -156,6 +146,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | |||
| dnn_opr.op->param() = matmul.param(); | |||
| if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||
| for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||
| layout1[0] *= layout1[i]; | |||
| layout2[0] *= layout2[i]; | |||
| } | |||
| layout1[1] = layout1[layout1.ndim - 1]; | |||
| layout1.ndim = 2; | |||
| layout1.init_contiguous_stride(); | |||
| layout2[1] = layout2[layout2.ndim - 1]; | |||
| layout2.ndim = 2; | |||
| layout2.init_contiguous_stride(); | |||
| } | |||
| DType dst_dtype; | |||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
| @@ -191,11 +194,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| } | |||
| TensorLayout layout_a = layout1, layout_b = layout2; | |||
| if (dim1 == 1) { | |||
| layout_a.add_axis_cont_inplace(0); | |||
| inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||
| inp_tensornds[0].layout = layout_a; | |||
| } else if (dim1 > 2) { | |||
| if (dim1 > 2) { | |||
| size_t batch = std::accumulate( | |||
| layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | |||
| std::multiplies<size_t>()); | |||
| @@ -216,13 +215,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||
| } | |||
| if (dim2 == 1) { | |||
| layout_b.add_axis_inplace(1, 1, 1); | |||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
| inp_tensornds[1].layout = layout_b; | |||
| } else { | |||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
| } | |||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | |||
| dst_layout.init_contiguous_stride(); | |||
| @@ -232,6 +225,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| if (matmul.transposeB) | |||
| std::swap(layout_b.shape[0], layout_b.shape[1]); | |||
| if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||
| inp_tensornds[0].layout = layout_a; | |||
| inp_tensornds[1].layout = layout_b; | |||
| } | |||
| DeviceTensorND out = | |||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | |||
| size_t sz = setup_algo<megdnn::MatrixMul>( | |||
| @@ -279,18 +277,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto graph = inputs[0]->owner_graph(); | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| bool remove_row = false, remove_col = false; | |||
| if (dim1 == 1) { | |||
| dim1 = 2; | |||
| remove_row = true; | |||
| inp1 = inp1.add_axis(0); | |||
| } | |||
| if (dim2 == 1) { | |||
| dim2 = 2; | |||
| remove_col = true; | |||
| inp2 = inp2.add_axis(1); | |||
| } | |||
| auto shp1 = inp1.symshape(); | |||
| auto shp2 = inp2.symshape(); | |||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
| @@ -349,16 +335,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | |||
| result = result.reshape(result_shp); | |||
| } | |||
| if (remove_row) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| if (remove_col) { | |||
| std::vector<Desc> remove_param; | |||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||
| result = opr::AxisAddRemove::make(result, remove_param); | |||
| } | |||
| return result; | |||
| } | |||
| @@ -418,21 +394,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| DType dst_dtype; | |||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | |||
| bool remove_row = false, remove_col = false; | |||
| if (dim1 == 1) { | |||
| dim1 = 2; | |||
| remove_row = true; | |||
| } | |||
| if (dim2 == 1) { | |||
| dim2 = 2; | |||
| remove_col = true; | |||
| } | |||
| if (remove_row) | |||
| layout1.add_axis_cont_inplace(0); | |||
| if (remove_col) | |||
| layout2.add_axis_inplace(1, 1, 1); | |||
| TensorShape tshp, batch_shp; | |||
| size_t j = 0; | |||
| auto inp1 = inputs[0], inp2 = inputs[1]; | |||
| @@ -530,12 +491,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| if (maxdim > 3) { | |||
| dst_layout = dst_layout.reshape(shp1); | |||
| } | |||
| if (remove_row) { | |||
| dst_layout = dst_layout.remove_axis(maxdim - 2); | |||
| } | |||
| if (remove_col) { | |||
| dst_layout = dst_layout.remove_axis(maxdim - 1); | |||
| } | |||
| return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | |||
| } | |||