From 0c7638f8b50454fcfc72672a78902d0e784532c2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 4 Jul 2020 20:41:28 -0500 Subject: [PATCH] Fix tf.split, #569 --- src/TensorFlowNET.Core/APIs/tf.tensor.cs | 13 +++++++-- .../Eager/EagerRunner.TFE_Execute.cs | 3 +- src/TensorFlowNET.Core/Eager/Execute.cs | 12 ++++---- .../Gradients/array_grad.cs | 2 +- .../Operations/NnOps/BasicLSTMCell.cs | 2 +- .../Operations/array_ops.cs | 28 +++++++++++++++---- .../Operations/gen_array_ops.cs | 16 ----------- .../TF_API/TensorOperate.cs | 12 ++++---- 8 files changed, 49 insertions(+), 39 deletions(-) diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs index 8ba78f42..3121e354 100644 --- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs +++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs @@ -66,11 +66,18 @@ namespace Tensorflow /// A name for the operation (optional) /// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value. - public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split( + public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) + => array_ops.split( value: value, + num_split: num_split, axis: axis, + name: name); + + public Tensor[] split(Tensor value, int num_split, int axis, string name = null) + => array_ops.split( + value: value, num_split: num_split, - name: name - ); + axis: axis, + name: name); } } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs index 5fe9986c..1c5c344f 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs @@ -44,10 +44,11 @@ namespace Tensorflow.Eager break; } c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); + status.Check(true); } } if (status.ok()) - SetOpAttrs(op, attrs, status.Handle); + SetOpAttrs(op, attrs); var outputs = new IntPtr[num_outputs]; if (status.ok()) diff --git a/src/TensorFlowNET.Core/Eager/Execute.cs b/src/TensorFlowNET.Core/Eager/Execute.cs index 52df5a7d..04c11a1d 100644 --- a/src/TensorFlowNET.Core/Eager/Execute.cs +++ b/src/TensorFlowNET.Core/Eager/Execute.cs @@ -44,27 +44,27 @@ namespace Tensorflow.Eager return results; } - public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) + public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) { if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) return (default_dtype, null); - if (args.Count(x => x is EagerTensor) == args.Length) - return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray()); + if (args.Count(x => x is Tensor) == args.Length) + return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray()); var dtype = TF_DataType.DtInvalid; foreach (var x in args) { - if (x is EagerTensor et) + if (x is Tensor et) dtype = et.dtype; } if (dtype == TF_DataType.DtInvalid) { - var ret = new List(); + var ret = new List(); foreach (var t in args) { - ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor); + ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor); if (dtype == TF_DataType.DtInvalid) dtype = ret.Last().dtype; } diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs index 33c5f7c5..36148287 100644 --- a/src/TensorFlowNET.Core/Gradients/array_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs @@ -117,7 +117,7 @@ namespace Tensorflow.Gradients new Tensor[] { non_neg_concat_dim, tf.constant(0) }, new Tensor[] { tf.constant(1), tf.constant(-1) }); var squeeze_sizes = array_ops.squeeze(slice); - out_grads = gen_array_ops.split(grad, squeeze_sizes, (int)non_neg_concat_dim).ToList(); + out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split:(int)non_neg_concat_dim).ToList(); } else { diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs index 1cb352ae..d8d1cb3d 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs @@ -91,7 +91,7 @@ namespace Tensorflow gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); // i = input_gate, j = new_input, f = forget_gate, o = output_gate - var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); + var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 4cb55119..a4335046 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -18,6 +18,7 @@ using NumSharp; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -668,14 +669,29 @@ namespace Tensorflow }); } - public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, + public static Tensor[] split(Tensor value, int num_split, T axis, string name = "split") { - var size_splits = ops.convert_to_tensor(num_or_size_splits); - return gen_array_ops.split(axis: axis, - num_split: num_or_size_splits, - value: value, - name: name); + var size_splits = ops.convert_to_tensor(num_split); + + if (tf.context.executing_eagerly()) + { + return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.context); + } + + var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); + return _op.outputs; + } + + private static Tensor[] split_eager_fallback(Ta axis, Tv value, int num_split, string name, Context ctx = null) + { + var (_attr_T, input) = tf._execute.args_to_matching_eager(ctx, args: new object[] { value }); + var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); + var _inputs_flat = new List { axis_tensor }; + _inputs_flat.AddRange(input); + var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; + + return tf._execute.execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); } public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 1bf03e7f..575ea46e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -483,22 +483,6 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) - { - if (tf.context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Split", name, - null, - axis, value, num_split); - - return results; - } - - var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); - return _op.outputs; - } - public static Tensor tile(Tensor input, T multiples, string name = null) { if (tf.context.executing_eagerly()) diff --git a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs index d2bd19ca..f30321a1 100644 --- a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs +++ b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs @@ -32,6 +32,7 @@ namespace Tensorflow.UnitTest.TF_API Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); } + [TestMethod] public void ConcatTest() { @@ -42,6 +43,7 @@ namespace Tensorflow.UnitTest.TF_API var concatValue = tf.concat(new[] { a, b, c }, axis: 0); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); } + [TestMethod] public void ConcatDoubleTest() {//double type has some error @@ -52,19 +54,19 @@ namespace Tensorflow.UnitTest.TF_API var concatValue = tf.concat(new[] { a, b, c }, axis: 0); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); } + [TestMethod] - public void SplitTest() + public void ConcatAndSplitTest() { var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); - var concatValue = tf.concat(new[] { a, b, c }, axis: 0); + var value = tf.concat(new[] { a, b, c }, axis: 0); - var splitValue = tf.split(concatValue, 3, axis: new Tensor(0)); + var splitValue = tf.split(value, 3, axis: 0); + Assert.AreEqual(3, splitValue.Length); Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); - } - } }