You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DatasetV2.cs 6.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Data;
  6. using Tensorflow.Framework.Models;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow
  9. {
  10. /// <summary>
  11. /// Abstract class representing a dataset with no inputs.
  12. /// </summary>
  13. public class DatasetV2 : IDatasetV2
  14. {
  15. protected dataset_ops ops = new dataset_ops();
  16. public string[] class_names { get; set; }
  17. public Tensor variant_tensor { get; set; }
  18. public TensorSpec[] structure { get; set; }
  19. public int FirstInputTensorCount { get; set; } = 1;
  20. public Shape[] output_shapes => structure.Select(x => x.shape).ToArray();
  21. public TF_DataType[] output_types => structure.Select(x => x.dtype).ToArray();
  22. public TensorSpec[] element_spec => structure;
  23. public int length => cardinality().numpy();
  24. public IDatasetV2 cache(string filename = "")
  25. => new CacheDataset(this, filename: filename);
  26. public IDatasetV2 concatenate(IDatasetV2 dataset)
  27. => new ConcatenateDataset(this, dataset);
  28. public IDatasetV2 take(int count = -1)
  29. => new TakeDataset(this, count: count);
  30. public IDatasetV2 batch(int batch_size, bool drop_remainder = false)
  31. => new BatchDataset(this, batch_size, drop_remainder: drop_remainder);
  32. public IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null)
  33. => new PrefetchDataset(this, buffer_size: buffer_size, slack_period: slack_period);
  34. public IDatasetV2 repeat(int count = -1)
  35. => new RepeatDataset(this, count: count);
  36. public IDatasetV2 shard(int num_shards, int index)
  37. => new ShardDataset(this, num_shards, index);
  38. public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true)
  39. => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration);
  40. public IDatasetV2 skip(int count)
  41. => new SkipDataset(this, count);
  42. public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
  43. => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
  44. public IDatasetV2 map(Func<Tensors, Tensors> map_func,
  45. bool use_inter_op_parallelism = true,
  46. bool preserve_cardinality = true,
  47. bool use_legacy_function = false)
  48. => new MapDataset(this,
  49. map_func,
  50. use_inter_op_parallelism: use_inter_op_parallelism,
  51. preserve_cardinality: preserve_cardinality,
  52. use_legacy_function: use_legacy_function);
  53. public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
  54. => new ParallelMapDataset(this, map_func,
  55. num_parallel_calls: num_parallel_calls,
  56. preserve_cardinality: true);
  57. public IDatasetV2 filter(Func<Tensors, Tensors> predicate_func)
  58. => new FilterDataset(this, predicate_func);
  59. public IDatasetV2 filter(Func<Tensor, bool> predicate_func)
  60. => new FilterDataset(this, predicate_func);
  61. public OwnedIterator make_one_shot_iterator()
  62. {
  63. if (tf.Context.executing_eagerly())
  64. {
  65. // with ops.colocate_with(self._variant_tensor)
  66. return new OwnedIterator(this);
  67. }
  68. throw new NotImplementedException("");
  69. }
  70. public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)
  71. => new FlatMapDataset(this, map_func);
  72. public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget, long ram_budget)
  73. => new ModelDataset(this, algorithm, cpu_budget, ram_budget);
  74. public IDatasetV2 with_options(DatasetOptions options)
  75. => new OptionsDataset(this, options);
  76. public IDatasetV2 apply_options()
  77. {
  78. IDatasetV2 dataset = this;
  79. // (1) Apply threading options
  80. // (2) Apply autotune options
  81. var autotune = true;
  82. long cpu_budget = 0;
  83. long ram_budget = 0;
  84. if (autotune)
  85. dataset = dataset.model(AutotuneAlgorithm.HILL_CLIMB, cpu_budget, ram_budget);
  86. // (3) Apply graph rewrite options
  87. var graph_rewrites = new[]
  88. {
  89. "map_and_batch_fusion",
  90. "map_parallelization",
  91. "noop_elimination",
  92. "shuffle_and_repeat_fusion"
  93. };
  94. var graph_rewrite_configs = new string[]
  95. {
  96. "autotune_buffer_sizes:autotune:true",
  97. "batch_parallelization:autotune:true",
  98. "disable_prefetch_legacy_autotune:autotune:true",
  99. "enable_gradient_descent:autotune:true",
  100. "map_parallelization:autotune:true"
  101. };
  102. dataset = new OptimizeDataset(dataset, new string[0], new string[0], graph_rewrites, graph_rewrite_configs);
  103. // (4) Apply stats aggregator options
  104. dataset.FirstInputTensorCount = this.FirstInputTensorCount;
  105. return dataset;
  106. }
  107. public Tensor cardinality(string name = null)
  108. => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor));
  109. public override string ToString()
  110. => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, " +
  111. $"types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}, " +
  112. $"len: {length}";
  113. public IEnumerator<(Tensors, Tensors)> GetEnumerator()
  114. {
  115. using var ownedIterator = new OwnedIterator(this);
  116. Tensor[] results = null;
  117. while (true)
  118. {
  119. try
  120. {
  121. results = ownedIterator.next();
  122. }
  123. catch (StopIteration)
  124. {
  125. break;
  126. }
  127. yield return (new Tensors(results.Take(FirstInputTensorCount)), results.Length == FirstInputTensorCount ?
  128. null : new Tensors(results.Skip(FirstInputTensorCount)));
  129. }
  130. }
  131. IEnumerator IEnumerable.GetEnumerator()
  132. {
  133. return this.GetEnumerator();
  134. }
  135. }
  136. }