From 3d943a10c17082321c97ec6aedb371cf91dec2e4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 19 Jun 2021 21:07:39 -0500 Subject: [PATCH] IDataset cardinality. --- src/TensorFlowNET.Core/APIs/tf.data.cs | 2 + .../Contexts/Context.ExecuteOp.cs | 52 ++++++++--------- src/TensorFlowNET.Core/Data/DatasetV2.cs | 12 +++- src/TensorFlowNET.Core/Data/FilterDataset.cs | 58 +++++++++++++++++++ src/TensorFlowNET.Core/Data/IDatasetV2.cs | 5 +- .../Operations/dataset_ops.cs | 19 ++++++ .../Tensors/Tensor.Equal.cs | 13 +++++ src/TensorFlowNET.Core/Tensors/constant_op.cs | 6 ++ .../Engine/DataAdapters/DataHandler.cs | 2 +- .../Dataset/DatasetTest.cs | 19 +++++- 10 files changed, 155 insertions(+), 33 deletions(-) create mode 100644 src/TensorFlowNET.Core/Data/FilterDataset.cs create mode 100644 src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.data.cs b/src/TensorFlowNET.Core/APIs/tf.data.cs index b4a92ce5..6c41a839 100644 --- a/src/TensorFlowNET.Core/APIs/tf.data.cs +++ b/src/TensorFlowNET.Core/APIs/tf.data.cs @@ -23,6 +23,8 @@ namespace Tensorflow public class DataOps { public int AUTOTUNE = -1; + public int INFINITE_CARDINALITY = -1; + public int UNKNOWN_CARDINALITY = -2; public DatasetManager Dataset { get; } = new DatasetManager(); } } diff --git a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs index d6eedd47..5b256455 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs @@ -29,48 +29,48 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { - // [DebuggerStepThrough] - public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args) + Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args) { - Func graphAction = () => + var keywords = new Dictionary(); + if (args.OpInputArgs != null) { - var keywords = new Dictionary(); - if(args.OpInputArgs != null) - { - foreach (var (i, input) in enumerate(args.OpInputArgs)) - keywords[$"input_{i}"] = input; - } + foreach (var (i, input) in enumerate(args.OpInputArgs)) + keywords[$"input_{i}"] = input; + } - if(args.OpAttrs != null) - { - foreach (var attr in args.OpAttrs) - keywords[attr.Key] = attr.Value; - } + if (args.OpAttrs != null) + { + foreach (var attr in args.OpAttrs) + keywords[attr.Key] = attr.Value; + } - return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; - }; + return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; + } - Func eagerAction = () => + Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args) + { + var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) { - var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) - { - attrs = args.OpAttrs - }; - return tf.Runner.TFE_FastPathExecute(opExecInfo); + attrs = args.OpAttrs }; + return tf.Runner.TFE_FastPathExecute(opExecInfo); + } + // [DebuggerStepThrough] + public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args) + { if (tf.Context.has_graph_arg(args.OpInputArgs)) { if (executing_eagerly()) { graph_mode(); - var result = graphAction(); + var result = ExecGraphAction(opType, name, args); restore_mode(); return result; } else { - var result = graphAction(); + var result = ExecGraphAction(opType, name, args); if (tf.Runner.MustRecordGradient()) { var op = result[0].op; @@ -92,14 +92,14 @@ namespace Tensorflow.Contexts args1[i + 1] = arg.Value; i += 2; } - tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs); + tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs); } return result; } } else { - return eagerAction(); + return ExecEagerAction(opType, name, args); } } } diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 0ae6187a..ab07168f 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -70,6 +70,12 @@ namespace Tensorflow num_parallel_calls: num_parallel_calls, preserve_cardinality: true); + public IDatasetV2 filter(Func predicate_func) + => new FilterDataset(this, predicate_func); + + public IDatasetV2 filter(Func predicate_func) + => new FilterDataset(this, predicate_func); + public OwnedIterator make_one_shot_iterator() { if (tf.Context.executing_eagerly()) @@ -105,13 +111,15 @@ namespace Tensorflow // (3) Apply graph rewrite options var graph_rewrites = new[] { - "noop_elimination", "map_and_batch_fusion", + "map_parallelization", + "noop_elimination", "shuffle_and_repeat_fusion" }; var graph_rewrite_configs = new string[] { "autotune_buffer_sizes:autotune:true", + "batch_parallelization:autotune:true", "disable_prefetch_legacy_autotune:autotune:true", "enable_gradient_descent:autotune:true", "map_parallelization:autotune:true" @@ -124,7 +132,7 @@ namespace Tensorflow return dataset; } - public Tensor dataset_cardinality(string name = null) + public Tensor cardinality(string name = null) => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); public override string ToString() diff --git a/src/TensorFlowNET.Core/Data/FilterDataset.cs b/src/TensorFlowNET.Core/Data/FilterDataset.cs new file mode 100644 index 00000000..84dfa0ae --- /dev/null +++ b/src/TensorFlowNET.Core/Data/FilterDataset.cs @@ -0,0 +1,58 @@ +using System; +using Tensorflow.Functions; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` that filters its input according to a predicate function. + /// + public class FilterDataset : UnaryDataset + { + public FilterDataset(IDatasetV2 input_dataset, + Func predicate_func) : base(input_dataset) + { + Func predicate_func_update = x => + { + var result = predicate_func(x); + return constant_op.constant(result); + }; + + var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = predicate_func_update(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); + + structure = func.OutputStructure; + + variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, + func, + output_types, + output_shapes); + } + + public FilterDataset(IDatasetV2 input_dataset, + Func predicate_func) : base(input_dataset) + { + var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = predicate_func(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); + + structure = func.OutputStructure; + + variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, + func, + output_types, + output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 143b8f12..88d8bcb6 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -72,6 +72,9 @@ namespace Tensorflow IDatasetV2 map(Func map_func, int num_parallel_calls); + IDatasetV2 filter(Func map_func); + IDatasetV2 filter(Func map_func); + OwnedIterator make_one_shot_iterator(); IDatasetV2 flat_map(Func map_func); @@ -91,6 +94,6 @@ namespace Tensorflow /// /// /// - Tensor dataset_cardinality(string name = null); + Tensor cardinality(string name = null); } } diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index c350ba9e..9fda99f7 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -249,6 +249,25 @@ namespace Tensorflow preserve_cardinality })); + /// + /// Creates a dataset containing elements of `input_dataset` matching `predicate`. + /// + /// + /// + /// + /// + /// + /// + public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, TensorShape[] output_shapes, + string name = null) + => tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) + .SetAttributes(new + { + predicate, + output_types, + output_shapes + })); + /// /// Creates a dataset that applies `f` to the outputs of `input_dataset`. /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs new file mode 100644 index 00000000..c3cbdb6a --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Equal.cs @@ -0,0 +1,13 @@ +using System; +using System.Runtime.CompilerServices; + +namespace Tensorflow +{ + public partial class Tensor + { + public static Tensor operator !=(Tensor x, int y) + => gen_math_ops.not_equal(x, math_ops.cast(y, dtype: x.dtype)); + public static Tensor operator ==(Tensor x, int y) + => gen_math_ops.equal(x, math_ops.cast(y, dtype: x.dtype)); + } +} diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 2aa63883..adfd9c24 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -144,6 +144,12 @@ namespace Tensorflow break; } } + else if (dtype != TF_DataType.DtInvalid && + value is NDArray nd && + dtypes.as_dtype(nd.dtype) != dtype) + { + value = nd.astype(dtype.as_numpy_dtype()); + } if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING); diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index a5b26e2c..fdc7fbbe 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -87,7 +87,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters if (adapter_steps > -1) return adapter_steps; - var size = dataset.dataset_cardinality(); + var size = dataset.cardinality(); return size.numpy(); } diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index f624476c..b705284b 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -147,10 +147,10 @@ namespace TensorFlowNET.UnitTest.Dataset public void Cardinality() { var dataset = tf.data.Dataset.range(10); - var cardinality = dataset.dataset_cardinality(); + var cardinality = dataset.cardinality(); Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); dataset = dataset.map(x => x[0] + 1); - cardinality = dataset.dataset_cardinality(); + cardinality = dataset.cardinality(); Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); } @@ -159,10 +159,23 @@ namespace TensorFlowNET.UnitTest.Dataset { var dataset = tf.data.Dataset.range(10); dataset = dataset.map(x => x, num_parallel_calls: -1); - var cardinality = dataset.dataset_cardinality(); + var cardinality = dataset.cardinality(); Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); } + [TestMethod] + public void CardinalityWithRepeat() + { + var dataset = tf.data.Dataset.range(10); + dataset = dataset.repeat(); + var cardinality = dataset.cardinality(); + Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy()); + + dataset = dataset.filter(x => true); + cardinality = dataset.cardinality(); + Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()); + } + [TestMethod] public void Shuffle() {