| @@ -66,11 +66,18 @@ namespace Tensorflow | |||
| /// <param name="name">A name for the operation (optional)</param> | |||
| /// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects; | |||
| /// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns> | |||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split( | |||
| public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) | |||
| => array_ops.split( | |||
| value: value, | |||
| num_split: num_split, | |||
| axis: axis, | |||
| name: name); | |||
| public Tensor[] split(Tensor value, int num_split, int axis, string name = null) | |||
| => array_ops.split( | |||
| value: value, | |||
| num_split: num_split, | |||
| name: name | |||
| ); | |||
| axis: axis, | |||
| name: name); | |||
| } | |||
| } | |||
| @@ -44,10 +44,11 @@ namespace Tensorflow.Eager | |||
| break; | |||
| } | |||
| c_api.TFE_OpAddInput(op, tensor_handle, status.Handle); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| if (status.ok()) | |||
| SetOpAttrs(op, attrs, status.Handle); | |||
| SetOpAttrs(op, attrs); | |||
| var outputs = new IntPtr[num_outputs]; | |||
| if (status.ok()) | |||
| @@ -44,27 +44,27 @@ namespace Tensorflow.Eager | |||
| return results; | |||
| } | |||
| public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) | |||
| public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null) | |||
| { | |||
| if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid) | |||
| return (default_dtype, null); | |||
| if (args.Count(x => x is EagerTensor) == args.Length) | |||
| return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray()); | |||
| if (args.Count(x => x is Tensor) == args.Length) | |||
| return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray()); | |||
| var dtype = TF_DataType.DtInvalid; | |||
| foreach (var x in args) | |||
| { | |||
| if (x is EagerTensor et) | |||
| if (x is Tensor et) | |||
| dtype = et.dtype; | |||
| } | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| { | |||
| var ret = new List<EagerTensor>(); | |||
| var ret = new List<Tensor>(); | |||
| foreach (var t in args) | |||
| { | |||
| ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor); | |||
| ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor); | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = ret.Last().dtype; | |||
| } | |||
| @@ -117,7 +117,7 @@ namespace Tensorflow.Gradients | |||
| new Tensor[] { non_neg_concat_dim, tf.constant(0) }, | |||
| new Tensor[] { tf.constant(1), tf.constant(-1) }); | |||
| var squeeze_sizes = array_ops.squeeze(slice); | |||
| out_grads = gen_array_ops.split(grad, squeeze_sizes, (int)non_neg_concat_dim).ToList(); | |||
| out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split:(int)non_neg_concat_dim).ToList(); | |||
| } | |||
| else | |||
| { | |||
| @@ -91,7 +91,7 @@ namespace Tensorflow | |||
| gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable); | |||
| // i = input_gate, j = new_input, f = forget_gate, o = output_gate | |||
| var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one); | |||
| var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one); | |||
| var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]); | |||
| var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype); | |||
| @@ -18,6 +18,7 @@ using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Framework; | |||
| using static Tensorflow.Binding; | |||
| @@ -668,14 +669,29 @@ namespace Tensorflow | |||
| }); | |||
| } | |||
| public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis, | |||
| public static Tensor[] split<T>(Tensor value, int num_split, T axis, | |||
| string name = "split") | |||
| { | |||
| var size_splits = ops.convert_to_tensor(num_or_size_splits); | |||
| return gen_array_ops.split(axis: axis, | |||
| num_split: num_or_size_splits, | |||
| value: value, | |||
| name: name); | |||
| var size_splits = ops.convert_to_tensor(num_split); | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.context); | |||
| } | |||
| var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||
| return _op.outputs; | |||
| } | |||
| private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null) | |||
| { | |||
| var (_attr_T, input) = tf._execute.args_to_matching_eager(ctx, args: new object[] { value }); | |||
| var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32); | |||
| var _inputs_flat = new List<Tensor> { axis_tensor }; | |||
| _inputs_flat.AddRange(input); | |||
| var _attrs = new object[] { "num_split", num_split, "T", _attr_T }; | |||
| return tf._execute.execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name); | |||
| } | |||
| public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||
| @@ -483,22 +483,6 @@ namespace Tensorflow | |||
| return _op.outputs[0]; | |||
| } | |||
| public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name, | |||
| "Split", name, | |||
| null, | |||
| axis, value, num_split); | |||
| return results; | |||
| } | |||
| var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split }); | |||
| return _op.outputs; | |||
| } | |||
| public static Tensor tile<T>(Tensor input, T multiples, string name = null) | |||
| { | |||
| if (tf.context.executing_eagerly()) | |||
| @@ -32,6 +32,7 @@ namespace Tensorflow.UnitTest.TF_API | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape)); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape)); | |||
| } | |||
| [TestMethod] | |||
| public void ConcatTest() | |||
| { | |||
| @@ -42,6 +43,7 @@ namespace Tensorflow.UnitTest.TF_API | |||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||
| } | |||
| [TestMethod] | |||
| public void ConcatDoubleTest() | |||
| {//double type has some error | |||
| @@ -52,19 +54,19 @@ namespace Tensorflow.UnitTest.TF_API | |||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape)); | |||
| } | |||
| [TestMethod] | |||
| public void SplitTest() | |||
| public void ConcatAndSplitTest() | |||
| { | |||
| var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } }); | |||
| var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } }); | |||
| var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } }); | |||
| var concatValue = tf.concat(new[] { a, b, c }, axis: 0); | |||
| var value = tf.concat(new[] { a, b, c }, axis: 0); | |||
| var splitValue = tf.split(concatValue, 3, axis: new Tensor(0)); | |||
| var splitValue = tf.split(value, 3, axis: 0); | |||
| Assert.AreEqual(3, splitValue.Length); | |||
| Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape)); | |||
| } | |||
| } | |||
| } | |||