From 436afe9703d708035a27ba586a723f71654da2b6 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 22 Aug 2020 19:59:06 -0500 Subject: [PATCH] tf.data.Dataset.from_tensor #446 --- src/TensorFlowNET.Core/Data/DatasetManager.cs | 17 +++++++++- .../Data/GeneratorDataset.cs | 11 +++++++ src/TensorFlowNET.Core/Data/TensorDataset.cs | 33 +++++++++++++++++++ .../Operations/dataset_ops.cs | 18 ++++++++++ .../Dataset/DatasetTest.cs | 17 ++++++++++ 5 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 src/TensorFlowNET.Core/Data/GeneratorDataset.cs create mode 100644 src/TensorFlowNET.Core/Data/TensorDataset.cs diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs index 9110ef73..059aef24 100644 --- a/src/TensorFlowNET.Core/Data/DatasetManager.cs +++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Data; @@ -7,6 +8,20 @@ namespace Tensorflow { public class DatasetManager { + public IDatasetV2 from_generator(IEnumerable generator, TF_DataType[] output_types, TensorShape[] output_shapes) + => new GeneratorDataset(); + + /// + /// Creates a `Dataset` with a single element, comprising the given tensors. + /// + /// + /// + public IDatasetV2 from_tensor(NDArray tensors) + => new TensorDataset(tensors); + + public IDatasetV2 from_tensor(Tensor tensors) + => new TensorDataset(tensors); + public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) => new TensorSliceDataset(features, labels); diff --git a/src/TensorFlowNET.Core/Data/GeneratorDataset.cs b/src/TensorFlowNET.Core/Data/GeneratorDataset.cs new file mode 100644 index 00000000..937c1eb4 --- /dev/null +++ b/src/TensorFlowNET.Core/Data/GeneratorDataset.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Data +{ + public class GeneratorDataset : DatasetSource + { + + } +} diff --git a/src/TensorFlowNET.Core/Data/TensorDataset.cs b/src/TensorFlowNET.Core/Data/TensorDataset.cs new file mode 100644 index 00000000..78cd77fe --- /dev/null +++ b/src/TensorFlowNET.Core/Data/TensorDataset.cs @@ -0,0 +1,33 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// A `Dataset` with a single element. + /// + public class TensorDataset : DatasetSource + { + public TensorDataset(Tensor element) + { + _tensors = new[] { element }; + var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); + structure = batched_spec.Select(x => x._unbatch()).ToArray(); + + variant_tensor = ops.tensor_dataset(_tensors, output_shapes); + } + + public TensorDataset(NDArray element) + { + _tensors = new[] { tf.convert_to_tensor(element) }; + var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); + structure = batched_spec.ToArray(); + + variant_tensor = ops.tensor_dataset(_tensors, output_shapes); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 276dc462..bebe9153 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -8,6 +8,24 @@ namespace Tensorflow { public class dataset_ops { + public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "TensorDataset", name, + null, + new object[] + { + components, + "output_shapes", output_shapes + }); + return results[0]; + } + + throw new NotImplementedException(""); + } + /// /// Creates a dataset that emits each dim-0 slice of `components` once. /// diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 37430980..389aea52 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -1,7 +1,9 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using Tensorflow.Keras; using Tensorflow.UnitTest; using static Tensorflow.Binding; @@ -62,6 +64,21 @@ namespace TensorFlowNET.UnitTest.Dataset Assert.AreEqual(5, n); } + [TestMethod] + public void FromTensor() + { + var X = new[] { 2013, 2014, 2015, 2016, 2017 }; + + var dataset = tf.data.Dataset.from_tensor(X); + int n = 0; + foreach (var x in dataset) + { + Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray())); + n += 1; + } + Assert.AreEqual(1, n); + } + [TestMethod] public void Shard() {