diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
index fd751322..3c89bd58 100644
--- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs
@@ -11,5 +11,8 @@ namespace Tensorflow
public static Tensor matmul(Tensor a, Tensor b)
=> gen_math_ops.mat_mul(a, b);
+
+ public static Tensor batch_matmul(Tensor x, Tensor y)
+ => gen_math_ops.batch_mat_mul(x, y);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs
index f7f8e35f..de3e7d66 100644
--- a/src/TensorFlowNET.Core/Gradients/math_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs
@@ -153,6 +153,11 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad_a, grad_b };
}
+ public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads)
+ {
+ throw new NotImplementedException();
+ }
+
[RegisterGradient("Mean")]
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
{
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 3dee5e9e..13b76075 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -471,6 +471,41 @@ namespace Tensorflow
return _op.outputs[0];
}
+ ///
+ /// Multiply slices of the two matrices "x" and "y".
+ ///
+ ///
+ /// The `BatchMatMul` operation is embedded into the
+ /// `MatMul` operation on the DLL side. However the expected
+ /// attributes are not the same, hence we need to expose this
+ /// method to have the right args list on the `_apply_op_helper`
+ /// function.
+ ///
+ /// For each rank > 2 the first rank - 2 dimensions are considered
+ /// as fixed, and have to be consistent across the two matrices. A
+ /// common matrix multiplication is then applied over the residual
+ /// 2 dimensions.
+ ///
+ /// e.g.
+ /// x is (3, 6, 12); y is (3, 12, 6)
+ /// batch_matmul(x, y) ==> (3, 6, 6)
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor batch_mat_mul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper(
+ "BatchMatMul",
+ name,
+ args: new { x, y, adj_x, adj_y });
+
+ return _op.outputs[0];
+ }
+
///
/// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
///
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 29e9d671..c13f0d25 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -497,6 +497,25 @@ namespace Tensorflow
return result;
}
+ public static Tensor batch_matmul(Tensor x, Tensor y,
+ bool adj_x = false, bool adj_y = false,
+ string name = null)
+ {
+ Tensor result = null;
+
+ with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope =>
+ {
+ name = scope;
+
+ x = ops.convert_to_tensor(x, name: "a");
+ y = ops.convert_to_tensor(y, name: "b");
+
+ result = gen_math_ops.batch_mat_mul(x, y, adj_x, adj_y, name);
+ });
+
+ return result;
+ }
+
///
/// Returns the complex conjugate of a complex number.
///
diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs
index ce861df7..39e7ce7e 100644
--- a/test/TensorFlowNET.Examples/BasicOperations.cs
+++ b/test/TensorFlowNET.Examples/BasicOperations.cs
@@ -91,11 +91,69 @@ namespace TensorFlowNET.Examples
// graph: the two constants and matmul.
//
// The output of the op is returned in 'result' as a numpy `ndarray` object.
- return with(tf.Session(), sess =>
+ using (sess = tf.Session())
{
var result = sess.run(product);
Console.WriteLine(result.ToString()); // ==> [[ 12.]]
- return result.Data()[0] == 12;
+ };
+
+ // `BatchMatMul` is actually embedded into the `MatMul` operation on the tensorflow.dll side. Every time we ask
+ // for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent
+ // across the two matrices and a common matrix multiplication is done on the residual 2 dimensions.
+ //
+ // np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3)
+ // array([[[1, 2, 3],
+ // [4, 5, 6],
+ // [7, 8, 9]],
+ //
+ // [[1, 2, 3],
+ // [4, 5, 6],
+ // [7, 8, 9]],
+ //
+ // [[1, 2, 3],
+ // [4, 5, 6],
+ // [7, 8, 9]]])
+ var firstTensor = tf.convert_to_tensor(
+ np.reshape(
+ np.array(1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9),
+ 3, 3, 3));
+ //
+ // np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2)
+ // array([[[0, 1],
+ // [0, 1],
+ // [0, 1]],
+ //
+ // [[0, 1],
+ // [0, 0],
+ // [1, 0]],
+ //
+ // [[1, 0],
+ // [1, 0],
+ // [1, 0]]])
+ var secondTensor = tf.convert_to_tensor(
+ np.reshape(
+ np.array(0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0),
+ 3, 3, 2));
+ var batchMul = tf.batch_matmul(firstTensor, secondTensor);
+ var checkTensor = np.array(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0);
+ return with(tf.Session(), sess =>
+ {
+ var result = sess.run(batchMul);
+ Console.WriteLine(result.ToString());
+ //
+ // ==> array([[[0, 6],
+ // [0, 15],
+ // [0, 24]],
+ //
+ // [[ 3, 1],
+ // [ 6, 4],
+ // [ 9, 7]],
+ //
+ // [[ 6, 0],
+ // [15, 0],
+ // [24, 0]]])
+ return np.reshape(result, 18)
+ .array_equal(checkTensor);
});
}