diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 7c260810..9e97154c 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -270,12 +270,10 @@ namespace Tensorflow int i = 0; foreach (var val in values) { - i += step; - - if (i < start) + if (i++ < start) continue; - yield return (i - step - start, val); + yield return (i - 1, val); } } diff --git a/src/TensorFlowNET.Core/Data/DatasetManager.cs b/src/TensorFlowNET.Core/Data/DatasetManager.cs index 515e0dd3..9110ef73 100644 --- a/src/TensorFlowNET.Core/Data/DatasetManager.cs +++ b/src/TensorFlowNET.Core/Data/DatasetManager.cs @@ -12,5 +12,8 @@ namespace Tensorflow public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64) => new RangeDataset(count, output_type: output_type); + + public IDatasetV2 range(int start, int stop, int step = 1, TF_DataType output_type = TF_DataType.TF_INT64) + => new RangeDataset(stop, start: start, step: step, output_type: output_type); } } diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 8db484d7..25ffe217 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -26,5 +26,24 @@ namespace TensorFlowNET.UnitTest.Dataset value++; } } + + [TestMethod] + public void Prefetch() + { + int iStep = 0; + long value = 1; + + var dataset = tf.data.Dataset.range(1, 5, 2); + dataset = dataset.prefetch(2); + + foreach (var (step, item) in enumerate(dataset)) + { + Assert.AreEqual(iStep, step); + iStep++; + + Assert.AreEqual(value, (long)item.Item1); + value += 2; + } + } } } diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index 35fc46b2..f59fab8c 100644 --- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.Keras /// /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding /// - [TestMethod] + [TestMethod, Ignore] public void Embedding() { var model = tf.keras.Sequential();