| @@ -28,7 +28,7 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| _structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); | |||
| structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); | |||
| } | |||
| variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, | |||
| @@ -15,13 +15,13 @@ namespace Tensorflow | |||
| protected dataset_ops ops = new dataset_ops(); | |||
| public Tensor variant_tensor { get; set; } | |||
| public TensorSpec[] _structure { get; set; } | |||
| public TensorSpec[] structure { get; set; } | |||
| public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray(); | |||
| public TensorShape[] output_shapes => structure.Select(x => x.shape).ToArray(); | |||
| public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray(); | |||
| public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); | |||
| public TensorSpec[] element_spec => _structure; | |||
| public TensorSpec[] element_spec => structure; | |||
| public IDatasetV2 take(int count = -1) | |||
| => new TakeDataset(this, count: count); | |||
| @@ -37,13 +37,52 @@ namespace Tensorflow | |||
| public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | |||
| => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | |||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | |||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | |||
| public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | |||
| => new ModelDataset(this, algorithm, cpu_budget); | |||
| public IDatasetV2 apply_options() | |||
| { | |||
| // (1) Apply threading options | |||
| var graph_rewrites = new[] | |||
| { | |||
| "map_and_batch_fusion", | |||
| "noop_elimination", | |||
| "shuffle_and_repeat_fusion" | |||
| }; | |||
| var graph_rewrite_configs = new string[0]; | |||
| // (2) Apply graph rewrite options | |||
| var dataset = optimize(graph_rewrites, graph_rewrite_configs); | |||
| // (3) Apply autotune options | |||
| var autotune = true; | |||
| long cpu_budget = 0; | |||
| if (autotune) | |||
| dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget); | |||
| // (4) Apply stats aggregator options | |||
| return dataset; | |||
| } | |||
| public override string ToString() | |||
| => $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})"; | |||
| => $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})"; | |||
| public IEnumerator<(Tensor, Tensor)> GetEnumerator() | |||
| { | |||
| throw new NotImplementedException(); | |||
| var ownedIterator = new OwnedIterator(this); | |||
| Tensor[] results = ownedIterator.next(); | |||
| while (results != null) | |||
| { | |||
| yield return (results[0], results[1]); | |||
| } | |||
| } | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||
| TensorSpec[] element_spec { get; } | |||
| TensorSpec[] _structure { get; set; } | |||
| TensorSpec[] structure { get; set; } | |||
| /// <summary> | |||
| /// | |||
| @@ -31,5 +31,15 @@ namespace Tensorflow | |||
| IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | |||
| IDatasetV2 take(int count); | |||
| IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | |||
| IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | |||
| /// <summary> | |||
| /// Apply options, such as optimization configuration, to the dataset. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| IDatasetV2 apply_options(); | |||
| } | |||
| } | |||
| @@ -0,0 +1,10 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class IteratorBase | |||
| { | |||
| } | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// An object which cleans up an iterator resource handle. | |||
| /// </summary> | |||
| public class IteratorResourceDeleter : IDisposable | |||
| { | |||
| Tensor _handle; | |||
| Tensor _deleter; | |||
| dataset_ops ops; | |||
| public IteratorResourceDeleter(Tensor handle, Tensor deleter) | |||
| { | |||
| _handle = handle; | |||
| _deleter = deleter; | |||
| ops = new dataset_ops(); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| ops.delete_iterator(_handle, _deleter); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,25 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Framework.Models; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// A `Dataset` that acts as an identity, and models performance. | |||
| /// </summary> | |||
| public class ModelDataset : UnaryUnchangedStructureDataset | |||
| { | |||
| public ModelDataset(IDatasetV2 input_dataset, | |||
| AutotuneAlgorithm algorithm, | |||
| long cpu_budget) : | |||
| base(input_dataset) | |||
| { | |||
| variant_tensor = ops.model_dataset(input_dataset.variant_tensor, | |||
| output_types, | |||
| output_shapes, | |||
| algorithm, | |||
| cpu_budget); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,34 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// A `Dataset` that acts as an identity, and applies optimizations. | |||
| /// </summary> | |||
| public class OptimizeDataset : UnaryUnchangedStructureDataset | |||
| { | |||
| Tensor _optimizations; | |||
| public OptimizeDataset(IDatasetV2 dataset, | |||
| string[] optimizations = null, | |||
| string[] optimization_configs = null) : | |||
| base(dataset) | |||
| { | |||
| if (optimizations == null) | |||
| optimizations = new string[0]; | |||
| if (optimization_configs == null) | |||
| optimization_configs = new string[0]; | |||
| _optimizations = tf.convert_to_tensor(optimizations, dtype: TF_DataType.TF_STRING, name: "optimizations"); | |||
| variant_tensor = ops.optimize_dataset( | |||
| _input_dataset.variant_tensor, | |||
| _optimizations, | |||
| output_types, | |||
| output_shapes, | |||
| optimization_configs: optimization_configs); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Framework.Models; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// An iterator producing tf.Tensor objects from a tf.data.Dataset. | |||
| /// </summary> | |||
| public class OwnedIterator : IteratorBase | |||
| { | |||
| IDatasetV2 _dataset; | |||
| TensorSpec[] _element_spec; | |||
| dataset_ops ops = new dataset_ops(); | |||
| Tensor _iterator_resource; | |||
| Tensor _deleter; | |||
| IteratorResourceDeleter _resource_deleter; | |||
| public OwnedIterator(IDatasetV2 dataset) | |||
| { | |||
| _create_iterator(dataset); | |||
| } | |||
| void _create_iterator(IDatasetV2 dataset) | |||
| { | |||
| dataset = dataset.apply_options(); | |||
| _dataset = dataset; | |||
| _element_spec = dataset.element_spec; | |||
| (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); | |||
| ops.make_iterator(dataset.variant_tensor, _iterator_resource); | |||
| // Delete the resource when this object is deleted | |||
| _resource_deleter = new IteratorResourceDeleter(_iterator_resource, _deleter); | |||
| } | |||
| public Tensor[] next() | |||
| => ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); | |||
| } | |||
| } | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||
| { | |||
| _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; | |||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||
| _structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||
| variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); | |||
| } | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||
| public UnaryDataset(IDatasetV2 input_dataset) | |||
| { | |||
| _input_dataset = input_dataset; | |||
| _structure = input_dataset._structure; | |||
| structure = input_dataset.structure; | |||
| } | |||
| } | |||
| } | |||