| @@ -233,7 +233,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( | |||||
| const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad, | ||||
| CustomBackward& backward) { | CustomBackward& backward) { | ||||
| auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>(); | ||||
| auto&& grad_op = IndexingSetMultiAxisVec::make(indexingMultiAxisVec.items); | |||||
| auto&& grad_op = IndexingIncrMultiAxisVec::make(indexingMultiAxisVec.items); | |||||
| SmallVector<ValueRef> inputs2; | SmallVector<ValueRef> inputs2; | ||||
| if (inputs_require_grad[0]) { | if (inputs_require_grad[0]) { | ||||
| inputs2.push_back(get_shape(inputs[0])); | inputs2.push_back(get_shape(inputs[0])); | ||||
| @@ -316,7 +316,7 @@ def test_IndexingMultiAxisVec(): | |||||
| def f(x): | def f(x): | ||||
| x = x * 1 | x = x * 1 | ||||
| y = x[[0, 2], [0, 2]] | |||||
| y = x[[0, 0, 2, 1], [2, 2, 1, 0]] | |||||
| refs["x"] = TensorWeakRef(x) | refs["x"] = TensorWeakRef(x) | ||||
| return y | return y | ||||
| @@ -326,7 +326,7 @@ def test_IndexingMultiAxisVec(): | |||||
| grad(y, F.ones_like(y)) | grad(y, F.ones_like(y)) | ||||
| np.testing.assert_equal( | np.testing.assert_equal( | ||||
| np.array([[1, 0, 0], [0, 0, 0], [0, 0, 1]], dtype=np.float32), x.grad.numpy() | |||||
| np.array([[0, 0, 2], [1, 0, 0], [0, 1, 0]], dtype=np.float32), x.grad.numpy() | |||||
| ) | ) | ||||