| @@ -0,0 +1,23 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class CacheDataset : UnaryUnchangedStructureDataset | |||||
| { | |||||
| Tensor _filename; | |||||
| public CacheDataset(IDatasetV2 input_dataset, | |||||
| string filename = "") : | |||||
| base(input_dataset) | |||||
| { | |||||
| _filename = tf.convert_to_tensor(filename, dtype: TF_DataType.TF_STRING, name: "filename"); | |||||
| variant_tensor = ops.cache_dataset_v2(input_dataset.variant_tensor, | |||||
| _filename, | |||||
| ops.dummy_memory_cache(), | |||||
| output_types, | |||||
| output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -23,6 +23,9 @@ namespace Tensorflow | |||||
| public TensorSpec[] element_spec => structure; | public TensorSpec[] element_spec => structure; | ||||
| public IDatasetV2 cache(string filename = "") | |||||
| => new CacheDataset(this, filename: filename); | |||||
| public IDatasetV2 take(int count = -1) | public IDatasetV2 take(int count = -1) | ||||
| => new TakeDataset(this, count: count); | => new TakeDataset(this, count: count); | ||||
| @@ -47,6 +50,16 @@ namespace Tensorflow | |||||
| public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs) | ||||
| => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | => new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs); | ||||
| public IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||||
| bool use_inter_op_parallelism = true, | |||||
| bool preserve_cardinality = false, | |||||
| bool use_legacy_function = false) | |||||
| => new MapDataset(this, | |||||
| map_func, | |||||
| use_inter_op_parallelism: use_inter_op_parallelism, | |||||
| preserve_cardinality: preserve_cardinality, | |||||
| use_legacy_function: use_legacy_function); | |||||
| public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | public IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget) | ||||
| => new ModelDataset(this, algorithm, cpu_budget); | => new ModelDataset(this, algorithm, cpu_budget); | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq.Expressions; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Framework.Models; | using Tensorflow.Framework.Models; | ||||
| @@ -17,6 +18,13 @@ namespace Tensorflow | |||||
| TensorSpec[] structure { get; set; } | TensorSpec[] structure { get; set; } | ||||
| /// <summary> | |||||
| /// Caches the elements in this dataset. | |||||
| /// </summary> | |||||
| /// <param name="filename"></param> | |||||
| /// <returns></returns> | |||||
| IDatasetV2 cache(string filename=""); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -49,6 +57,11 @@ namespace Tensorflow | |||||
| IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | IDatasetV2 optimize(string[] optimizations, string[] optimization_configs); | ||||
| IDatasetV2 map(Func<Tensor, Tensor> map_func, | |||||
| bool use_inter_op_parallelism = true, | |||||
| bool preserve_cardinality = false, | |||||
| bool use_legacy_function = false); | |||||
| IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that maps a function over elements in its input. | |||||
| /// </summary> | |||||
| public class MapDataset : UnaryDataset | |||||
| { | |||||
| public MapDataset(IDatasetV2 input_dataset, | |||||
| Func<Tensor, Tensor> map_func, | |||||
| bool use_inter_op_parallelism = true, | |||||
| bool preserve_cardinality = false, | |||||
| bool use_legacy_function = false) : base(input_dataset) | |||||
| { | |||||
| foreach(var input in input_dataset) | |||||
| { | |||||
| var data = map_func(input.Item1); | |||||
| } | |||||
| variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | |||||
| output_types, | |||||
| output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -381,6 +381,9 @@ namespace Tensorflow.Eager | |||||
| c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle); | ||||
| status.Check(true); | status.Check(true); | ||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_FUNC: | |||||
| c_api.TFE_OpSetAttrFunctionName(op, key, value.ToString(), value.ToString().Length); | |||||
| break; | |||||
| default: | default: | ||||
| throw new NotImplementedException($"SetOpAttrScalar for {type}"); | throw new NotImplementedException($"SetOpAttrScalar for {type}"); | ||||
| } | } | ||||
| @@ -196,6 +196,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_OpSetAttrFunctionName(SafeOpHandle op, string attr_name, string data, int length); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -155,6 +155,24 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache, | |||||
| TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "CacheDatasetV2", name, | |||||
| null, | |||||
| input_dataset, filename, cache, | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a dataset that batches `batch_size` elements from `input_dataset`. | /// Creates a dataset that batches `batch_size` elements from `input_dataset`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -187,6 +205,24 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor dummy_memory_cache(string name = "") | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "DummyMemoryCache", name, | |||||
| null); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a dataset that asynchronously prefetches elements from `input_dataset`. | /// Creates a dataset that asynchronously prefetches elements from `input_dataset`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -354,6 +390,33 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="dataset"></param> | |||||
| /// <param name="iterator"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor map_dataset(Tensor dataset, TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "MapDataset", name, | |||||
| null, | |||||
| dataset, new Tensor[0], | |||||
| "f", "MapDataset", | |||||
| "output_types", output_types, | |||||
| "output_shapes", output_shapes, | |||||
| "use_inter_op_parallelism", use_inter_op_parallelism, | |||||
| "preserve_cardinality", preserve_cardinality); | |||||
| return results[0]; | |||||
| } | |||||
| throw new NotImplementedException(""); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// A container for an iterator resource. | /// A container for an iterator resource. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | |||||
| using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
| using Tensorflow.UnitTest; | using Tensorflow.UnitTest; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -116,5 +117,35 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| value ++; | value ++; | ||||
| } | } | ||||
| } | } | ||||
| [TestMethod, Ignore] | |||||
| public void Map() | |||||
| { | |||||
| long value = 0; | |||||
| var dataset = tf.data.Dataset.range(3); | |||||
| var dataset1 = dataset.map(x => x); | |||||
| foreach (var item in dataset) | |||||
| { | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value++; | |||||
| } | |||||
| } | |||||
| [TestMethod] | |||||
| public void Cache() | |||||
| { | |||||
| long value = 0; | |||||
| var dataset = tf.data.Dataset.range(5); | |||||
| dataset = dataset.cache(); | |||||
| foreach (var item in dataset) | |||||
| { | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value++; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||