| @@ -41,6 +41,9 @@ namespace Tensorflow | |||||
| public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | public IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true) | ||||
| => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | => new ShuffleDataset(this, buffer_size, seed: seed, reshuffle_each_iteration: reshuffle_each_iteration); | ||||
| public IDatasetV2 skip(int count) | |||||
| => new SkipDataset(this, count); | |||||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | ||||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | ||||
| @@ -34,6 +34,13 @@ namespace Tensorflow | |||||
| IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); | IDatasetV2 shuffle(int buffer_size, int? seed = null, bool reshuffle_each_iteration = true); | ||||
| /// <summary> | |||||
| /// Creates a `Dataset` that skips `count` elements from this dataset. | |||||
| /// </summary> | |||||
| /// <param name="count"></param> | |||||
| /// <returns></returns> | |||||
| IDatasetV2 skip(int count); | |||||
| IDatasetV2 batch(int batch_size, bool drop_remainder = false); | IDatasetV2 batch(int batch_size, bool drop_remainder = false); | ||||
| IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | IDatasetV2 prefetch(int buffer_size = -1, int? slack_period = null); | ||||
| @@ -0,0 +1,24 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` skipping the first `count` elements from its input. | |||||
| /// </summary> | |||||
| public class SkipDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _count; | |||||
| public SkipDataset(IDatasetV2 input_dataset, | |||||
| int count) : base(input_dataset) | |||||
| { | |||||
| _count = tf.convert_to_tensor(count, dtype: dtypes.int64, name: "count"); | |||||
| variant_tensor = ops.skip_dataset(input_dataset.variant_tensor, | |||||
| _count, | |||||
| output_types, output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -106,6 +106,24 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public Tensor skip_dataset(Tensor input_dataset, Tensor count, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "SkipDataset", name, | |||||
| null, | |||||
| input_dataset, count, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| public Tensor dummy_seed_generator(string name = null) | public Tensor dummy_seed_generator(string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| @@ -84,5 +84,20 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| value += 3; | value += 3; | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Skip() | |||||
| { | |||||
| long value = 7; | |||||
| var dataset = tf.data.Dataset.range(10); | |||||
| dataset = dataset.skip(7); | |||||
| foreach (var item in dataset) | |||||
| { | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value ++; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||