Browse Source

tf.data abstraction in IDatasetV2.

tags/v0.20
Oceania2018 5 years ago
parent
commit
ea2b4734a9
10 changed files with 198 additions and 11 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Data/BatchDataset.cs
  2. +46
    -7
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +11
    -1
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  4. +10
    -0
      src/TensorFlowNET.Core/Data/IteratorBase.cs
  5. +29
    -0
      src/TensorFlowNET.Core/Data/IteratorResourceDeleter.cs
  6. +25
    -0
      src/TensorFlowNET.Core/Data/ModelDataset.cs
  7. +34
    -0
      src/TensorFlowNET.Core/Data/OptimizeDataset.cs
  8. +40
    -0
      src/TensorFlowNET.Core/Data/OwnedIterator.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Data/TensorSliceDataset.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Data/UnaryDataset.cs

+ 1
- 1
src/TensorFlowNET.Core/Data/BatchDataset.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow
} }
else else
{ {
_structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray();
structure = input_dataset.element_spec.Select(x => x._batch(-1)).ToArray();
} }


variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor, variant_tensor = ops.batch_dataset_v2(input_dataset.variant_tensor,


+ 46
- 7
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -15,13 +15,13 @@ namespace Tensorflow
protected dataset_ops ops = new dataset_ops(); protected dataset_ops ops = new dataset_ops();
public Tensor variant_tensor { get; set; } public Tensor variant_tensor { get; set; }


public TensorSpec[] _structure { get; set; }
public TensorSpec[] structure { get; set; }


public TensorShape[] output_shapes => _structure.Select(x => x.shape).ToArray();
public TensorShape[] output_shapes => structure.Select(x => x.shape).ToArray();
public TF_DataType[] output_types => _structure.Select(x => x.dtype).ToArray();
public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
public TensorSpec[] element_spec => _structure;
public TensorSpec[] element_spec => structure;


public IDatasetV2 take(int count = -1) public IDatasetV2 take(int count = -1)
=> new TakeDataset(this, count: count); => new TakeDataset(this, count: count);
@@ -37,13 +37,52 @@ namespace Tensorflow


public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
=> new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);

public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);

public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
=> new ModelDataset(this, algorithm, cpu_budget);

public IDatasetV2 apply_options()
{
// (1) Apply threading options
var graph_rewrites = new[]
{
"map_and_batch_fusion",
"noop_elimination",
"shuffle_and_repeat_fusion"
};

var graph_rewrite_configs = new string[0];

// (2) Apply graph rewrite options
var dataset = optimize(graph_rewrites, graph_rewrite_configs);

// (3) Apply autotune options
var autotune = true;
long cpu_budget = 0;

if (autotune)
dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget);

// (4) Apply stats aggregator options

return dataset;
}

public override string ToString() public override string ToString()
=> $"{GetType().Name} shapes: ({_structure[0].shape}, {_structure[1].shape}), types: (tf.{_structure[0].dtype.as_numpy_name()}, tf.{_structure[1].dtype.as_numpy_name()})";
=> $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})";


public IEnumerator<(Tensor, Tensor)> GetEnumerator() public IEnumerator<(Tensor, Tensor)> GetEnumerator()
{ {
throw new NotImplementedException();
var ownedIterator = new OwnedIterator(this);

Tensor[] results = ownedIterator.next();
while (results != null)
{
yield return (results[0], results[1]);
}
} }


IEnumerator IEnumerable.GetEnumerator() IEnumerator IEnumerable.GetEnumerator()


+ 11
- 1
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow


TensorSpec[] element_spec { get; } TensorSpec[] element_spec { get; }


TensorSpec[] _structure { get; set; }
TensorSpec[] structure { get; set; }


/// <summary> /// <summary>
/// ///
@@ -31,5 +31,15 @@ namespace Tensorflow
IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null);


IDatasetV2 take(int count); IDatasetV2 take(int count);

