| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Data; | using Tensorflow.Data; | ||||
| @@ -7,6 +8,20 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class DatasetManager | public class DatasetManager | ||||
| { | { | ||||
| public IDatasetV2 from_generator<T>(IEnumerable<T> generator, TF_DataType[] output_types, TensorShape[] output_shapes) | |||||
| => new GeneratorDataset(); | |||||
| /// <summary> | |||||
| /// Creates a `Dataset` with a single element, comprising the given tensors. | |||||
| /// </summary> | |||||
| /// <param name="tensors"></param> | |||||
| /// <returns></returns> | |||||
| public IDatasetV2 from_tensor(NDArray tensors) | |||||
| => new TensorDataset(tensors); | |||||
| public IDatasetV2 from_tensor(Tensor tensors) | |||||
| => new TensorDataset(tensors); | |||||
| public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels) | ||||
| => new TensorSliceDataset(features, labels); | => new TensorSliceDataset(features, labels); | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Data | |||||
| { | |||||
| public class GeneratorDataset : DatasetSource | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| using NumSharp; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` with a single element. | |||||
| /// </summary> | |||||
| public class TensorDataset : DatasetSource | |||||
| { | |||||
| public TensorDataset(Tensor element) | |||||
| { | |||||
| _tensors = new[] { element }; | |||||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
| variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | |||||
| } | |||||
| public TensorDataset(NDArray element) | |||||
| { | |||||
| _tensors = new[] { tf.convert_to_tensor(element) }; | |||||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| structure = batched_spec.ToArray(); | |||||
| variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -8,6 +8,24 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class dataset_ops | public class dataset_ops | ||||
| { | { | ||||
| public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "TensorDataset", name, | |||||
| null, | |||||
| new object[] | |||||
| { | |||||
| components, | |||||
| "output_shapes", output_shapes | |||||
| }); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a dataset that emits each dim-0 slice of `components` once. | /// Creates a dataset that emits each dim-0 slice of `components` once. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,7 +1,9 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras; | |||||
| using Tensorflow.UnitTest; | using Tensorflow.UnitTest; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -62,6 +64,21 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| Assert.AreEqual(5, n); | Assert.AreEqual(5, n); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void FromTensor() | |||||
| { | |||||
| var X = new[] { 2013, 2014, 2015, 2016, 2017 }; | |||||
| var dataset = tf.data.Dataset.from_tensor(X); | |||||
| int n = 0; | |||||
| foreach (var x in dataset) | |||||
| { | |||||
| Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>())); | |||||
| n += 1; | |||||
| } | |||||
| Assert.AreEqual(1, n); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Shard() | public void Shard() | ||||
| { | { | ||||