Browse Source

IDataset cardinality.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
3d943a10c1
10 changed files with 155 additions and 33 deletions
  1. +2
    -0
      src/TensorFlowNET.Core/APIs/tf.data.cs
  2. +26
    -26
      src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs
  3. +10
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  4. +58
    -0
      src/TensorFlowNET.Core/Data/FilterDataset.cs
  5. +4
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  6. +19
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  7. +13
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs
  8. +6
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  9. +1
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  10. +16
    -3
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 2
- 0
src/TensorFlowNET.Core/APIs/tf.data.cs View File

@@ -23,6 +23,8 @@ namespace Tensorflow
public class DataOps public class DataOps
{ {
public int AUTOTUNE = -1; public int AUTOTUNE = -1;
public int INFINITE_CARDINALITY = -1;
public int UNKNOWN_CARDINALITY = -2;
public DatasetManager Dataset { get; } = new DatasetManager(); public DatasetManager Dataset { get; } = new DatasetManager();
} }
} }


+ 26
- 26
src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs View File

@@ -29,48 +29,48 @@ namespace Tensorflow.Contexts
/// </summary> /// </summary>
public sealed partial class Context public sealed partial class Context
{ {
// [DebuggerStepThrough]
public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args)
Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args)
{ {
Func<Tensors> graphAction = () =>
var keywords = new Dictionary<string, object>();
if (args.OpInputArgs != null)
{ {
var keywords = new Dictionary<string, object>();
if(args.OpInputArgs != null)
{
foreach (var (i, input) in enumerate(args.OpInputArgs))
keywords[$"input_{i}"] = input;
}
foreach (var (i, input) in enumerate(args.OpInputArgs))
keywords[$"input_{i}"] = input;
}


if(args.OpAttrs != null)
{
foreach (var attr in args.OpAttrs)
keywords[attr.Key] = attr.Value;
}
if (args.OpAttrs != null)
{
foreach (var attr in args.OpAttrs)
keywords[attr.Key] = attr.Value;
}


return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
};
return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs;
}


Func<Tensors> eagerAction = () =>
Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args)
{
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
{ {
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs)
{
attrs = args.OpAttrs
};
return tf.Runner.TFE_FastPathExecute(opExecInfo);
attrs = args.OpAttrs
}; };
return tf.Runner.TFE_FastPathExecute(opExecInfo);
}


// [DebuggerStepThrough]
public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args)
{
if (tf.Context.has_graph_arg(args.OpInputArgs)) if (tf.Context.has_graph_arg(args.OpInputArgs))
{ {
if (executing_eagerly()) if (executing_eagerly())
{ {
graph_mode(); graph_mode();
var result = graphAction();
var result = ExecGraphAction(opType, name, args);
restore_mode(); restore_mode();
return result; return result;
} }
else else
{ {
var result = graphAction();
var result = ExecGraphAction(opType, name, args);
if (tf.Runner.MustRecordGradient()) if (tf.Runner.MustRecordGradient())
{ {
var op = result[0].op; var op = result[0].op;
@@ -92,14 +92,14 @@ namespace Tensorflow.Contexts
args1[i + 1] = arg.Value; args1[i + 1] = arg.Value;
i += 2; i += 2;
} }
tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs);
tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs);
} }
return result; return result;
} }
} }
else else
{ {
return eagerAction();
return ExecEagerAction(opType, name, args);
} }
} }
} }


+ 10
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -70,6 +70,12 @@ namespace Tensorflow
num_parallel_calls: num_parallel_calls, num_parallel_calls: num_parallel_calls,
preserve_cardinality: true); preserve_cardinality: true);


public IDatasetV2 filter(Func<Tensors, Tensors> predicate_func)
=> new FilterDataset(this, predicate_func);

public IDatasetV2 filter(Func<Tensor, bool> predicate_func)
=> new FilterDataset(this, predicate_func);

public OwnedIterator make_one_shot_iterator() public OwnedIterator make_one_shot_iterator()
{ {
if (tf.Context.executing_eagerly()) if (tf.Context.executing_eagerly())
@@ -105,13 +111,15 @@ namespace Tensorflow
// (3) Apply graph rewrite options // (3) Apply graph rewrite options
var graph_rewrites = new[] var graph_rewrites = new[]
{ {
"noop_elimination",
"map_and_batch_fusion", "map_and_batch_fusion",
"map_parallelization",
"noop_elimination",
"shuffle_and_repeat_fusion" "shuffle_and_repeat_fusion"
}; };
var graph_rewrite_configs = new string[] var graph_rewrite_configs = new string[]
{ {
"autotune_buffer_sizes:autotune:true", "autotune_buffer_sizes:autotune:true",
"batch_parallelization:autotune:true",
"disable_prefetch_legacy_autotune:autotune:true", "disable_prefetch_legacy_autotune:autotune:true",
"enable_gradient_descent:autotune:true", "enable_gradient_descent:autotune:true",
"map_parallelization:autotune:true" "map_parallelization:autotune:true"
@@ -124,7 +132,7 @@ namespace Tensorflow
return dataset; return dataset;
} }


public Tensor dataset_cardinality(string name = null)
public Tensor cardinality(string name = null)
=> tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));


public override string ToString() public override string ToString()