IDatasetV2 optimize(string[] optimizations, string[] optimization_configs);

IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);

/// <summary>
/// Apply options, such as optimization configuration, to the dataset.
/// </summary>
/// <returns></returns>
IDatasetV2 apply_options();
} }
} }

+ 10
- 0
src/TensorFlowNET.Core/Data/IteratorBase.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class IteratorBase
{
}
}

+ 29
- 0
src/TensorFlowNET.Core/Data/IteratorResourceDeleter.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// An object which cleans up an iterator resource handle.
/// </summary>
public class IteratorResourceDeleter : IDisposable
{
Tensor _handle;
Tensor _deleter;
dataset_ops ops;

public IteratorResourceDeleter(Tensor handle, Tensor deleter)
{
_handle = handle;
_deleter = deleter;
ops = new dataset_ops();
}

public void Dispose()
{
ops.delete_iterator(_handle, _deleter);
}
}
}

+ 25
- 0
src/TensorFlowNET.Core/Data/ModelDataset.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that acts as an identity, and models performance.
/// </summary>
public class ModelDataset : UnaryUnchangedStructureDataset
{
public ModelDataset(IDatasetV2 input_dataset,
AutotuneAlgorithm algorithm,
long cpu_budget) :
base(input_dataset)
{
variant_tensor = ops.model_dataset(input_dataset.variant_tensor,
output_types,
output_shapes,
algorithm,
cpu_budget);
}
}
}

+ 34
- 0
src/TensorFlowNET.Core/Data/OptimizeDataset.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that acts as an identity, and applies optimizations.
/// </summary>
public class OptimizeDataset : UnaryUnchangedStructureDataset
{
Tensor _optimizations;

public OptimizeDataset(IDatasetV2 dataset,
string[] optimizations = null,
string[] optimization_configs = null) :
base(dataset)
{
if (optimizations == null)
optimizations = new string[0];
if (optimization_configs == null)
optimization_configs = new string[0];

_optimizations = tf.convert_to_tensor(optimizations, dtype: TF_DataType.TF_STRING, name: "optimizations");
variant_tensor = ops.optimize_dataset(
_input_dataset.variant_tensor,
_optimizations,
output_types,
output_shapes,
optimization_configs: optimization_configs);
}
}
}

+ 40
- 0
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -0,0 +1,40 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework.Models;

namespace Tensorflow
{
/// <summary>
/// An iterator producing tf.Tensor objects from a tf.data.Dataset.
/// </summary>
public class OwnedIterator : IteratorBase
{
IDatasetV2 _dataset;
TensorSpec[] _element_spec;
dataset_ops ops = new dataset_ops();
Tensor _iterator_resource;
Tensor _deleter;
IteratorResourceDeleter _resource_deleter;

public OwnedIterator(IDatasetV2 dataset)
{
_create_iterator(dataset);
}

void _create_iterator(IDatasetV2 dataset)
{
dataset = dataset.apply_options();
_dataset = dataset;
_element_spec = dataset.element_spec;
(_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes);
ops.make_iterator(dataset.variant_tensor, _iterator_resource);

// Delete the resource when this object is deleted
_resource_deleter = new IteratorResourceDeleter(_iterator_resource, _deleter);
}

public Tensor[] next()
=> ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Data/TensorSliceDataset.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
{ {
_tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) }; _tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
_structure = batched_spec.Select(x => x._unbatch()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray();
variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes); variant_tensor = ops.tensor_slice_dataset(_tensors, output_shapes);
} }


+ 1
- 1
src/TensorFlowNET.Core/Data/UnaryDataset.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
public UnaryDataset(IDatasetV2 input_dataset) public UnaryDataset(IDatasetV2 input_dataset)
{ {
_input_dataset = input_dataset; _input_dataset = input_dataset;
_structure = input_dataset._structure;
structure = input_dataset.structure;
} }
} }
} }

Loading…
Cancel
Save