diff --git a/src/TensorFlowNET.Core/Data/BatchDataset.cs b/src/TensorFlowNET.Core/Data/BatchDataset.cs index 86d65e28..331cc5ba 100644 --- a/src/TensorFlowNET.Core/Data/BatchDataset.cs +++ b/src/TensorFlowNET.Core/Data/BatchDataset.cs @@ -28,7 +28,7 @@ namespace Tensorflow } else { - _structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); + structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); } variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index a422490c..09bc38ff 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -15,13 +15,13 @@ namespace Tensorflow protected dataset_ops ops = new dataset_ops(); public Tensor variant_tensor { get; set; } - public TensorSpec[] _structure { get; set; } + public TensorSpec[] structure { get; set; } - public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray(); + public TensorShape[] output_shapes => structure.Select(x => x.shape).ToArray(); - public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray(); + public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); - public TensorSpec[] element_spec => _structure; + public TensorSpec[] element_spec => structure; public IDatasetV2 take(int count = -1) => new TakeDataset(this, count: count); @@ -37,13 +37,52 @@ namespace Tensorflow public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); - + + public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) + => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); + + public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) + => new ModelDataset(this, algorithm, cpu_budget); + + public IDatasetV2 apply_options() + { + // (1) Apply threading options + var graph_rewrites = new[] + { + "map_and_batch_fusion", + "noop_elimination", + "shuffle_and_repeat_fusion" + }; + + var graph_rewrite_configs = new string[0]; + + // (2) Apply graph rewrite options + var dataset = optimize(graph_rewrites, graph_rewrite_configs); + + // (3) Apply autotune options + var autotune = true; + long cpu_budget = 0; + + if (autotune) + dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget); + + // (4) Apply stats aggregator options + + return dataset; + } + public override string ToString() - => $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})"; + => $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})"; public IEnumerator<(Tensor, Tensor)> GetEnumerator() { - throw new NotImplementedException(); + var ownedIterator = new OwnedIterator(this); + + Tensor[] results = ownedIterator.next(); + while (results != null) + { + yield return (results[0], results[1]); + } } IEnumerator IEnumerable.GetEnumerator() diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index c1b6f863..c5c32013 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -15,7 +15,7 @@ namespace Tensorflow TensorSpec[] element_spec { get; } - TensorSpec[] _structure { get; set; } + TensorSpec[] structure { get; set; } /// /// @@ -31,5 +31,15 @@ namespace Tensorflow IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); IDatasetV2 take(int count); + + IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); + + IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); + + /// + /// Apply options, such as optimization configuration, to the dataset. + /// + /// + IDatasetV2 apply_options(); } } diff --git a/src/TensorFlowNET.Core/Data/IteratorBase.cs b/src/TensorFlowNET.Core/Data/IteratorBase.cs new file mode 100644 index 00000000..159c0272 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/IteratorBase.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class IteratorBase + { + } +} diff --git a/src/TensorFlowNET.Core/Data/IteratorResourceDeleter.cs b/src/TensorFlowNET.Core/Data/IteratorResourceDeleter.cs new file mode 100644 index 00000000..2c981c81 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/IteratorResourceDeleter.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// An object which cleans up an iterator resource handle. + /// + public class IteratorResourceDeleter : IDisposable + { + Tensor _handle; + Tensor _deleter; + dataset_ops ops; + + public IteratorResourceDeleter(Tensor handle, Tensor deleter) + { + _handle = handle; + _deleter = deleter; + ops = new dataset_ops(); + } + + public void Dispose() + { + ops.delete_iterator(_handle, _deleter); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/ModelDataset.cs b/src/TensorFlowNET.Core/Data/ModelDataset.cs new file mode 100644 index 00000000..3e44b0b2 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/ModelDataset.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework.Models; + +namespace Tensorflow +{ + /// + /// A `Dataset` that acts as an identity, and models performance. + /// + public class ModelDataset : UnaryUnchangedStructureDataset + { + public ModelDataset(IDatasetV2 input_dataset, + AutotuneAlgorithm algorithm, + long cpu_budget) : + base(input_dataset) + { + variant_tensor = ops.model_dataset(input_dataset.variant_tensor, + output_types, + output_shapes, + algorithm, + cpu_budget); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/OptimizeDataset.cs b/src/TensorFlowNET.Core/Data/OptimizeDataset.cs new file mode 100644 index 00000000..818980f4 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/OptimizeDataset.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` that acts as an identity, and applies optimizations. + /// + public class OptimizeDataset : UnaryUnchangedStructureDataset + { + Tensor _optimizations; + + public OptimizeDataset(IDatasetV2 dataset, + string[] optimizations = null, + string[] optimization_configs = null) : + base(dataset) + { + if (optimizations == null) + optimizations = new string[0]; + if (optimization_configs == null) + optimization_configs = new string[0]; + + _optimizations = tf.convert_to_tensor(optimizations, dtype: TF_DataType.TF_STRING, name: "optimizations"); + variant_tensor = ops.optimize_dataset( + _input_dataset.variant_tensor, + _optimizations, + output_types, + output_shapes, + optimization_configs: optimization_configs); + } + } +} diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs new file mode 100644 index 00000000..bde214e8 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -0,0 +1,40 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework.Models; + +namespace Tensorflow +{ + /// + /// An iterator producing tf.Tensor objects from a tf.data.Dataset. + /// + public class OwnedIterator : IteratorBase + { + IDatasetV2 _dataset; + TensorSpec[] _element_spec; + dataset_ops ops = new dataset_ops(); + Tensor _iterator_resource; + Tensor _deleter; + IteratorResourceDeleter _resource_deleter; + + public OwnedIterator(IDatasetV2 dataset) + { + _create_iterator(dataset); + } + + void _create_iterator(IDatasetV2 dataset) + { + dataset = dataset.apply_options(); + _dataset = dataset; + _element_spec = dataset.element_spec; + (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); + ops.make_iterator(dataset.variant_tensor, _iterator_resource); + + // Delete the resource when this object is deleted + _resource_deleter = new IteratorResourceDeleter(_iterator_resource, _deleter); + } + + public Tensor[] next() + => ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); + } +} diff --git a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs index 910d56ba..e35cd8c5 100644 --- a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs +++ b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs @@ -15,7 +15,7 @@ namespace Tensorflow { _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); - _structure = batched_spec.Select(x => x._unbatch()).ToArray(); + structure = batched_spec.Select(x => x._unbatch()).ToArray(); variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); } diff --git a/src/TensorFlowNET.Core/Data/UnaryDataset.cs b/src/TensorFlowNET.Core/Data/UnaryDataset.cs index 4ebce8cc..3cab9a96 100644 --- a/src/TensorFlowNET.Core/Data/UnaryDataset.cs +++ b/src/TensorFlowNET.Core/Data/UnaryDataset.cs @@ -15,7 +15,7 @@ namespace Tensorflow public UnaryDataset(IDatasetV2 input_dataset) { _input_dataset = input_dataset; - _structure = input_dataset._structure; + structure = input_dataset.structure; } } }