+ 58
- 0
src/TensorFlowNET.Core/Data/FilterDataset.cs View File

@@ -0,0 +1,58 @@
using System;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that filters its input according to a predicate function.
/// </summary>
public class FilterDataset : UnaryDataset
{
public FilterDataset(IDatasetV2 input_dataset,
Func<Tensor, bool> predicate_func) : base(input_dataset)
{
Func<Tensors, Tensors> predicate_func_update = x =>
{
var result = predicate_func(x);
return constant_op.constant(result);
};

var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");
func.Enter();
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
var outputs = predicate_func_update(inputs);
func.ToGraph(inputs, outputs);
func.Exit();

structure = func.OutputStructure;

variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}

public FilterDataset(IDatasetV2 input_dataset,
Func<Tensors, Tensors> predicate_func) : base(input_dataset)
{
var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}");
func.Enter();
var inputs = new Tensors();
foreach (var input in input_dataset.element_spec)
inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg"));
var outputs = predicate_func(inputs);
func.ToGraph(inputs, outputs);
func.Exit();

structure = func.OutputStructure;

variant_tensor = ops.filter_dataset(input_dataset.variant_tensor,
func,
output_types,
output_shapes);
}
}
}

+ 4
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -72,6 +72,9 @@ namespace Tensorflow
IDatasetV2 map(Func<Tensors, Tensors> map_func, IDatasetV2 map(Func<Tensors, Tensors> map_func,
int num_parallel_calls); int num_parallel_calls);


IDatasetV2 filter(Func<Tensors, Tensors> map_func);
IDatasetV2 filter(Func<Tensor, bool> map_func);

OwnedIterator make_one_shot_iterator(); OwnedIterator make_one_shot_iterator();


IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);
@@ -91,6 +94,6 @@ namespace Tensorflow
/// </summary> /// </summary>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>
Tensor dataset_cardinality(string name = null);
Tensor cardinality(string name = null);
} }
} }

+ 19
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -249,6 +249,25 @@ namespace Tensorflow
preserve_cardinality preserve_cardinality
})); }));


/// <summary>
/// Creates a dataset containing elements of `input_dataset` matching `predicate`.
/// </summary>
/// <param name="dataset"></param>
/// <param name="predicate"></param>
/// <param name="output_types"></param>
/// <param name="output_shapes"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
=> tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0])
.SetAttributes(new
{
predicate,
output_types,
output_shapes
}));

/// <summary> /// <summary>
/// Creates a dataset that applies `f` to the outputs of `input_dataset`. /// Creates a dataset that applies `f` to the outputs of `input_dataset`.
/// </summary> /// </summary>


+ 13
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs View File

@@ -0,0 +1,13 @@
using System;
using System.Runtime.CompilerServices;

namespace Tensorflow
{
public partial class Tensor
{
public static Tensor operator !=(Tensor x, int y)
=> gen_math_ops.not_equal(x, math_ops.cast(y, dtype: x.dtype));
public static Tensor operator ==(Tensor x, int y)
=> gen_math_ops.equal(x, math_ops.cast(y, dtype: x.dtype));
}
}

+ 6
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -144,6 +144,12 @@ namespace Tensorflow
break; break;
} }
} }
else if (dtype != TF_DataType.DtInvalid &&
value is NDArray nd &&
dtypes.as_dtype(nd.dtype) != dtype)
{
value = nd.astype(dtype.as_numpy_dtype());
}


if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) if (dtype == TF_DataType.TF_STRING && value is byte[] bytes)
return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING); return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING);


+ 1
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -87,7 +87,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
if (adapter_steps > -1) if (adapter_steps > -1)
return adapter_steps; return adapter_steps;


var size = dataset.dataset_cardinality();
var size = dataset.cardinality();
return size.numpy(); return size.numpy();
} }




+ 16
- 3
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -147,10 +147,10 @@ namespace TensorFlowNET.UnitTest.Dataset
public void Cardinality() public void Cardinality()
{ {
var dataset = tf.data.Dataset.range(10); var dataset = tf.data.Dataset.range(10);
var cardinality = dataset.dataset_cardinality();
var cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
dataset = dataset.map(x => x[0] + 1); dataset = dataset.map(x => x[0] + 1);
cardinality = dataset.dataset_cardinality();
cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
} }


@@ -159,10 +159,23 @@ namespace TensorFlowNET.UnitTest.Dataset
{ {
var dataset = tf.data.Dataset.range(10); var dataset = tf.data.Dataset.range(10);
dataset = dataset.map(x => x, num_parallel_calls: -1); dataset = dataset.map(x => x, num_parallel_calls: -1);
var cardinality = dataset.dataset_cardinality();
var cardinality = dataset.cardinality();
Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); Assert.AreEqual(new long[] { 10 }, cardinality.numpy());
} }


[TestMethod]
public void CardinalityWithRepeat()
{
var dataset = tf.data.Dataset.range(10);
dataset = dataset.repeat();
var cardinality = dataset.cardinality();
Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy());

dataset = dataset.filter(x => true);
cardinality = dataset.cardinality();
Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy());
}

[TestMethod] [TestMethod]
public void Shuffle() public void Shuffle()
{ {


Loading…
Cancel
Save