Browse Source

add tf.data Prefetch unit test #446

tags/v0.20
Oceania2018 5 years ago
parent
commit
b431f976c8
4 changed files with 25 additions and 5 deletions
  1. +2
    -4
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Data/DatasetManager.cs
  3. +19
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs
  4. +1
    -1
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs

+ 2
- 4
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -270,12 +270,10 @@ namespace Tensorflow
int i = 0; int i = 0;
foreach (var val in values) foreach (var val in values)
{ {
i += step;

if (i < start)
if (i++ < start)
continue; continue;


yield return (i - step - start, val);
yield return (i - 1, val);
} }
} }




+ 3
- 0
src/TensorFlowNET.Core/Data/DatasetManager.cs View File

@@ -12,5 +12,8 @@ namespace Tensorflow


public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64) public IDatasetV2 range(int count, TF_DataType output_type = TF_DataType.TF_INT64)
=> new RangeDataset(count, output_type: output_type); => new RangeDataset(count, output_type: output_type);

public IDatasetV2 range(int start, int stop, int step = 1, TF_DataType output_type = TF_DataType.TF_INT64)
=> new RangeDataset(stop, start: start, step: step, output_type: output_type);
} }
} }

+ 19
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -26,5 +26,24 @@ namespace TensorFlowNET.UnitTest.Dataset
value++; value++;
} }
} }

[TestMethod]
public void Prefetch()
{
int iStep = 0;
long value = 1;

var dataset = tf.data.Dataset.range(1, 5, 2);
dataset = dataset.prefetch(2);

foreach (var (step, item) in enumerate(dataset))
{
Assert.AreEqual(iStep, step);
iStep++;

Assert.AreEqual(value, (long)item.Item1);
value += 2;
}
}
} }
} }

+ 1
- 1
test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.Keras
/// <summary> /// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding /// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary> /// </summary>
[TestMethod]
[TestMethod, Ignore]
public void Embedding() public void Embedding()
{ {
var model = tf.keras.Sequential(); var model = tf.keras.Sequential();


Loading…
Cancel
Save