| @@ -0,0 +1,35 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Framework; | |||||
| using Tensorflow.Framework.Models; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Data | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that concatenates its input with given dataset. | |||||
| /// </summary> | |||||
| public class ConcatenateDataset : DatasetV2 | |||||
| { | |||||
| IDatasetV2 _input_dataset; | |||||
| IDatasetV2 _dataset_to_concatenate; | |||||
| public ConcatenateDataset(IDatasetV2 input_dataset, IDatasetV2 dataset_to_concatenate) | |||||
| { | |||||
| _input_dataset = input_dataset; | |||||
| _dataset_to_concatenate = dataset_to_concatenate; | |||||
| var _structure = new List<TensorSpec>(); | |||||
| foreach(var (i, spec) in enumerate(dataset_to_concatenate.element_spec)) | |||||
| { | |||||
| var shape = _input_dataset.output_shapes[i].most_specific_compatible_shape(spec.shape); | |||||
| _structure.Add(new TensorSpec(shape, dtype: spec.dtype)); | |||||
| } | |||||
| structure = _structure.ToArray(); | |||||
| variant_tensor = ops.concatenate_dataset(input_dataset.variant_tensor, | |||||
| dataset_to_concatenate.variant_tensor, | |||||
| output_types, output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,6 +2,7 @@ | |||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Data; | |||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -26,6 +27,9 @@ namespace Tensorflow | |||||
| public IDatasetV2 cache(string filename = "") | public IDatasetV2 cache(string filename = "") | ||||
| => new CacheDataset(this, filename: filename); | => new CacheDataset(this, filename: filename); | ||||
| public IDatasetV2 concatenate(IDatasetV2 dataset) | |||||
| => new ConcatenateDataset(this, dataset); | |||||
| public IDatasetV2 take(int count = -1) | public IDatasetV2 take(int count = -1) | ||||
| => new TakeDataset(this, count: count); | => new TakeDataset(this, count: count); | ||||
| @@ -23,6 +23,13 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| IDatasetV2 cache(string filename = ""); | IDatasetV2 cache(string filename = ""); | ||||
| /// <summary> | |||||
| /// Creates a `Dataset` by concatenating the given dataset with this dataset. | |||||
| /// </summary> | |||||
| /// <param name="dataset"></param> | |||||
| /// <returns></returns> | |||||
| IDatasetV2 concatenate(IDatasetV2 dataset); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -1,5 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | |||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -36,7 +38,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| try | try | ||||
| { | { | ||||
| return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); | |||||
| var results = ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); | |||||
| foreach(var (i, tensor) in enumerate(results)) | |||||
| tensor.set_shape(_element_spec[i].shape); | |||||
| return results; | |||||
| } | } | ||||
| catch (OutOfRangeError ex) | catch (OutOfRangeError ex) | ||||
| { | { | ||||