Browse Source

Fix tf.split, #569

tags/v0.20
Oceania2018 5 years ago
parent
commit
0c7638f8b5
8 changed files with 49 additions and 39 deletions
  1. +10
    -3
      src/TensorFlowNET.Core/APIs/tf.tensor.cs
  2. +2
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  3. +6
    -6
      src/TensorFlowNET.Core/Eager/Execute.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Gradients/array_grad.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  6. +22
    -6
      src/TensorFlowNET.Core/Operations/array_ops.cs
  7. +0
    -16
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  8. +7
    -5
      test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs

+ 10
- 3
src/TensorFlowNET.Core/APIs/tf.tensor.cs View File

@@ -66,11 +66,18 @@ namespace Tensorflow
/// <param name="name">A name for the operation (optional)</param>
/// <returns>if num_or_size_splits is a scalar returns num_or_size_splits Tensor objects;
/// if num_or_size_splits is a 1-D Tensor returns num_or_size_splits.get_shape[0] Tensor objects resulting from splitting value.</returns>
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null) => gen_array_ops.split(
public Tensor[] split(Tensor value, int num_split, Tensor axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
axis: axis,
name: name);

public Tensor[] split(Tensor value, int num_split, int axis, string name = null)
=> array_ops.split(
value: value,
num_split: num_split,
name: name
);
axis: axis,
name: name);
}
}

+ 2
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -44,10 +44,11 @@ namespace Tensorflow.Eager
break;
}
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
status.Check(true);
}
}
if (status.ok())
SetOpAttrs(op, attrs, status.Handle);
SetOpAttrs(op, attrs);

var outputs = new IntPtr[num_outputs];
if (status.ok())


+ 6
- 6
src/TensorFlowNET.Core/Eager/Execute.cs View File

@@ -44,27 +44,27 @@ namespace Tensorflow.Eager
return results;
}

public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
public (TF_DataType, Tensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
{
if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
return (default_dtype, null);

if (args.Count(x => x is EagerTensor) == args.Length)
return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray());
if (args.Count(x => x is Tensor) == args.Length)
return ((args[0] as Tensor).dtype, args.Select(x => x as Tensor).ToArray());

var dtype = TF_DataType.DtInvalid;
foreach (var x in args)
{
if (x is EagerTensor et)
if (x is Tensor et)
dtype = et.dtype;
}

if (dtype == TF_DataType.DtInvalid)
{
var ret = new List<EagerTensor>();
var ret = new List<Tensor>();
foreach (var t in args)
{
ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor);
ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as Tensor);
if (dtype == TF_DataType.DtInvalid)
dtype = ret.Last().dtype;
}


+ 1
- 1
src/TensorFlowNET.Core/Gradients/array_grad.cs View File

@@ -117,7 +117,7 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
out_grads = gen_array_ops.split(grad, squeeze_sizes, (int)non_neg_concat_dim).ToList();
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split:(int)non_neg_concat_dim).ToList();
}
else
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -91,7 +91,7 @@ namespace Tensorflow
gate_inputs = nn_ops.bias_add(gate_inputs, _bias as RefVariable);

// i = input_gate, j = new_input, f = forget_gate, o = output_gate
var tensors = array_ops.split(value: gate_inputs, num_or_size_splits: 4, axis: one);
var tensors = array_ops.split(value: gate_inputs, num_split: 4, axis: one);
var (i, j, f, o) = (tensors[0], tensors[1], tensors[2], tensors[3]);

var forget_bias_tensor = constant_op.constant(_forget_bias, dtype: f.dtype);


+ 22
- 6
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -18,6 +18,7 @@ using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -668,14 +669,29 @@ namespace Tensorflow
});
}

public static Tensor[] split(Tensor value, int num_or_size_splits, Tensor axis,
public static Tensor[] split<T>(Tensor value, int num_split, T axis,
string name = "split")
{
var size_splits = ops.convert_to_tensor(num_or_size_splits);
return gen_array_ops.split(axis: axis,
num_split: num_or_size_splits,
value: value,
name: name);
var size_splits = ops.convert_to_tensor(num_split);

if (tf.context.executing_eagerly())
{
return split_eager_fallback(axis, value, num_split: num_split, name: name, ctx: tf.context);
}

var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}

private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_split, string name, Context ctx = null)
{
var (_attr_T, input) = tf._execute.args_to_matching_eager(ctx, args: new object[] { value });
var axis_tensor = ops.convert_to_tensor(axis, dtype: TF_DataType.TF_INT32);
var _inputs_flat = new List<Tensor> { axis_tensor };
_inputs_flat.AddRange(input);
var _attrs = new object[] { "num_split", num_split, "T", _attr_T };

return tf._execute.execute(ctx, "Split", num_split, _inputs_flat.ToArray(), _attrs, name: name);
}

public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)


+ 0
- 16
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -483,22 +483,6 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null)
{
if (tf.context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Split", name,
null,
axis, value, num_split);

return results;
}

var _op = tf._op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
return _op.outputs;
}

public static Tensor tile<T>(Tensor input, T multiples, string name = null)
{
if (tf.context.executing_eagerly())


+ 7
- 5
test/TensorFlowNET.UnitTest/TF_API/TensorOperate.cs View File

@@ -32,6 +32,7 @@ namespace Tensorflow.UnitTest.TF_API
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, a.shape));
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 3, 1 }, b.shape));
}

[TestMethod]
public void ConcatTest()
{
@@ -42,6 +43,7 @@ namespace Tensorflow.UnitTest.TF_API
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
}

[TestMethod]
public void ConcatDoubleTest()
{//double type has some error
@@ -52,19 +54,19 @@ namespace Tensorflow.UnitTest.TF_API
var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 6, 2 }, concatValue.shape));
}

[TestMethod]
public void SplitTest()
public void ConcatAndSplitTest()
{
var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } });
var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } });
var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });

var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
var value = tf.concat(new[] { a, b, c }, axis: 0);

var splitValue = tf.split(concatValue, 3, axis: new Tensor(0));
var splitValue = tf.split(value, 3, axis: 0);
Assert.AreEqual(3, splitValue.Length);
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 2, 2 }, splitValue[0].shape));

}

}
}

Loading…
Cancel
Save