| @@ -28,7 +28,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| _structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); | |||||
| structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray(); | |||||
| } | } | ||||
| variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, | variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, | ||||
| @@ -15,13 +15,13 @@ namespace Tensorflow | |||||
| protected dataset_ops ops = new dataset_ops(); | protected dataset_ops ops = new dataset_ops(); | ||||
| public Tensor variant_tensor { get; set; } | public Tensor variant_tensor { get; set; } | ||||
| public TensorSpec[] _structure { get; set; } | |||||
| public TensorSpec[] structure { get; set; } | |||||
| public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray(); | |||||
| public TensorShape[] output_shapes => structure.Select(x => x.shape).ToArray(); | |||||
| public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray(); | |||||
| public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray(); | |||||
| public TensorSpec[] element_spec => _structure; | |||||
| public TensorSpec[] element_spec => structure; | |||||
| public IDatasetV2 take(int count = -1) | public IDatasetV2 take(int count = -1) | ||||
| => new TakeDataset(this, count: count); | => new TakeDataset(this, count: count); | ||||
| @@ -37,13 +37,52 @@ namespace Tensorflow | |||||
| public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | ||||
| => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | ||||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | |||||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | |||||
| public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | |||||
| => new ModelDataset(this, algorithm, cpu_budget); | |||||
| public IDatasetV2 apply_options() | |||||
| { | |||||
| // (1) Apply threading options | |||||
| var graph_rewrites = new[] | |||||
| { | |||||
| "map_and_batch_fusion", | |||||
| "noop_elimination", | |||||
| "shuffle_and_repeat_fusion" | |||||
| }; | |||||
| var graph_rewrite_configs = new string[0]; | |||||
| // (2) Apply graph rewrite options | |||||
| var dataset = optimize(graph_rewrites, graph_rewrite_configs); | |||||
| // (3) Apply autotune options | |||||
| var autotune = true; | |||||
| long cpu_budget = 0; | |||||
| if (autotune) | |||||
| dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget); | |||||
| // (4) Apply stats aggregator options | |||||
| return dataset; | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| => $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})"; | |||||
| => $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})"; | |||||
| public IEnumerator<(Tensor, Tensor)> GetEnumerator() | public IEnumerator<(Tensor, Tensor)> GetEnumerator() | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| var ownedIterator = new OwnedIterator(this); | |||||
| Tensor[] results = ownedIterator.next(); | |||||
| while (results != null) | |||||
| { | |||||
| yield return (results[0], results[1]); | |||||
| } | |||||
| } | } | ||||
| IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| TensorSpec[] element_spec { get; } | TensorSpec[] element_spec { get; } | ||||
| TensorSpec[] _structure { get; set; } | |||||
| TensorSpec[] structure { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -31,5 +31,15 @@ namespace Tensorflow | |||||
| IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | ||||
| IDatasetV2 take(int count); | IDatasetV2 take(int count); | ||||
| IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | |||||
| IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | |||||
| /// <summary> | |||||
| /// Apply options, such as optimization configuration, to the dataset. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| IDatasetV2 apply_options(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class IteratorBase | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,29 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// An object which cleans up an iterator resource handle. | |||||
| /// </summary> | |||||
| public class IteratorResourceDeleter : IDisposable | |||||
| { | |||||
| Tensor _handle; | |||||
| Tensor _deleter; | |||||
| dataset_ops ops; | |||||
| public IteratorResourceDeleter(Tensor handle, Tensor deleter) | |||||
| { | |||||
| _handle = handle; | |||||
| _deleter = deleter; | |||||
| ops = new dataset_ops(); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| ops.delete_iterator(_handle, _deleter); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,25 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that acts as an identity, and models performance. | |||||
| /// </summary> | |||||
| public class ModelDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| public ModelDataset(IDatasetV2 input_dataset, | |||||
| AutotuneAlgorithm algorithm, | |||||
| long cpu_budget) : | |||||
| base(input_dataset) | |||||
| { | |||||
| variant_tensor = ops.model_dataset(input_dataset.variant_tensor, | |||||
| output_types, | |||||
| output_shapes, | |||||
| algorithm, | |||||
| cpu_budget); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,34 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that acts as an identity, and applies optimizations. | |||||
| /// </summary> | |||||
| public class OptimizeDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _optimizations; | |||||
| public OptimizeDataset(IDatasetV2 dataset, | |||||
| string[] optimizations = null, | |||||
| string[] optimization_configs = null) : | |||||
| base(dataset) | |||||
| { | |||||
| if (optimizations == null) | |||||
| optimizations = new string[0]; | |||||
| if (optimization_configs == null) | |||||
| optimization_configs = new string[0]; | |||||
| _optimizations = tf.convert_to_tensor(optimizations, dtype: TF_DataType.TF_STRING, name: "optimizations"); | |||||
| variant_tensor = ops.optimize_dataset( | |||||
| _input_dataset.variant_tensor, | |||||
| _optimizations, | |||||
| output_types, | |||||
| output_shapes, | |||||
| optimization_configs: optimization_configs); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,40 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// An iterator producing tf.Tensor objects from a tf.data.Dataset. | |||||
| /// </summary> | |||||
| public class OwnedIterator : IteratorBase | |||||
| { | |||||
| IDatasetV2 _dataset; | |||||
| TensorSpec[] _element_spec; | |||||
| dataset_ops ops = new dataset_ops(); | |||||
| Tensor _iterator_resource; | |||||
| Tensor _deleter; | |||||
| IteratorResourceDeleter _resource_deleter; | |||||
| public OwnedIterator(IDatasetV2 dataset) | |||||
| { | |||||
| _create_iterator(dataset); | |||||
| } | |||||
| void _create_iterator(IDatasetV2 dataset) | |||||
| { | |||||
| dataset = dataset.apply_options(); | |||||
| _dataset = dataset; | |||||
| _element_spec = dataset.element_spec; | |||||
| (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); | |||||
| ops.make_iterator(dataset.variant_tensor, _iterator_resource); | |||||
| // Delete the resource when this object is deleted | |||||
| _resource_deleter = new IteratorResourceDeleter(_iterator_resource, _deleter); | |||||
| } | |||||
| public Tensor[] next() | |||||
| => ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); | |||||
| } | |||||
| } | |||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; | _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; | ||||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | ||||
| _structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
| variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); | variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| public UnaryDataset(IDatasetV2 input_dataset) | public UnaryDataset(IDatasetV2 input_dataset) | ||||
| { | { | ||||
| _input_dataset = input_dataset; | _input_dataset = input_dataset; | ||||
| _structure = input_dataset._structure; | |||||
| structure = input_dataset.structure; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||