| @@ -23,6 +23,8 @@ namespace Tensorflow | |||||
| public class DataOps | public class DataOps | ||||
| { | { | ||||
| public int AUTOTUNE = -1; | public int AUTOTUNE = -1; | ||||
| public int INFINITE_CARDINALITY = -1; | |||||
| public int UNKNOWN_CARDINALITY = -2; | |||||
| public DatasetManager Dataset { get; } = new DatasetManager(); | public DatasetManager Dataset { get; } = new DatasetManager(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -29,48 +29,48 @@ namespace Tensorflow.Contexts | |||||
| /// </summary> | /// </summary> | ||||
| public sealed partial class Context | public sealed partial class Context | ||||
| { | { | ||||
| // [DebuggerStepThrough] | |||||
| public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args) | |||||
| Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args) | |||||
| { | { | ||||
| Func<Tensors> graphAction = () => | |||||
| var keywords = new Dictionary<string, object>(); | |||||
| if (args.OpInputArgs != null) | |||||
| { | { | ||||
| var keywords = new Dictionary<string, object>(); | |||||
| if(args.OpInputArgs != null) | |||||
| { | |||||
| foreach (var (i, input) in enumerate(args.OpInputArgs)) | |||||
| keywords[$"input_{i}"] = input; | |||||
| } | |||||
| foreach (var (i, input) in enumerate(args.OpInputArgs)) | |||||
| keywords[$"input_{i}"] = input; | |||||
| } | |||||
| if(args.OpAttrs != null) | |||||
| { | |||||
| foreach (var attr in args.OpAttrs) | |||||
| keywords[attr.Key] = attr.Value; | |||||
| } | |||||
| if (args.OpAttrs != null) | |||||
| { | |||||
| foreach (var attr in args.OpAttrs) | |||||
| keywords[attr.Key] = attr.Value; | |||||
| } | |||||
| return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; | |||||
| }; | |||||
| return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; | |||||
| } | |||||
| Func<Tensors> eagerAction = () => | |||||
| Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args) | |||||
| { | |||||
| var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) | |||||
| { | { | ||||
| var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) | |||||
| { | |||||
| attrs = args.OpAttrs | |||||
| }; | |||||
| return tf.Runner.TFE_FastPathExecute(opExecInfo); | |||||
| attrs = args.OpAttrs | |||||
| }; | }; | ||||
| return tf.Runner.TFE_FastPathExecute(opExecInfo); | |||||
| } | |||||
| // [DebuggerStepThrough] | |||||
| public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args) | |||||
| { | |||||
| if (tf.Context.has_graph_arg(args.OpInputArgs)) | if (tf.Context.has_graph_arg(args.OpInputArgs)) | ||||
| { | { | ||||
| if (executing_eagerly()) | if (executing_eagerly()) | ||||
| { | { | ||||
| graph_mode(); | graph_mode(); | ||||
| var result = graphAction(); | |||||
| var result = ExecGraphAction(opType, name, args); | |||||
| restore_mode(); | restore_mode(); | ||||
| return result; | return result; | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| var result = graphAction(); | |||||
| var result = ExecGraphAction(opType, name, args); | |||||
| if (tf.Runner.MustRecordGradient()) | if (tf.Runner.MustRecordGradient()) | ||||
| { | { | ||||
| var op = result[0].op; | var op = result[0].op; | ||||
| @@ -92,14 +92,14 @@ namespace Tensorflow.Contexts | |||||
| args1[i + 1] = arg.Value; | args1[i + 1] = arg.Value; | ||||
| i += 2; | i += 2; | ||||
| } | } | ||||
| tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs); | |||||
| tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs); | |||||
| } | } | ||||
| return result; | return result; | ||||
| } | } | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| return eagerAction(); | |||||
| return ExecEagerAction(opType, name, args); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -70,6 +70,12 @@ namespace Tensorflow | |||||
| num_parallel_calls: num_parallel_calls, | num_parallel_calls: num_parallel_calls, | ||||
| preserve_cardinality: true); | preserve_cardinality: true); | ||||
| public IDatasetV2 filter(Func<Tensors, Tensors> predicate_func) | |||||
| => new FilterDataset(this, predicate_func); | |||||
| public IDatasetV2 filter(Func<Tensor, bool> predicate_func) | |||||
| => new FilterDataset(this, predicate_func); | |||||
| public OwnedIterator make_one_shot_iterator() | public OwnedIterator make_one_shot_iterator() | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| @@ -105,13 +111,15 @@ namespace Tensorflow | |||||
| // (3) Apply graph rewrite options | // (3) Apply graph rewrite options | ||||
| var graph_rewrites = new[] | var graph_rewrites = new[] | ||||
| { | { | ||||
| "noop_elimination", | |||||
| "map_and_batch_fusion", | "map_and_batch_fusion", | ||||
| "map_parallelization", | |||||
| "noop_elimination", | |||||
| "shuffle_and_repeat_fusion" | "shuffle_and_repeat_fusion" | ||||
| }; | }; | ||||
| var graph_rewrite_configs = new string[] | var graph_rewrite_configs = new string[] | ||||
| { | { | ||||
| "autotune_buffer_sizes:autotune:true", | "autotune_buffer_sizes:autotune:true", | ||||
| "batch_parallelization:autotune:true", | |||||
| "disable_prefetch_legacy_autotune:autotune:true", | "disable_prefetch_legacy_autotune:autotune:true", | ||||
| "enable_gradient_descent:autotune:true", | "enable_gradient_descent:autotune:true", | ||||
| "map_parallelization:autotune:true" | "map_parallelization:autotune:true" | ||||
| @@ -124,7 +132,7 @@ namespace Tensorflow | |||||
| return dataset; | return dataset; | ||||
| } | } | ||||
| public Tensor dataset_cardinality(string name = null) | |||||
| public Tensor cardinality(string name = null) | |||||
| => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); | => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -0,0 +1,58 @@ | |||||
| using System; | |||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow | |||||
| { | |||||
| /// <summary> | |||||
| /// A `Dataset` that filters its input according to a predicate function. | |||||
| /// </summary> | |||||
| public class FilterDataset : UnaryDataset | |||||
| { | |||||
| public FilterDataset(IDatasetV2 input_dataset, | |||||
| Func<Tensor, bool> predicate_func) : base(input_dataset) | |||||
| { | |||||
| Func<Tensors, Tensors> predicate_func_update = x => | |||||
| { | |||||
| var result = predicate_func(x); | |||||
| return constant_op.constant(result); | |||||
| }; | |||||
| var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); | |||||
| func.Enter(); | |||||
| var inputs = new Tensors(); | |||||
| foreach (var input in input_dataset.element_spec) | |||||
| inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); | |||||
| var outputs = predicate_func_update(inputs); | |||||
| func.ToGraph(inputs, outputs); | |||||
| func.Exit(); | |||||
| structure = func.OutputStructure; | |||||
| variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, | |||||
| func, | |||||
| output_types, | |||||
| output_shapes); | |||||
| } | |||||
| public FilterDataset(IDatasetV2 input_dataset, | |||||
| Func<Tensors, Tensors> predicate_func) : base(input_dataset) | |||||
| { | |||||
| var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); | |||||
| func.Enter(); | |||||
| var inputs = new Tensors(); | |||||
| foreach (var input in input_dataset.element_spec) | |||||
| inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); | |||||
| var outputs = predicate_func(inputs); | |||||
| func.ToGraph(inputs, outputs); | |||||
| func.Exit(); | |||||
| structure = func.OutputStructure; | |||||
| variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, | |||||
| func, | |||||
| output_types, | |||||
| output_shapes); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -72,6 +72,9 @@ namespace Tensorflow | |||||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | IDatasetV2 map(Func<Tensors, Tensors> map_func, | ||||
| int num_parallel_calls); | int num_parallel_calls); | ||||
| IDatasetV2 filter(Func<Tensors, Tensors> map_func); | |||||
| IDatasetV2 filter(Func<Tensor, bool> map_func); | |||||
| OwnedIterator make_one_shot_iterator(); | OwnedIterator make_one_shot_iterator(); | ||||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | ||||
| @@ -91,6 +94,6 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| Tensor dataset_cardinality(string name = null); | |||||
| Tensor cardinality(string name = null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -249,6 +249,25 @@ namespace Tensorflow | |||||
| preserve_cardinality | preserve_cardinality | ||||
| })); | })); | ||||
| /// <summary> | |||||
| /// Creates a dataset containing elements of `input_dataset` matching `predicate`. | |||||
| /// </summary> | |||||
| /// <param name="dataset"></param> | |||||
| /// <param name="predicate"></param> | |||||
| /// <param name="output_types"></param> | |||||
| /// <param name="output_shapes"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, TensorShape[] output_shapes, | |||||
| string name = null) | |||||
| => tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) | |||||
| .SetAttributes(new | |||||
| { | |||||
| predicate, | |||||
| output_types, | |||||
| output_shapes | |||||
| })); | |||||
| /// <summary> | /// <summary> | ||||
| /// Creates a dataset that applies `f` to the outputs of `input_dataset`. | /// Creates a dataset that applies `f` to the outputs of `input_dataset`. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public partial class Tensor | |||||
| { | |||||
| public static Tensor operator !=(Tensor x, int y) | |||||
| => gen_math_ops.not_equal(x, math_ops.cast(y, dtype: x.dtype)); | |||||
| public static Tensor operator ==(Tensor x, int y) | |||||
| => gen_math_ops.equal(x, math_ops.cast(y, dtype: x.dtype)); | |||||
| } | |||||
| } | |||||
| @@ -144,6 +144,12 @@ namespace Tensorflow | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| else if (dtype != TF_DataType.DtInvalid && | |||||
| value is NDArray nd && | |||||
| dtypes.as_dtype(nd.dtype) != dtype) | |||||
| { | |||||
| value = nd.astype(dtype.as_numpy_dtype()); | |||||
| } | |||||
| if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) | if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) | ||||
| return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING); | return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING); | ||||
| @@ -87,7 +87,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| if (adapter_steps > -1) | if (adapter_steps > -1) | ||||
| return adapter_steps; | return adapter_steps; | ||||
| var size = dataset.dataset_cardinality(); | |||||
| var size = dataset.cardinality(); | |||||
| return size.numpy(); | return size.numpy(); | ||||
| } | } | ||||
| @@ -147,10 +147,10 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| public void Cardinality() | public void Cardinality() | ||||
| { | { | ||||
| var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
| var cardinality = dataset.dataset_cardinality(); | |||||
| var cardinality = dataset.cardinality(); | |||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | ||||
| dataset = dataset.map(x => x[0] + 1); | dataset = dataset.map(x => x[0] + 1); | ||||
| cardinality = dataset.dataset_cardinality(); | |||||
| cardinality = dataset.cardinality(); | |||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | ||||
| } | } | ||||
| @@ -159,10 +159,23 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| { | { | ||||
| var dataset = tf.data.Dataset.range(10); | var dataset = tf.data.Dataset.range(10); | ||||
| dataset = dataset.map(x => x, num_parallel_calls: -1); | dataset = dataset.map(x => x, num_parallel_calls: -1); | ||||
| var cardinality = dataset.dataset_cardinality(); | |||||
| var cardinality = dataset.cardinality(); | |||||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void CardinalityWithRepeat() | |||||
| { | |||||
| var dataset = tf.data.Dataset.range(10); | |||||
| dataset = dataset.repeat(); | |||||
| var cardinality = dataset.cardinality(); | |||||
| Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy()); | |||||
| dataset = dataset.filter(x => true); | |||||
| cardinality = dataset.cardinality(); | |||||
| Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()); | |||||
| } | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Shuffle() | public void Shuffle() | ||||
| { | { | ||||