Browse Source

Add CacheDataset.

tags/v0.20
Oceania2018 5 years ago
parent
commit
a3cf7ae1ac
8 changed files with 177 additions and 0 deletions
  1. +23
    -0
      src/TensorFlowNET.Core/Data/CacheDataset.cs
  2. +13
    -0
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +13
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  4. +28
    -0
      src/TensorFlowNET.Core/Data/MapDataset.cs
  5. +3
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  7. +63
    -0
      src/TensorFlowNET.Core/Operations/dataset_ops.cs
  8. +31
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 23
- 0
src/TensorFlowNET.Core/Data/CacheDataset.cs View File

@@ -0,0 +1,23 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class CacheDataset : UnaryUnchangedStructureDataset
{
Tensor _filename;
public CacheDataset(IDatasetV2 input_dataset,
string filename = "") :
base(input_dataset)
{
_filename = tf.convert_to_tensor(filename, dtype: TF_DataType.TF_STRING, name: "filename");
variant_tensor = ops.cache_dataset_v2(input_dataset.variant_tensor,
_filename,
ops.dummy_memory_cache(),
output_types,
output_shapes);
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -23,6 +23,9 @@ namespace Tensorflow
public TensorSpec[] element_spec => structure;

public IDatasetV2 cache(string filename = "")
=> new CacheDataset(this, filename: filename);

public IDatasetV2 take(int count = -1)
=> new TakeDataset(this, count: count);

@@ -47,6 +50,16 @@ namespace Tensorflow
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);

public IDatasetV2 map(Func<Tensor, Tensor> map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool use_legacy_function = false)
=> new MapDataset(this,
map_func,
use_inter_op_parallelism: use_inter_op_parallelism,
preserve_cardinality: preserve_cardinality,
use_legacy_function: use_legacy_function);

public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
=> new ModelDataset(this, algorithm, cpu_budget);



+ 13
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Text;
using Tensorflow.Framework.Models;

@@ -17,6 +18,13 @@ namespace Tensorflow

TensorSpec[] structure { get; set; }

/// <summary>
/// Caches the elements in this dataset.
/// </summary>
/// <param name="filename"></param>
/// <returns></returns>
IDatasetV2 cache(string filename="");

/// <summary>
///
/// </summary>
@@ -49,6 +57,11 @@ namespace Tensorflow

IDatasetV2 optimize(string[] optimizations, string[] optimization_configs);

IDatasetV2 map(Func<Tensor, Tensor> map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool use_legacy_function = false);

IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget);

/// <summary>


+ 28
- 0
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// A `Dataset` that maps a function over elements in its input.
/// </summary>
public class MapDataset : UnaryDataset
{
public MapDataset(IDatasetV2 input_dataset,
Func<Tensor, Tensor> map_func,
bool use_inter_op_parallelism = true,
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
foreach(var input in input_dataset)
{
var data = map_func(input.Item1);
}

variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
output_types,
output_shapes);
}
}
}

+ 3
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -381,6 +381,9 @@ namespace Tensorflow.Eager
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
status.Check(true);
break;
case TF_AttrType.TF_ATTR_FUNC:
c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length);
break;
default:
throw new NotImplementedException($"SetOpAttrScalar for {type}");
}


+ 3
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -196,6 +196,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value);

[DllImport(TensorFlowLibName)]
public static extern void TFE_OpSetAttrFunctionName(SafeOpHandle op, string attr_name, string data, int length);

/// <summary>
///
/// </summary>


+ 63
- 0
src/TensorFlowNET.Core/Operations/dataset_ops.cs View File

@@ -155,6 +155,24 @@ namespace Tensorflow
throw new NotImplementedException("");
}

public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache,
TF_DataType[] output_types, TensorShape[] output_shapes,
string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"CacheDatasetV2", name,
null,
input_dataset, filename, cache,
"output_types", output_types,
"output_shapes", output_shapes);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Creates a dataset that batches `batch_size` elements from `input_dataset`.
/// </summary>
@@ -187,6 +205,24 @@ namespace Tensorflow
throw new NotImplementedException("");
}

/// <summary>
///
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
public Tensor dummy_memory_cache(string name = "")
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DummyMemoryCache", name,
null);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// Creates a dataset that asynchronously prefetches elements from `input_dataset`.
/// </summary>
@@ -354,6 +390,33 @@ namespace Tensorflow
throw new NotImplementedException("");
}

/// <summary>
///
/// </summary>
/// <param name="dataset"></param>
/// <param name="iterator"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes,
bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"MapDataset", name,
null,
dataset, new Tensor[0],
"f", "MapDataset",
"output_types", output_types,
"output_shapes", output_shapes,
"use_inter_op_parallelism", use_inter_op_parallelism,
"preserve_cardinality", preserve_cardinality);
return results[0];
}

throw new NotImplementedException("");
}

/// <summary>
/// A container for an iterator resource.
/// </summary>


+ 31
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.Keras;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;
@@ -116,5 +117,35 @@ namespace TensorFlowNET.UnitTest.Dataset
value ++;
}
}

[TestMethod, Ignore]
public void Map()
{
long value = 0;

var dataset = tf.data.Dataset.range(3);
var dataset1 = dataset.map(x => x);

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1);
value++;
}
}

[TestMethod]
public void Cache()
{
long value = 0;

var dataset = tf.data.Dataset.range(5);
dataset = dataset.cache();

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1);
value++;
}
}
}
}

Loading…
Cancel
Save