| @@ -27,6 +27,18 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public string newaxis = ""; | public string newaxis = ""; | ||||
| /// <summary> | |||||
| /// BatchToSpace for N-D tensors of type T. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <param name="input"></param> | |||||
| /// <param name="block_shape"></param> | |||||
| /// <param name="crops"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null) | |||||
| => gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); | |||||
| public Tensor check_numerics(Tensor tensor, string message, string name = null) | public Tensor check_numerics(Tensor tensor, string message, string name = null) | ||||
| => gen_array_ops.check_numerics(tensor, message, name: name); | => gen_array_ops.check_numerics(tensor, message, name: name); | ||||
| @@ -26,6 +26,13 @@ namespace Tensorflow | |||||
| public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | public static OpDefLibrary _op_def_lib = new OpDefLibrary(); | ||||
| public static Execute _execute = new Execute(); | public static Execute _execute = new Execute(); | ||||
| public static Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("BatchToSpaceND", name: name, args: new { input, block_shape, crops }); | |||||
| return _op.output; | |||||
| } | |||||
| public static Tensor check_numerics(Tensor tensor, string message, string name = null) | public static Tensor check_numerics(Tensor tensor, string message, string name = null) | ||||
| { | { | ||||
| var _op = _op_def_lib._apply_op_helper("CheckNumerics", name: name, args: new { tensor, message }); | var _op = _op_def_lib._apply_op_helper("CheckNumerics", name: name, args: new { tensor, message }); | ||||
| @@ -242,5 +242,23 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0 }, result[2].ToArray<int>())); | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void batch_to_space_nd() | |||||
| { | |||||
| var inputs = np.arange(24).reshape(4, 2, 3); | |||||
| var block_shape = new[] { 2, 2 }; | |||||
| int[,] crops = { { 0, 0 }, { 0, 0 } }; | |||||
| var tensor = tf.batch_to_space_nd(inputs, block_shape, crops); | |||||
| using (var sess = tf.Session()) | |||||
| { | |||||
| var result = sess.run(tensor); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 6, 1, 7, 2, 8 }, result[0, 0].ToArray<int>())); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 12, 18, 13, 19, 14, 20 }, result[0, 1].ToArray<int>())); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 9, 4, 10, 5, 11 }, result[0, 2].ToArray<int>())); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>())); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||