diff --git a/src/TensorFlowNET.Core/APIs/tf.tensor.cs b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
index 8ba78f42..3121e354 100644
--- a/src/TensorFlowNET.Core/APIs/tf.tensor.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.tensor.cs
@@ -66,11 +66,18 @@ namespace Tensorflow
/// A name for the operation (optional)
/// if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects;
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.
- public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split(
+ public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
+ => array_ops.split(
value: value,
+ num_split: num_split,
axis: axis,
+ name: name);
+
+ public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
+ => array_ops.split(
+ value: value,
num_split: num_split,
- name: name
- );
+ axis: axis,
+ name: name);
}
}
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
index 5fe9986c..1c5c344f 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
@@ -44,10 +44,11 @@ namespace Tensorflow.Eager
break;
}
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
+ status.Check(true);
}
}
if (status.ok())
- SetOpAttrs(op, attrs, status.Handle);
+ SetOpAttrs(op, attrs);
var outputs = new IntPtr[num_outputs];
if (status.ok())
diff --git a/src/TensorFlowNET.Core/Eager/Execute.cs b/src/TensorFlowNET.Core/Eager/Execute.cs
index 52df5a7d..04c11a1d 100644
--- a/src/TensorFlowNET.Core/Eager/Execute.cs
+++ b/src/TensorFlowNET.Core/Eager/Execute.cs
@@ -44,27 +44,27 @@ namespace Tensorflow.Eager
return results;
}
- public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
+ public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
{
if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
return (default_dtype, null);
- if (args.Count(x => x is EagerTensor) == args.Length)
- return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray());
+ if (args.Count(x => x is Tensor) == args.Length)
+ return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray());
var dtype = TF_DataType.DtInvalid;
foreach (var x in args)
{
- if (x is EagerTensor et)
+ if (x is Tensor et)
dtype = et.dtype;
}
if (dtype == TF_DataType.DtInvalid)
{
- var ret = new List();
+ var ret = new List();
foreach (var t in args)
{
- ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor);
+ ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor);
if (dtype == TF_DataType.DtInvalid)
dtype = ret.Last().dtype;
}
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index 33c5f7c5..36148287 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -117,7 +117,7 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
- out_grads = gen_array_ops.split(grad, squeeze_sizes, (int)non_neg_concat_dim).ToList();
+ out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split:(int)non_neg_concat_dim).ToList();
}
else
{
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
index 1cb352ae..d8d1cb3d 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
@@ -91,7 +91,7 @@ namespace Tensorflow
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);
// i = input_gate, j = new_input, f = forget_gate, o = output_gate
- var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
+ var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);
var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index 4cb55119..a4335046 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -18,6 +18,7 @@ using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -668,14 +669,29 @@ namespace Tensorflow
});
}
- public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis,
+ public static Tensor[] split(Tensor value, int num_split, T axis,
string name = "split")
{
- var size_splits = ops.convert_to_tensor(num_or_size_splits);
- return gen_array_ops.split(axis: axis,
- num_split: num_or_size_splits,
- value: value,
- name: name);
+ var size_splits = ops.convert_to_tensor(num_split);
+
+ if (tf.context.executing_eagerly())
+ {
+ return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.context);
+ }
+
+ var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
+ return _op.outputs;
+ }
+
+ private static Tensor[] split_eager_fallback(Ta axis, Tv value, int num_split, string name, Context ctx = null)
+ {
+ var (_attr_T, input) = tf._execute.args_to_matching_eager(ctx, args: new object[] { value });
+ var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32);
+ var _inputs_flat = new List { axis_tensor };
+ _inputs_flat.AddRange(input);
+ var _attrs = new object[] { "num_split", num_split, "T", _attr_T };
+
+ return tf._execute.execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
}
public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null)
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 1bf03e7f..575ea46e 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -483,22 +483,6 @@ namespace Tensorflow
return _op.outputs[0];
}
- public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null)
- {
- if (tf.context.executing_eagerly())
- {
- var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
- "Split", name,
- null,
- axis, value, num_split);
-
- return results;
- }
-
- var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
- return _op.outputs;
- }
-
public static Tensor tile(Tensor input, T multiples, string name = null)
{
if (tf.context.executing_eagerly())
diff --git a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs
index d2bd19ca..f30321a1 100644
--- a/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs
+++ b/test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs
@@ -32,6 +32,7 @@ namespace Tensorflow.UnitTest.TF_API
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape));
}
+
[TestMethod]
public void ConcatTest()
{
@@ -42,6 +43,7 @@ namespace Tensorflow.UnitTest.TF_API
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
}
+
[TestMethod]
public void ConcatDoubleTest()
{//double type has some error
@@ -52,19 +54,19 @@ namespace Tensorflow.UnitTest.TF_API
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
}
+
[TestMethod]
- public void SplitTest()
+ public void ConcatAndSplitTest()
{
var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } });
var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } });
var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });
- var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
+ var value = tf.concat(new[] { a, b, c }, axis: 0);
- var splitValue = tf.split(concatValue, 3, axis: new Tensor(0));
+ var splitValue = tf.split(value, 3, axis: 0);
+ Assert.AreEqual(3, splitValue.Length);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape));
-
}
-
}
}