You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetTest.cs 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow.Keras;
  7. using Tensorflow.UnitTest;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.Dataset
  10. {
  11. [TestClass]
  12. public class DatasetTest : EagerModeTestBase
  13. {
  14. [TestMethod]
  15. public void Range()
  16. {
  17. int iStep = 0;
  18. long value = 0;
  19. var dataset = tf.data.Dataset.range(3);
  20. foreach(var (step, item) in enumerate(dataset))
  21. {
  22. Assert.AreEqual(iStep, step);
  23. iStep++;
  24. Assert.AreEqual(value, (long)item.Item1);
  25. value++;
  26. }
  27. }
  28. [TestMethod]
  29. public void Prefetch()
  30. {
  31. int iStep = 0;
  32. long value = 1;
  33. var dataset = tf.data.Dataset.range(1, 5, 2);
  34. dataset = dataset.prefetch(2);
  35. foreach (var (step, item) in enumerate(dataset))
  36. {
  37. Assert.AreEqual(iStep, step);
  38. iStep++;
  39. Assert.AreEqual(value, (long)item.Item1);
  40. value += 2;
  41. }
  42. }
  43. [TestMethod]
  44. public void FromTensorSlices()
  45. {
  46. var X = tf.constant(new[] { 2013, 2014, 2015, 2016, 2017 });
  47. var Y = tf.constant(new[] { 12000, 14000, 15000, 16500, 17500 });
  48. var dataset = tf.data.Dataset.from_tensor_slices(X, Y);
  49. int n = 0;
  50. foreach (var (item_x, item_y) in dataset)
  51. {
  52. print($"x:{item_x.numpy()},y:{item_y.numpy()}");
  53. n += 1;
  54. }
  55. Assert.AreEqual(5, n);
  56. }
  57. [TestMethod]
  58. public void FromTensor()
  59. {
  60. var X = new[] { 2013, 2014, 2015, 2016, 2017 };
  61. var dataset = tf.data.Dataset.from_tensor(X);
  62. int n = 0;
  63. foreach (var x in dataset)
  64. {
  65. Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
  66. n += 1;
  67. }
  68. Assert.AreEqual(1, n);
  69. }
  70. [TestMethod]
  71. public void Shard()
  72. {
  73. long value = 0;
  74. var dataset1 = tf.data.Dataset.range(10);
  75. var dataset2 = dataset1.shard(num_shards: 3, index: 0);
  76. foreach (var item in dataset2)
  77. {
  78. Assert.AreEqual(value, (long)item.Item1);
  79. value += 3;
  80. }
  81. value = 1;
  82. var dataset3 = dataset1.shard(num_shards: 3, index: 1);
  83. foreach (var item in dataset3)
  84. {
  85. Assert.AreEqual(value, (long)item.Item1);
  86. value += 3;
  87. }
  88. }
  89. [TestMethod]
  90. public void Skip()
  91. {
  92. long value = 7;
  93. var dataset = tf.data.Dataset.range(10);
  94. dataset = dataset.skip(7);
  95. foreach (var item in dataset)
  96. {
  97. Assert.AreEqual(value, (long)item.Item1);
  98. value ++;
  99. }
  100. }
  101. }
  102. }