From 7979d2fb864255e85fd7a45c3bd10717ba0af4a2 Mon Sep 17 00:00:00 2001 From: pepure Date: Sat, 4 Jul 2020 00:22:00 +0800 Subject: [PATCH] add tf.split problem describe --- src/TensorFlowNET.Core/Operations/gen_array_ops.cs | 10 ++++++++++ .../TensorFlowNET.UnitTest/TF_API/TensorOperate.cs | 14 ++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 2111564c..c7f5591f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -465,6 +465,16 @@ namespace Tensorflow public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Split", name, + null, + axis, value, num_split); + + return results; + } + var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); return _op.outputs; } diff --git a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs index 06448e5e..d2bd19ca 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs @@ -52,5 +52,19 @@ namespace Tensorflow.UnitTest.TF_API var concatValue = tf.concat(new[] { a, b, c }, axis: 0); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); } + [TestMethod] + public void SplitTest() + { + var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); + var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); + var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); + + var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + + var splitValue = tf.split(concatValue, 3, axis: new Tensor(0)); + Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); + + } + } }