diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 77549574..bef72417 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -27,6 +27,18 @@ namespace Tensorflow /// public string newaxis = ""; + /// + /// BatchToSpace for N-D tensors of type T. + /// + /// + /// + /// + /// + /// + /// + public Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null) + => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); + public Tensor check_numerics(Tensor tensor, string message, string name = null) => gen_array_ops.check_numerics(tensor, message, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 6ed03040..5d037a09 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -26,6 +26,13 @@ namespace Tensorflow public static OpDefLibrary _op_def_lib = new OpDefLibrary(); public static Execute _execute = new Execute(); + public static Tensor batch_to_space_nd(T input, int[] block_shape, int[,] crops, string name = null) + { + var _op = _op_def_lib._apply_op_helper("BatchToSpaceND", name: name, args: new { input, block_shape, crops }); + + return _op.output; + } + public static Tensor check_numerics(Tensor tensor, string message, string name = null) { var _op = _op_def_lib._apply_op_helper("CheckNumerics", name: name, args: new { tensor, message }); diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 6b5b5dec..b73f15a9 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -242,5 +242,23 @@ namespace TensorFlowNET.UnitTest Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray())); } } + + [TestMethod] + public void batch_to_space_nd() + { + var inputs = np.arange(24).reshape(4, 2, 3); + var block_shape = new[] { 2, 2 }; + int[,] crops = { { 0, 0 }, { 0, 0 } }; + var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); + + using (var sess = tf.Session()) + { + var result = sess.run(tensor); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray())); + } + } } } \ No newline at end of file