You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetV2.cs 5.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Data;
  6. using Tensorflow.Framework.Models;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow
  9. {
  10. /// <summary>
  11. /// Abstract class representing a dataset with no inputs.
  12. /// </summary>
  13. public class DatasetV2 : IDatasetV2
  14. {
  15. protected dataset_ops ops = new dataset_ops();
  16. public string[] class_names { get; set; }
  17. public Tensor variant_tensor { get; set; }
  18. public TensorSpec[] structure { get; set; }
  19. public TensorShape[] output_shapes => structure.Select(x => x.shape).ToArray();
  20. public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
  21. public TensorSpec[] element_spec => structure;
  22. public IDatasetV2 cache(string filename = "")
  23. => new CacheDataset(this, filename: filename);
  24. public IDatasetV2 concatenate(IDatasetV2 dataset)
  25. => new ConcatenateDataset(this, dataset);
  26. public IDatasetV2 take(int count = -1)
  27. => new TakeDataset(this, count: count);
  28. public IDatasetV2 batch(int batch_size, bool drop_remainder = false)
  29. => new BatchDataset(this, batch_size, drop_remainder: drop_remainder);
  30. public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
  31. => new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period);
  32. public IDatasetV2 repeat(int count = -1)
  33. => new RepeatDataset(this, count: count);
  34. public IDatasetV2 shard(int num_shards, int index)
  35. => new ShardDataset(this, num_shards, index);
  36. public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
  37. => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
  38. public IDatasetV2 skip(int count)
  39. => new SkipDataset(this, count);
  40. public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
  41. => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
  42. public IDatasetV2 map(Func<Tensors, Tensors> map_func,
  43. bool use_inter_op_parallelism = true,
  44. bool preserve_cardinality = true,
  45. bool use_legacy_function = false)
  46. => new MapDataset(this,
  47. map_func,
  48. use_inter_op_parallelism: use_inter_op_parallelism,
  49. preserve_cardinality: preserve_cardinality,
  50. use_legacy_function: use_legacy_function);
  51. public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
  52. => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
  53. public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
  54. => new FlatMapDataset(this, map_func);
  55. public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget)
  56. => new ModelDataset(this, algorithm, cpu_budget);
  57. public IDatasetV2 with_options(DatasetOptions options)
  58. => new OptionsDataset(this, options);
  59. public IDatasetV2 apply_options()
  60. {
  61. // (1) Apply threading options
  62. var graph_rewrites = new[]
  63. {
  64. "map_and_batch_fusion",
  65. "noop_elimination",
  66. "shuffle_and_repeat_fusion"
  67. };
  68. var graph_rewrite_configs = new string[0];
  69. // (2) Apply graph rewrite options
  70. var dataset = optimize(graph_rewrites, graph_rewrite_configs);
  71. // (3) Apply autotune options
  72. var autotune = true;
  73. long cpu_budget = 0;
  74. if (autotune)
  75. dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget);
  76. // (4) Apply stats aggregator options
  77. return dataset;
  78. }
  79. public Tensor dataset_cardinality(string name = null)
  80. {
  81. if (tf.Context.executing_eagerly())
  82. {
  83. var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
  84. "DatasetCardinality", name,
  85. null,
  86. variant_tensor);
  87. return results[0];
  88. }
  89. throw new NotImplementedException("");
  90. }
  91. public override string ToString()
  92. => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";
  93. public IEnumerator<(Tensor, Tensor)> GetEnumerator()
  94. {
  95. using var ownedIterator = new OwnedIterator(this);
  96. Tensor[] results = null;
  97. while (true)
  98. {
  99. try
  100. {
  101. results = ownedIterator.next();
  102. }
  103. catch (StopIteration)
  104. {
  105. break;
  106. }
  107. yield return (results[0], results.Length == 1 ? null : results[1]);
  108. }
  109. }
  110. IEnumerator IEnumerable.GetEnumerator()
  111. {
  112. return this.GetEnumerator();
  113. }
  114. }
  115. }