| @@ -13,6 +13,7 @@ from .._imperative_rt.core2 import ( | |||||
| astype_cpp, | astype_cpp, | ||||
| batched_matmul_cpp, | batched_matmul_cpp, | ||||
| broadcast_cpp, | broadcast_cpp, | ||||
| expand_dims_cpp, | |||||
| getitem_cpp, | getitem_cpp, | ||||
| matmul_cpp, | matmul_cpp, | ||||
| reshape_cpp, | reshape_cpp, | ||||
| @@ -62,7 +63,6 @@ def _matmul( | |||||
| assert dim1 > 0 and dim2 > 0 | assert dim1 > 0 and dim2 > 0 | ||||
| maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | if dim1 == 1 and dim2 == 1: # dispatch to Dot | ||||
| (result,) = apply(builtin.Dot(), inp1, inp2) | (result,) = apply(builtin.Dot(), inp1, inp2) | ||||
| return result | return result | ||||
| @@ -72,34 +72,44 @@ def _matmul( | |||||
| # 2x2 | # 2x2 | ||||
| # nx1(transpose_a=False), n>=3 | # nx1(transpose_a=False), n>=3 | ||||
| # nx2(transpose_a=False), n>=3 | # nx2(transpose_a=False), n>=3 | ||||
| return matmul_cpp( | |||||
| inp1, | |||||
| inp2, | |||||
| dim1, | |||||
| dim2, | |||||
| ret = matmul_cpp( | |||||
| inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||||
| inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||||
| max(dim1, 2), | |||||
| max(dim2, 2), | |||||
| transpose_a, | transpose_a, | ||||
| transpose_b, | transpose_b, | ||||
| compute_mode, | compute_mode, | ||||
| _config._benchmark_kernel, | _config._benchmark_kernel, | ||||
| _config._deterministic_kernel, | _config._deterministic_kernel, | ||||
| ) | ) | ||||
| if dim1 == 1: | |||||
| ret = squeeze_cpp(ret, -2) | |||||
| elif dim2 == 1: | |||||
| ret = squeeze_cpp(ret, -1) | |||||
| return ret | |||||
| else: # dispath to BatchedMatrixMul | else: # dispath to BatchedMatrixMul | ||||
| # nx1(transpose_a=True), n>=3 | # nx1(transpose_a=True), n>=3 | ||||
| # nx2(transpose_a=True), n>=3 | # nx2(transpose_a=True), n>=3 | ||||
| # nxm,n>=3,m>=3 | # nxm,n>=3,m>=3 | ||||
| # 1xm,m>=3 | # 1xm,m>=3 | ||||
| # 2xm,m>=3 | # 2xm,m>=3 | ||||
| return batched_matmul_cpp( | |||||
| inp1, | |||||
| inp2, | |||||
| dim1, | |||||
| dim2, | |||||
| ret = batched_matmul_cpp( | |||||
| inp1 if dim1 > 1 else expand_dims_cpp(inp1, 0), | |||||
| inp2 if dim2 > 1 else expand_dims_cpp(inp2, -1), | |||||
| max(dim1, 2), | |||||
| max(dim2, 2), | |||||
| transpose_a, | transpose_a, | ||||
| transpose_b, | transpose_b, | ||||
| compute_mode, | compute_mode, | ||||
| _config._benchmark_kernel, | _config._benchmark_kernel, | ||||
| _config._deterministic_kernel, | _config._deterministic_kernel, | ||||
| ) | ) | ||||
| if dim1 == 1: | |||||
| ret = squeeze_cpp(ret, -2) | |||||
| elif dim2 == 1: | |||||
| ret = squeeze_cpp(ret, -1) | |||||
| return ret | |||||
| def _unary_elwise(mode): | def _unary_elwise(mode): | ||||
| @@ -87,6 +87,136 @@ ValueRef make_empty_tensor( | |||||
| return res; | return res; | ||||
| } | } | ||||
| std::optional<ValueRefList> matrix_mul_grad_rule( | |||||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
| CustomBackward& backward) { | |||||
| auto&& matmul = op.cast_final_safe<MatrixMul>(); | |||||
| size_t dimA = matmul.dimA; | |||||
| size_t dimB = matmul.dimB; | |||||
| auto&& param = matmul.param(); | |||||
| auto&& policy = matmul.policy(); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| std::array<ValueRef, 2> inps, input_shapes; | |||||
| for (size_t i = 0; i < 2; ++i) { | |||||
| if (inputs_require_grad[i ^ 1]) { | |||||
| inps[i] = inputs[i]; | |||||
| input_shapes[i] = get_shape(inputs[i]); | |||||
| } | |||||
| } | |||||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||||
| maker.output_size(1).output_captured(0, false); | |||||
| maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||||
| param, policy, dimA, dimB](Span<ValueRef> grads) { | |||||
| mgb_assert(grads.size() == 1); | |||||
| ValueRef grad = grads[0]; | |||||
| SmallVector<ValueRef> ret(2); | |||||
| if (!grad) { | |||||
| return ret; | |||||
| } | |||||
| size_t dimG = std::max(dimA, dimB); | |||||
| if (inps_[1]) { | |||||
| if (param.transposeA) { | |||||
| // A^T(2) @ B(2) = G(2), A'(2) = B'(2) @ G'^T(2) -> MatrixMul | |||||
| auto&& grad_op = MatrixMul::make( | |||||
| param.transposeB, true, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimB, dimG); | |||||
| ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||||
| } else { | |||||
| // A(>=2) @ B(2) = G(>=2), A'(>=2) = G'(>=2) @ B(2) -> MatrixMul | |||||
| auto&& grad_op = MatrixMul::make( | |||||
| false, !param.transposeB, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimG, dimB); | |||||
| ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||||
| } | |||||
| } | |||||
| if (inps_[0]) { | |||||
| if (param.transposeB) { | |||||
| // A(>=2) @ B^T(2) = G(>=2), B'(2) = G'^T(>=2) @ A(>=2) -> MatrixMul | |||||
| // (specialized) | |||||
| auto&& grad_op = MatrixMul::make( | |||||
| true, param.transposeA, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimG, dimA); | |||||
| ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||||
| } else { | |||||
| // A(>=2) @ B(2) = G(>=2), B'(2) = G'(>=2) @ A(>=2) -> MatrixMul | |||||
| // (specialized) | |||||
| auto&& grad_op = MatrixMul::make( | |||||
| !param.transposeA, false, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimA, dimG); | |||||
| ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| }); | |||||
| maker.finalize(); | |||||
| return imperative::apply(ApplyOp(op), inputs); | |||||
| } | |||||
| std::optional<ValueRefList> batched_matrix_mul_grad_rule( | |||||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | |||||
| CustomBackward& backward) { | |||||
| auto&& bmm = op.cast_final_safe<BatchedMatrixMul>(); | |||||
| size_t dimA = bmm.dimA; | |||||
| size_t dimB = bmm.dimB; | |||||
| auto&& param = bmm.param(); | |||||
| auto&& policy = bmm.policy(); | |||||
| mgb_assert(inputs.size() == 2); | |||||
| std::array<ValueRef, 2> inps, input_shapes; | |||||
| for (size_t i = 0; i < 2; ++i) { | |||||
| if (inputs_require_grad[i ^ 1]) { | |||||
| inps[i] = inputs[i]; | |||||
| input_shapes[i] = get_shape(inputs[i]); | |||||
| } | |||||
| } | |||||
| auto maker = CustomGradMaker(backward, inputs.size()); | |||||
| maker.output_size(1).output_captured(0, false); | |||||
| maker.backward([inps_ = std::move(inps), input_shapes_ = std::move(input_shapes), | |||||
| param, policy, dimA, dimB](Span<ValueRef> grads) { | |||||
| mgb_assert(grads.size() == 1); | |||||
| ValueRef grad = grads[0]; | |||||
| SmallVector<ValueRef> ret(2); | |||||
| if (!grad) { | |||||
| return ret; | |||||
| } | |||||
| size_t dimG = std::max(dimA, dimB); | |||||
| if (inps_[1]) { | |||||
| if (param.transposeA) { | |||||
| auto&& grad_op = BatchedMatrixMul::make( | |||||
| param.transposeB, true, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimB, dimG); | |||||
| ret[0] = imperative::apply(*grad_op, inps_[1], grad)[0]; | |||||
| } else { | |||||
| auto&& grad_op = BatchedMatrixMul::make( | |||||
| false, !param.transposeB, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimG, dimB); | |||||
| ret[0] = imperative::apply(*grad_op, grad, inps_[1])[0]; | |||||
| } | |||||
| if (dimG != dimA) { | |||||
| ret[0] = reduce_to(ret[0], input_shapes_[0]); | |||||
| } | |||||
| } | |||||
| if (inps_[0]) { | |||||
| if (param.transposeB) { | |||||
| auto&& grad_op = BatchedMatrixMul::make( | |||||
| true, param.transposeA, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimG, dimA); | |||||
| ret[1] = imperative::apply(*grad_op, grad, inps_[0])[0]; | |||||
| } else { | |||||
| auto&& grad_op = BatchedMatrixMul::make( | |||||
| !param.transposeA, false, param.compute_mode, param.format, | |||||
| policy.strategy, policy.workspace_limit, dimA, dimG); | |||||
| ret[1] = imperative::apply(*grad_op, inps_[0], grad)[0]; | |||||
| } | |||||
| if (dimG != dimB) { | |||||
| ret[1] = reduce_to(ret[1], input_shapes_[1]); | |||||
| } | |||||
| } | |||||
| return ret; | |||||
| }); | |||||
| maker.finalize(); | |||||
| return imperative::apply(ApplyOp(op), inputs); | |||||
| } | |||||
| std::optional<ValueRefList> elemwise_grad_rule( | std::optional<ValueRefList> elemwise_grad_rule( | ||||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
| CustomBackward& backward) { | CustomBackward& backward) { | ||||
| @@ -395,6 +525,9 @@ struct Init { | |||||
| FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | FastpathCopy::typeinfo(), fastpathcopy_grad_rule); | ||||
| CustomBackward::register_grad_rule( | CustomBackward::register_grad_rule( | ||||
| PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | PixelShuffle::typeinfo(), pixelShuffle_grad_rule); | ||||
| CustomBackward::register_grad_rule(MatrixMul::typeinfo(), matrix_mul_grad_rule); | |||||
| CustomBackward::register_grad_rule( | |||||
| BatchedMatrixMul::typeinfo(), batched_matrix_mul_grad_rule); | |||||
| } | } | ||||
| } _; | } _; | ||||
| @@ -511,3 +511,45 @@ def test_pixel_shuffle(): | |||||
| y = f(x) | y = f(x) | ||||
| grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
| np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | np.testing.assert_equal(2 * x.numpy(), x.grad.numpy()) | ||||
| def test_matmul(): | |||||
| def test_one(xdim, ydim, transposeA, transposeB): | |||||
| xshape = (1, 4) if xdim == 1 else (2,) * (xdim - 2) + (3, 4) | |||||
| yshape = (4, 1) if ydim == 1 else (2,) * (ydim - 2) + (4, 5) | |||||
| x = np.random.rand(*xshape).astype("float32") | |||||
| y = np.random.rand(*yshape).astype("float32") | |||||
| gshape = (x @ y).shape | |||||
| g = np.random.rand(*gshape).astype("float32") | |||||
| dx = g @ np.swapaxes(y, -1, -2) | |||||
| dy = np.swapaxes(x, -1, -2) @ g | |||||
| while dx.shape != x.shape: | |||||
| dx = dx.sum(0) | |||||
| while dy.shape != y.shape: | |||||
| dy = dy.sum(0) | |||||
| if transposeA: | |||||
| x = np.swapaxes(x, -1, -2) | |||||
| dx = np.swapaxes(dx, -1, -2) | |||||
| if transposeB: | |||||
| y = np.swapaxes(y, -1, -2) | |||||
| dy = np.swapaxes(dy, -1, -2) | |||||
| x = mge.Tensor(x.squeeze()) | |||||
| y = mge.Tensor(y.squeeze()) | |||||
| g = mge.Tensor(g.squeeze()) | |||||
| with Grad() as grad: | |||||
| grad.wrt(x, callback=save_to(x)) | |||||
| grad.wrt(y, callback=save_to(y)) | |||||
| z = F.matmul(x, y, transpose_a=transposeA, transpose_b=transposeB) | |||||
| grad(z, g) | |||||
| np.testing.assert_almost_equal(dx.squeeze(), x.grad.numpy(), decimal=5) | |||||
| np.testing.assert_almost_equal(dy.squeeze(), y.grad.numpy(), decimal=5) | |||||
| for xdim in [1, 2, 3, 4]: | |||||
| for ydim in [1, 2, 3, 4]: | |||||
| for transposeA in [False, True]: | |||||
| if xdim == 1 and transposeA == True: | |||||
| continue | |||||
| for transposeB in [False, True]: | |||||
| if ydim == 1 and transposeB == True: | |||||
| continue | |||||
| test_one(xdim, ydim, transposeA, transposeB) | |||||
| @@ -31,18 +31,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| DTypeScalar vi{-1}; | DTypeScalar vi{-1}; | ||||
| auto graph = inputs[0]->owner_graph(); | auto graph = inputs[0]->owner_graph(); | ||||
| bool remove_row = false, remove_col = false; | |||||
| if (dim1 == 1) { | |||||
| dim1 = 2; | |||||
| remove_row = true; | |||||
| inp1 = inp1.add_axis(0); | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| dim2 = 2; | |||||
| remove_col = true; | |||||
| inp2 = inp2.add_axis(1); | |||||
| } | |||||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | ||||
| if (dim1 > 2) { | if (dim1 > 2) { | ||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | auto idx = opr::ImmutableTensor::make(*graph, vi, config); | ||||
| @@ -91,17 +79,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| result = result.reshape(tshp); | result = result.reshape(tshp); | ||||
| } | } | ||||
| auto maxdim = dim1 > dim2 ? dim1 : dim2; | |||||
| if (remove_row) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| if (remove_col) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -113,6 +90,19 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | size_t dim1 = layout1.ndim, dim2 = layout2.ndim; | ||||
| DType dst_dtype; | DType dst_dtype; | ||||
| if (dim1 == dim2 && dim2 >= 3) { // only happens in backward | |||||
| for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||||
| layout1[0] *= layout1[i]; | |||||
| layout2[0] *= layout2[i]; | |||||
| } | |||||
| layout1[1] = layout1[layout1.ndim - 1]; | |||||
| layout1.ndim = 2; | |||||
| layout1.init_contiguous_stride(); | |||||
| layout2[1] = layout2[layout2.ndim - 1]; | |||||
| layout2.ndim = 2; | |||||
| layout2.init_contiguous_stride(); | |||||
| dim1 = dim2 = 2; | |||||
| } | |||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | DnnOprCaller<megdnn::MatrixMul> dnn_opr(inputs[0].comp_node); | ||||
| dnn_opr.op->param() = matmul.param(); | dnn_opr.op->param() = matmul.param(); | ||||
| @@ -156,6 +146,19 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | DnnOprCaller<megdnn::MatrixMul> dnn_opr(cn); | ||||
| dnn_opr.op->param() = matmul.param(); | dnn_opr.op->param() = matmul.param(); | ||||
| if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||||
| for (size_t i = 1; i + 1 < layout1.ndim; ++i) { | |||||
| layout1[0] *= layout1[i]; | |||||
| layout2[0] *= layout2[i]; | |||||
| } | |||||
| layout1[1] = layout1[layout1.ndim - 1]; | |||||
| layout1.ndim = 2; | |||||
| layout1.init_contiguous_stride(); | |||||
| layout2[1] = layout2[layout2.ndim - 1]; | |||||
| layout2.ndim = 2; | |||||
| layout2.init_contiguous_stride(); | |||||
| } | |||||
| DType dst_dtype; | DType dst_dtype; | ||||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | ||||
| @@ -191,11 +194,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| } | } | ||||
| TensorLayout layout_a = layout1, layout_b = layout2; | TensorLayout layout_a = layout1, layout_b = layout2; | ||||
| if (dim1 == 1) { | |||||
| layout_a.add_axis_cont_inplace(0); | |||||
| inp_tensornds[0] = inputs[0]->dnn_tensor(); | |||||
| inp_tensornds[0].layout = layout_a; | |||||
| } else if (dim1 > 2) { | |||||
| if (dim1 > 2) { | |||||
| size_t batch = std::accumulate( | size_t batch = std::accumulate( | ||||
| layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | layout1.shape, layout1.shape + dim1 - 1, (size_t)1, | ||||
| std::multiplies<size_t>()); | std::multiplies<size_t>()); | ||||
| @@ -216,13 +215,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| inp_tensornds[0] = inputs[0]->dnn_tensor(); | inp_tensornds[0] = inputs[0]->dnn_tensor(); | ||||
| } | } | ||||
| if (dim2 == 1) { | |||||
| layout_b.add_axis_inplace(1, 1, 1); | |||||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
| inp_tensornds[1].layout = layout_b; | |||||
| } else { | |||||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
| } | |||||
| inp_tensornds[1] = inputs[1]->dnn_tensor(); | |||||
| TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | TensorLayout dst_layout = TensorLayout({layout_a[0], layout_b[1]}, dst_dtype); | ||||
| dst_layout.init_contiguous_stride(); | dst_layout.init_contiguous_stride(); | ||||
| @@ -232,6 +225,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| if (matmul.transposeB) | if (matmul.transposeB) | ||||
| std::swap(layout_b.shape[0], layout_b.shape[1]); | std::swap(layout_b.shape[0], layout_b.shape[1]); | ||||
| if (matmul.dimA == matmul.dimB && matmul.dimB >= 3) { // only happens in backward | |||||
| inp_tensornds[0].layout = layout_a; | |||||
| inp_tensornds[1].layout = layout_b; | |||||
| } | |||||
| DeviceTensorND out = | DeviceTensorND out = | ||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | BlobManager::inst()->alloc_workspace_with_defrag(cn, dst_layout); | ||||
| size_t sz = setup_algo<megdnn::MatrixMul>( | size_t sz = setup_algo<megdnn::MatrixMul>( | ||||
| @@ -279,18 +277,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto graph = inputs[0]->owner_graph(); | auto graph = inputs[0]->owner_graph(); | ||||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | auto idx = opr::ImmutableTensor::make(*graph, vi, config); | ||||
| bool remove_row = false, remove_col = false; | |||||
| if (dim1 == 1) { | |||||
| dim1 = 2; | |||||
| remove_row = true; | |||||
| inp1 = inp1.add_axis(0); | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| dim2 = 2; | |||||
| remove_col = true; | |||||
| inp2 = inp2.add_axis(1); | |||||
| } | |||||
| auto shp1 = inp1.symshape(); | auto shp1 = inp1.symshape(); | ||||
| auto shp2 = inp2.symshape(); | auto shp2 = inp2.symshape(); | ||||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | ||||
| @@ -349,16 +335,6 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | result_shp = opr::Concat::make({batch_shape, shp_tail}, 0, cn); | ||||
| result = result.reshape(result_shp); | result = result.reshape(result_shp); | ||||
| } | } | ||||
| if (remove_row) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 2)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| if (remove_col) { | |||||
| std::vector<Desc> remove_param; | |||||
| remove_param.push_back(Desc::make_remove(maxdim - 1)); | |||||
| result = opr::AxisAddRemove::make(result, remove_param); | |||||
| } | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -418,21 +394,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| DType dst_dtype; | DType dst_dtype; | ||||
| dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | dnn_opr.op->deduce_dtype(layout1.dtype, layout1.dtype, dst_dtype); | ||||
| bool remove_row = false, remove_col = false; | |||||
| if (dim1 == 1) { | |||||
| dim1 = 2; | |||||
| remove_row = true; | |||||
| } | |||||
| if (dim2 == 1) { | |||||
| dim2 = 2; | |||||
| remove_col = true; | |||||
| } | |||||
| if (remove_row) | |||||
| layout1.add_axis_cont_inplace(0); | |||||
| if (remove_col) | |||||
| layout2.add_axis_inplace(1, 1, 1); | |||||
| TensorShape tshp, batch_shp; | TensorShape tshp, batch_shp; | ||||
| size_t j = 0; | size_t j = 0; | ||||
| auto inp1 = inputs[0], inp2 = inputs[1]; | auto inp1 = inputs[0], inp2 = inputs[1]; | ||||
| @@ -530,12 +491,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| if (maxdim > 3) { | if (maxdim > 3) { | ||||
| dst_layout = dst_layout.reshape(shp1); | dst_layout = dst_layout.reshape(shp1); | ||||
| } | } | ||||
| if (remove_row) { | |||||
| dst_layout = dst_layout.remove_axis(maxdim - 2); | |||||
| } | |||||
| if (remove_col) { | |||||
| dst_layout = dst_layout.remove_axis(maxdim - 1); | |||||
| } | |||||
| return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | return {Tensor::make(out.sub(SubTensorSpec::make_from_layout(dst_layout)))}; | ||||
| } | } | ||||