| @@ -0,0 +1,23 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| public class CacheDataset : UnaryUnchangedStructureDataset | |||
| { | |||
| Tensor _filename; | |||
| public CacheDataset(IDatasetV2 input_dataset, | |||
| string filename = "") : | |||
| base(input_dataset) | |||
| { | |||
| _filename = tf.convert_to_tensor(filename, dtype: TF_DataType.TF_STRING, name: "filename"); | |||
| variant_tensor = ops.cache_dataset_v2(input_dataset.variant_tensor, | |||
| _filename, | |||
| ops.dummy_memory_cache(), | |||
| output_types, | |||
| output_shapes); | |||
| } | |||
| } | |||
| } | |||
| @@ -23,6 +23,9 @@ namespace Tensorflow | |||
| public TensorSpec[] element_spec => structure; | |||
| public IDatasetV2 cache(string filename = "") | |||
| => new CacheDataset(this, filename: filename); | |||
| public IDatasetV2 take(int count = -1) | |||
| => new TakeDataset(this, count: count); | |||
| @@ -47,6 +50,16 @@ namespace Tensorflow | |||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | |||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | |||
| public IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||
| bool use_inter_op_parallelism = true, | |||
| bool preserve_cardinality = false, | |||
| bool use_legacy_function = false) | |||
| => new MapDataset(this, | |||
| map_func, | |||
| use_inter_op_parallelism: use_inter_op_parallelism, | |||
| preserve_cardinality: preserve_cardinality, | |||
| use_legacy_function: use_legacy_function); | |||
| public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | |||
| => new ModelDataset(this, algorithm, cpu_budget); | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq.Expressions; | |||
| using System.Text; | |||
| using Tensorflow.Framework.Models; | |||
| @@ -17,6 +18,13 @@ namespace Tensorflow | |||
| TensorSpec[] structure { get; set; } | |||
| /// <summary> | |||
| /// Caches the elements in this dataset. | |||
| /// </summary> | |||
| /// <param name="filename"></param> | |||
| /// <returns></returns> | |||
| IDatasetV2 cache(string filename=""); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -49,6 +57,11 @@ namespace Tensorflow | |||
| IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | |||
| IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||
| bool use_inter_op_parallelism = true, | |||
| bool preserve_cardinality = false, | |||
| bool use_legacy_function = false); | |||
| IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | |||
| /// <summary> | |||
| @@ -0,0 +1,28 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// A `Dataset` that maps a function over elements in its input. | |||
| /// </summary> | |||
| public class MapDataset : UnaryDataset | |||
| { | |||
| public MapDataset(IDatasetV2 input_dataset, | |||
| Func<Tensor, Tensor> map_func, | |||
| bool use_inter_op_parallelism = true, | |||
| bool preserve_cardinality = false, | |||
| bool use_legacy_function = false) : base(input_dataset) | |||
| { | |||
| foreach(var input in input_dataset) | |||
| { | |||
| var data = map_func(input.Item1); | |||
| } | |||
| variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | |||
| output_types, | |||
| output_shapes); | |||
| } | |||
| } | |||
| } | |||
| @@ -381,6 +381,9 @@ namespace Tensorflow.Eager | |||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | |||
| status.Check(true); | |||
| break; | |||
| case TF_AttrType.TF_ATTR_FUNC: | |||
| c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length); | |||
| break; | |||
| default: | |||
| throw new NotImplementedException($"SetOpAttrScalar for {type}"); | |||
| } | |||
| @@ -196,6 +196,9 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TFE_OpSetAttrFunctionName(SafeOpHandle op, string attr_name, string data, int length); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| @@ -155,6 +155,24 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache, | |||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||
| string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "CacheDatasetV2", name, | |||
| null, | |||
| input_dataset, filename, cache, | |||
| "output_types", output_types, | |||
| "output_shapes", output_shapes); | |||
| return results[0]; | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// Creates a dataset that batches `batch_size` elements from `input_dataset`. | |||
| /// </summary> | |||
| @@ -187,6 +205,24 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor dummy_memory_cache(string name = "") | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "DummyMemoryCache", name, | |||
| null); | |||
| return results[0]; | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// Creates a dataset that asynchronously prefetches elements from `input_dataset`. | |||
| /// </summary> | |||
| @@ -354,6 +390,33 @@ namespace Tensorflow | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="dataset"></param> | |||
| /// <param name="iterator"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes, | |||
| bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "MapDataset", name, | |||
| null, | |||
| dataset, new Tensor[0], | |||
| "f", "MapDataset", | |||
| "output_types", output_types, | |||
| "output_shapes", output_shapes, | |||
| "use_inter_op_parallelism", use_inter_op_parallelism, | |||
| "preserve_cardinality", preserve_cardinality); | |||
| return results[0]; | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| /// <summary> | |||
| /// A container for an iterator resource. | |||
| /// </summary> | |||
| @@ -3,6 +3,7 @@ using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.UnitTest; | |||
| using static Tensorflow.Binding; | |||
| @@ -116,5 +117,35 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
| value ++; | |||
| } | |||
| } | |||
| [TestMethod, Ignore] | |||
| public void Map() | |||
| { | |||
| long value = 0; | |||
| var dataset = tf.data.Dataset.range(3); | |||
| var dataset1 = dataset.map(x => x); | |||
| foreach (var item in dataset) | |||
| { | |||
| Assert.AreEqual(value, (long)item.Item1); | |||
| value++; | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void Cache() | |||
| { | |||
| long value = 0; | |||
| var dataset = tf.data.Dataset.range(5); | |||
| dataset = dataset.cache(); | |||
| foreach (var item in dataset) | |||
| { | |||
| Assert.AreEqual(value, (long)item.Item1); | |||
| value++; | |||
| } | |||
| } | |||
| } | |||
| } | |||