Browse Source

Add `BatchMatMul` operation

Adding the `BatchMatMul` operation.
Expanding `BasicOperations` to test the new operation.
tags/v0.9
Antonio Cifonelli 6 years ago
parent
commit
0f4cb1aedb
5 changed files with 122 additions and 2 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.linalg.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Gradients/math_grad.cs
  3. +35
    -0
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +19
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  5. +60
    -2
      test/TensorFlowNET.Examples/BasicOperations.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.linalg.cs View File

@@ -11,5 +11,8 @@ namespace Tensorflow

public static Tensor matmul(Tensor a, Tensor b)
=> gen_math_ops.mat_mul(a, b);

public static Tensor batch_matmul(Tensor x, Tensor y)
=> gen_math_ops.batch_mat_mul(x, y);
}
}

+ 5
- 0
src/TensorFlowNET.Core/Gradients/math_grad.cs View File

@@ -153,6 +153,11 @@ namespace Tensorflow.Gradients
return new Tensor[] { grad_a, grad_b };
}

public static Tensor[] _BatchMatMul(Operation op, Tensor[] grads)
{
throw new NotImplementedException();
}

[RegisterGradient("Mean")]
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
{


+ 35
- 0
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -471,6 +471,41 @@ namespace Tensorflow
return _op.outputs[0];
}
/// <summary>
/// Multiply slices of the two matrices "x" and "y".
/// </summary>
/// <remarks>
/// The `BatchMatMul` operation is embedded into the
/// `MatMul` operation on the DLL side. However the expected
/// attributes are not the same, hence we need to expose this
/// method to have the right args list on the `_apply_op_helper`
/// function.
///
/// For each rank > 2 the first rank - 2 dimensions are considered
/// as fixed, and have to be consistent across the two matrices. A
/// common matrix multiplication is then applied over the residual
/// 2 dimensions.
///
/// e.g.
/// x is (3, 6, 12); y is (3, 12, 6)
/// batch_matmul(x, y) ==> (3, 6, 6)
/// </remarks>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="adj_x"></param>
/// <param name="adj_y"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor batch_mat_mul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null)
{
var _op = _op_def_lib._apply_op_helper(
"BatchMatMul",
name,
args: new { x, y, adj_x, adj_y });
return _op.outputs[0];
}
/// <summary>
/// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
/// </summary>


+ 19
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -497,6 +497,25 @@ namespace Tensorflow
return result;
}

public static Tensor batch_matmul(Tensor x, Tensor y,
bool adj_x = false, bool adj_y = false,
string name = null)
{
Tensor result = null;

with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope =>
{
name = scope;

x = ops.convert_to_tensor(x, name: "a");
y = ops.convert_to_tensor(y, name: "b");

result = gen_math_ops.batch_mat_mul(x, y, adj_x, adj_y, name);
});

return result;
}

/// <summary>
/// Returns the complex conjugate of a complex number.
/// </summary>


+ 60
- 2
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -91,11 +91,69 @@ namespace TensorFlowNET.Examples
// graph: the two constants and matmul.
//
// The output of the op is returned in 'result' as a numpy `ndarray` object.
return with(tf.Session(), sess =>
using (sess = tf.Session())
{
var result = sess.run(product);
Console.WriteLine(result.ToString()); // ==> [[ 12.]]
return result.Data<int>()[0] == 12;
};

// `BatchMatMul` is actually embedded into the `MatMul` operation on the tensorflow.dll side. Every time we ask
// for a multiplication between matrices with rank > 2, the first rank - 2 dimensions are checked to be consistent
// across the two matrices and a common matrix multiplication is done on the residual 2 dimensions.
//
// np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(3, 3, 3)
// array([[[1, 2, 3],
// [4, 5, 6],
// [7, 8, 9]],
//
// [[1, 2, 3],
// [4, 5, 6],
// [7, 8, 9]],
//
// [[1, 2, 3],
// [4, 5, 6],
// [7, 8, 9]]])
var firstTensor = tf.convert_to_tensor(
np.reshape(
np.array<float>(1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9),
3, 3, 3));
//
// np.array([0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0]).reshape(3,3,2)
// array([[[0, 1],
// [0, 1],
// [0, 1]],
//
// [[0, 1],
// [0, 0],
// [1, 0]],
//
// [[1, 0],
// [1, 0],
// [1, 0]]])
var secondTensor = tf.convert_to_tensor(
np.reshape(
np.array<float>(0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0),
3, 3, 2));
var batchMul = tf.batch_matmul(firstTensor, secondTensor);
var checkTensor = np.array<float>(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0);
return with(tf.Session(), sess =>
{
var result = sess.run(batchMul);
Console.WriteLine(result.ToString());
//
// ==> array([[[0, 6],
// [0, 15],
// [0, 24]],
//
// [[ 3, 1],
// [ 6, 4],
// [ 9, 7]],
//
// [[ 6, 0],
// [15, 0],
// [24, 0]]])
return np.reshape(result, 18)
.array_equal(checkTensor);
});
}



Loading…
Cancel
Save