| @@ -0,0 +1,10 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class DatasetOptions | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -60,12 +60,18 @@ namespace Tensorflow | |||||
| preserve_cardinality: preserve_cardinality, | preserve_cardinality: preserve_cardinality, | ||||
| use_legacy_function: use_legacy_function); | use_legacy_function: use_legacy_function); | ||||
| public IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, int num_parallel_calls = -1) | |||||
| => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); | |||||
| public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func) | ||||
| => new FlatMapDataset(this, map_func); | => new FlatMapDataset(this, map_func); | ||||
| public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | ||||
| => new ModelDataset(this, algorithm, cpu_budget); | => new ModelDataset(this, algorithm, cpu_budget); | ||||
| public IDatasetV2 with_options(DatasetOptions options) | |||||
| => new OptionsDataset(this, options); | |||||
| public IDatasetV2 apply_options() | public IDatasetV2 apply_options() | ||||
| { | { | ||||
| // (1) Apply threading options | // (1) Apply threading options | ||||
| @@ -94,7 +100,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| public override string ToString() | public override string ToString() | ||||
| => $"{GetType().Name} shapes: ({structure[0].shape}, {structure[1].shape}), types: (tf.{structure[0].dtype.as_numpy_name()}, tf.{structure[1].dtype.as_numpy_name()})"; | |||||
| => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; | |||||
| public IEnumerator<(Tensor, Tensor)> GetEnumerator() | public IEnumerator<(Tensor, Tensor)> GetEnumerator() | ||||
| { | { | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| @@ -14,7 +15,7 @@ namespace Tensorflow | |||||
| Func<Tensor, IDatasetV2> map_func) : base(input_dataset) | Func<Tensor, IDatasetV2> map_func) : base(input_dataset) | ||||
| { | { | ||||
| var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); | var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); | ||||
| structure = func.OutputStructure; | |||||
| variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor, | variant_tensor = ops.flat_map_dataset(input_dataset.variant_tensor, | ||||
| func, | func, | ||||
| output_types, | output_types, | ||||
| @@ -62,10 +62,15 @@ namespace Tensorflow | |||||
| bool preserve_cardinality = false, | bool preserve_cardinality = false, | ||||
| bool use_legacy_function = false); | bool use_legacy_function = false); | ||||
| IDatasetV2 map(Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, | |||||
| int num_parallel_calls = -1); | |||||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | ||||
| IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | ||||
| IDatasetV2 with_options(DatasetOptions options); | |||||
| /// <summary> | /// <summary> | ||||
| /// Apply options, such as optimization configuration, to the dataset. | /// Apply options, such as optimization configuration, to the dataset. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,21 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// An identity `Dataset` that stores options. | |||||
| /// </summary> | |||||
| public class OptionsDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| DatasetOptions options; | |||||
| public OptionsDataset(IDatasetV2 input_dataset, DatasetOptions options) | |||||
| : base(input_dataset) | |||||
| { | |||||
| this.options = options; | |||||
| variant_tensor = input_dataset.variant_tensor; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,34 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| //A `Dataset` that maps a function over elements in its input in parallel. | |||||
| public class ParallelMapDataset : UnaryDataset | |||||
| { | |||||
| public ParallelMapDataset(IDatasetV2 input_dataset, | |||||
| Func<Tensor, (Tensor, Tensor), (Tensor, Tensor)> map_func, | |||||
| int num_parallel_calls = -1, | |||||
| bool use_inter_op_parallelism = true, | |||||
| bool preserve_cardinality = false, | |||||
| bool use_legacy_function = false) : base(input_dataset) | |||||
| { | |||||
| var func = new ConcreteFunction(map_func, | |||||
| input_dataset.element_spec.Select(x => x.dtype).ToArray(), | |||||
| input_dataset.element_spec.Select(x => x.shape).ToArray()); | |||||
| structure = func.OutputStructure; | |||||
| var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, | |||||
| name: "num_parallel_calls"); | |||||
| variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, | |||||
| _num_parallel_calls, | |||||
| func, | |||||
| output_types, | |||||
| output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -15,11 +15,9 @@ namespace Tensorflow | |||||
| public TensorDataset(Tensor feature, Tensor label) | public TensorDataset(Tensor feature, Tensor label) | ||||
| { | { | ||||
| _tensors = new[] { feature, label }; | _tensors = new[] { feature, label }; | ||||
| var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| structure = batched_spec.Select(x => x._unbatch()).ToArray(); | |||||
| structure = _tensors.Select(x => x.ToTensorSpec()).ToArray(); | |||||
| variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | variant_tensor = ops.tensor_dataset(_tensors, output_shapes); | ||||
| } | } | ||||
| public TensorDataset(Tensor element) | public TensorDataset(Tensor element) | ||||
| { | { | ||||
| @@ -2,16 +2,19 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework.Models; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public class ZipDataset : DatasetV2 | public class ZipDataset : DatasetV2 | ||||
| { | { | ||||
| dataset_ops ops = new dataset_ops(); | |||||
| public ZipDataset(params IDatasetV2[] ds) | public ZipDataset(params IDatasetV2[] ds) | ||||
| { | { | ||||
| var input_datasets = ds.Select(x => x.variant_tensor).ToArray(); | var input_datasets = ds.Select(x => x.variant_tensor).ToArray(); | ||||
| structure = ds.Select(x => x.structure[0]).ToArray(); | |||||
| var _structure = new List<TensorSpec>(); | |||||
| foreach (var dataset in ds) | |||||
| _structure.AddRange(dataset.structure); | |||||
| structure = _structure.ToArray(); | |||||
| variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes); | variant_tensor = ops.zip_dataset(input_datasets, output_types, output_shapes); | ||||
| } | } | ||||
| } | } | ||||