Adding the `BatchMatMul` operation. Expanding `BasicOperations` to test the new operation.tags/v0.9
| @@ -11,5 +11,8 @@ namespace Tensorflow | |||||
| public static Tensor matmul(Tensor a, Tensor b) | public static Tensor matmul(Tensor a, Tensor b) | ||||
| => gen_math_ops.mat_mul(a, b); | => gen_math_ops.mat_mul(a, b); | ||||
| public static Tensor batch_matmul(Tensor x, Tensor y) | |||||
| => gen_math_ops.batch_mat_mul(x, y); | |||||
| } | } | ||||
| } | } | ||||
| @@ -153,6 +153,11 @@ namespace Tensorflow.Gradients | |||||
| return new Tensor[] { grad_a, grad_b }; | return new Tensor[] { grad_a, grad_b }; | ||||
| } | } | ||||
| public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| [RegisterGradient("Mean")] | [RegisterGradient("Mean")] | ||||
| public static Tensor[] _MeanGrad(Operation op, Tensor[] grads) | public static Tensor[] _MeanGrad(Operation op, Tensor[] grads) | ||||
| { | { | ||||
| @@ -471,6 +471,41 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Multiply slices of the two matrices "x" and "y". | |||||
| /// </summary> | |||||
| /// <remarks> | |||||
| /// The `BatchMatMul` operation is embedded into the | |||||
| /// `MatMul` operation on the DLL side. However the expected | |||||
| /// attributes are not the same, hence we need to expose this | |||||
| /// method to have the right args list on the `_apply_op_helper` | |||||
| /// function. | |||||
| /// | |||||
| /// For each rank > 2 the first rank - 2 dimensions are considered | |||||
| /// as fixed, and have to be consistent across the two matrices. A | |||||
| /// common matrix multiplication is then applied over the residual | |||||
| /// 2 dimensions. | |||||
| /// | |||||
| /// e.g. | |||||
| /// x is (3, 6, 12); y is (3, 12, 6) | |||||
| /// batch_matmul(x, y) ==> (3, 6, 6) | |||||
| /// </remarks> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="adj_x"></param> | |||||
| /// <param name="adj_y"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public static Tensor batch_mat_mul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper( | |||||
| "BatchMatMul", | |||||
| name, | |||||
| args: new { x, y, adj_x, adj_y }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the max of x and y (i.e. x > y ? x : y) element-wise. | /// Returns the max of x and y (i.e. x > y ? x : y) element-wise. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -497,6 +497,25 @@ namespace Tensorflow | |||||
| return result; | return result; | ||||
| } | } | ||||
| public static Tensor batch_matmul(Tensor x, Tensor y, | |||||
| bool adj_x = false, bool adj_y = false, | |||||
| string name = null) | |||||
| { | |||||
| Tensor result = null; | |||||
| with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope => | |||||
| { | |||||
| name = scope; | |||||
| x = ops.convert_to_tensor(x, name: "a"); | |||||
| y = ops.convert_to_tensor(y, name: "b"); | |||||
| result = gen_math_ops.batch_mat_mul(x, y, adj_x, adj_y, name); | |||||
| }); | |||||
| return result; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the complex conjugate of a complex number. | /// Returns the complex conjugate of a complex number. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -91,11 +91,69 @@ namespace TensorFlowNET.Examples | |||||
| // graph: the two constants and matmul. | // graph: the two constants and matmul. | ||||
| // | // | ||||
| // The output of the op is returned in 'result' as a numpy `ndarray` object. | // The output of the op is returned in 'result' as a numpy `ndarray` object. | ||||
| return with(tf.Session(), sess => | |||||
| using (sess = tf.Session()) | |||||
| { | { | ||||
| var result = sess.run(product); | var result = sess.run(product); | ||||
| Console.WriteLine(result.ToString()); // ==> [[ 12.]] | Console.WriteLine(result.ToString()); // ==> [[ 12.]] | ||||
| return result.Data<int>()[0] == 12; | |||||
| }; | |||||
| // `BatchMatMul` is actually embedded into the `MatMul` operation on the tensorflow.dll side. Every time we ask | |||||
| // for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent | |||||
| // across the two matrices and a common matrix multiplication is done on the residual 2 dimensions. | |||||
| // | |||||
| // np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3) | |||||
| // array([[[1, 2, 3], | |||||
| // [4, 5, 6], | |||||
| // [7, 8, 9]], | |||||
| // | |||||
| // [[1, 2, 3], | |||||
| // [4, 5, 6], | |||||
| // [7, 8, 9]], | |||||
| // | |||||
| // [[1, 2, 3], | |||||
| // [4, 5, 6], | |||||
| // [7, 8, 9]]]) | |||||
| var firstTensor = tf.convert_to_tensor( | |||||
| np.reshape( | |||||
| np.array<float>(1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9), | |||||
| 3, 3, 3)); | |||||
| // | |||||
| // np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2) | |||||
| // array([[[0, 1], | |||||
| // [0, 1], | |||||
| // [0, 1]], | |||||
| // | |||||
| // [[0, 1], | |||||
| // [0, 0], | |||||
| // [1, 0]], | |||||
| // | |||||
| // [[1, 0], | |||||
| // [1, 0], | |||||
| // [1, 0]]]) | |||||
| var secondTensor = tf.convert_to_tensor( | |||||
| np.reshape( | |||||
| np.array<float>(0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0), | |||||
| 3, 3, 2)); | |||||
| var batchMul = tf.batch_matmul(firstTensor, secondTensor); | |||||
| var checkTensor = np.array<float>(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0); | |||||
| return with(tf.Session(), sess => | |||||
| { | |||||
| var result = sess.run(batchMul); | |||||
| Console.WriteLine(result.ToString()); | |||||
| // | |||||
| // ==> array([[[0, 6], | |||||
| // [0, 15], | |||||
| // [0, 24]], | |||||
| // | |||||
| // [[ 3, 1], | |||||
| // [ 6, 4], | |||||
| // [ 9, 7]], | |||||
| // | |||||
| // [[ 6, 0], | |||||
| // [15, 0], | |||||
| // [24, 0]]]) | |||||
| return np.reshape(result, 18) | |||||
| .array_equal(checkTensor); | |||||
| }); | }); | ||||
| } | } | ||||