| @@ -23,6 +23,8 @@ namespace Tensorflow | |||
| public class DataOps | |||
| { | |||
| public int AUTOTUNE = -1; | |||
| public int INFINITE_CARDINALITY = -1; | |||
| public int UNKNOWN_CARDINALITY = -2; | |||
| public DatasetManager Dataset { get; } = new DatasetManager(); | |||
| } | |||
| } | |||
| @@ -29,48 +29,48 @@ namespace Tensorflow.Contexts | |||
| /// </summary> | |||
| public sealed partial class Context | |||
| { | |||
| // [DebuggerStepThrough] | |||
| public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args) | |||
| Tensors ExecGraphAction(string OpType, string Name, ExecuteOpArgs args) | |||
| { | |||
| Func<Tensors> graphAction = () => | |||
| var keywords = new Dictionary<string, object>(); | |||
| if (args.OpInputArgs != null) | |||
| { | |||
| var keywords = new Dictionary<string, object>(); | |||
| if(args.OpInputArgs != null) | |||
| { | |||
| foreach (var (i, input) in enumerate(args.OpInputArgs)) | |||
| keywords[$"input_{i}"] = input; | |||
| } | |||
| foreach (var (i, input) in enumerate(args.OpInputArgs)) | |||
| keywords[$"input_{i}"] = input; | |||
| } | |||
| if(args.OpAttrs != null) | |||
| { | |||
| foreach (var attr in args.OpAttrs) | |||
| keywords[attr.Key] = attr.Value; | |||
| } | |||
| if (args.OpAttrs != null) | |||
| { | |||
| foreach (var attr in args.OpAttrs) | |||
| keywords[attr.Key] = attr.Value; | |||
| } | |||
| return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; | |||
| }; | |||
| return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; | |||
| } | |||
| Func<Tensors> eagerAction = () => | |||
| Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args) | |||
| { | |||
| var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) | |||
| { | |||
| var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) | |||
| { | |||
| attrs = args.OpAttrs | |||
| }; | |||
| return tf.Runner.TFE_FastPathExecute(opExecInfo); | |||
| attrs = args.OpAttrs | |||
| }; | |||
| return tf.Runner.TFE_FastPathExecute(opExecInfo); | |||
| } | |||
| // [DebuggerStepThrough] | |||
| public Tensors ExecuteOp(string opType, string name, ExecuteOpArgs args) | |||
| { | |||
| if (tf.Context.has_graph_arg(args.OpInputArgs)) | |||
| { | |||
| if (executing_eagerly()) | |||
| { | |||
| graph_mode(); | |||
| var result = graphAction(); | |||
| var result = ExecGraphAction(opType, name, args); | |||
| restore_mode(); | |||
| return result; | |||
| } | |||
| else | |||
| { | |||
| var result = graphAction(); | |||
| var result = ExecGraphAction(opType, name, args); | |||
| if (tf.Runner.MustRecordGradient()) | |||
| { | |||
| var op = result[0].op; | |||
| @@ -92,14 +92,14 @@ namespace Tensorflow.Contexts | |||
| args1[i + 1] = arg.Value; | |||
| i += 2; | |||
| } | |||
| tf.Runner.RecordGradient(OpType, op.inputs, args1, op.outputs); | |||
| tf.Runner.RecordGradient(opType, op.inputs, args1, op.outputs); | |||
| } | |||
| return result; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| return eagerAction(); | |||
| return ExecEagerAction(opType, name, args); | |||
| } | |||
| } | |||
| } | |||
| @@ -70,6 +70,12 @@ namespace Tensorflow | |||
| num_parallel_calls: num_parallel_calls, | |||
| preserve_cardinality: true); | |||
| public IDatasetV2 filter(Func<Tensors, Tensors> predicate_func) | |||
| => new FilterDataset(this, predicate_func); | |||
| public IDatasetV2 filter(Func<Tensor, bool> predicate_func) | |||
| => new FilterDataset(this, predicate_func); | |||
| public OwnedIterator make_one_shot_iterator() | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| @@ -105,13 +111,15 @@ namespace Tensorflow | |||
| // (3) Apply graph rewrite options | |||
| var graph_rewrites = new[] | |||
| { | |||
| "noop_elimination", | |||
| "map_and_batch_fusion", | |||
| "map_parallelization", | |||
| "noop_elimination", | |||
| "shuffle_and_repeat_fusion" | |||
| }; | |||
| var graph_rewrite_configs = new string[] | |||
| { | |||
| "autotune_buffer_sizes:autotune:true", | |||
| "batch_parallelization:autotune:true", | |||
| "disable_prefetch_legacy_autotune:autotune:true", | |||
| "enable_gradient_descent:autotune:true", | |||
| "map_parallelization:autotune:true" | |||
| @@ -124,7 +132,7 @@ namespace Tensorflow | |||
| return dataset; | |||
| } | |||
| public Tensor dataset_cardinality(string name = null) | |||
| public Tensor cardinality(string name = null) | |||
| => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); | |||
| public override string ToString() | |||
| @@ -0,0 +1,58 @@ | |||
| using System; | |||
| using Tensorflow.Functions; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// A `Dataset` that filters its input according to a predicate function. | |||
| /// </summary> | |||
| public class FilterDataset : UnaryDataset | |||
| { | |||
| public FilterDataset(IDatasetV2 input_dataset, | |||
| Func<Tensor, bool> predicate_func) : base(input_dataset) | |||
| { | |||
| Func<Tensors, Tensors> predicate_func_update = x => | |||
| { | |||
| var result = predicate_func(x); | |||
| return constant_op.constant(result); | |||
| }; | |||
| var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); | |||
| func.Enter(); | |||
| var inputs = new Tensors(); | |||
| foreach (var input in input_dataset.element_spec) | |||
| inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); | |||
| var outputs = predicate_func_update(inputs); | |||
| func.ToGraph(inputs, outputs); | |||
| func.Exit(); | |||
| structure = func.OutputStructure; | |||
| variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, | |||
| func, | |||
| output_types, | |||
| output_shapes); | |||
| } | |||
| public FilterDataset(IDatasetV2 input_dataset, | |||
| Func<Tensors, Tensors> predicate_func) : base(input_dataset) | |||
| { | |||
| var func = new ConcreteFunction($"{predicate_func.Method.Name}_{Tensorflow.ops.uid_function()}"); | |||
| func.Enter(); | |||
| var inputs = new Tensors(); | |||
| foreach (var input in input_dataset.element_spec) | |||
| inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); | |||
| var outputs = predicate_func(inputs); | |||
| func.ToGraph(inputs, outputs); | |||
| func.Exit(); | |||
| structure = func.OutputStructure; | |||
| variant_tensor = ops.filter_dataset(input_dataset.variant_tensor, | |||
| func, | |||
| output_types, | |||
| output_shapes); | |||
| } | |||
| } | |||
| } | |||
| @@ -72,6 +72,9 @@ namespace Tensorflow | |||
| IDatasetV2 map(Func<Tensors, Tensors> map_func, | |||
| int num_parallel_calls); | |||
| IDatasetV2 filter(Func<Tensors, Tensors> map_func); | |||
| IDatasetV2 filter(Func<Tensor, bool> map_func); | |||
| OwnedIterator make_one_shot_iterator(); | |||
| IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func); | |||
| @@ -91,6 +94,6 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| Tensor dataset_cardinality(string name = null); | |||
| Tensor cardinality(string name = null); | |||
| } | |||
| } | |||
| @@ -249,6 +249,25 @@ namespace Tensorflow | |||
| preserve_cardinality | |||
| })); | |||
| /// <summary> | |||
| /// Creates a dataset containing elements of `input_dataset` matching `predicate`. | |||
| /// </summary> | |||
| /// <param name="dataset"></param> | |||
| /// <param name="predicate"></param> | |||
| /// <param name="output_types"></param> | |||
| /// <param name="output_shapes"></param> | |||
| /// <param name="name"></param> | |||
| /// <returns></returns> | |||
| public Tensor filter_dataset(Tensor dataset, ConcreteFunction predicate, TF_DataType[] output_types, TensorShape[] output_shapes, | |||
| string name = null) | |||
| => tf.Context.ExecuteOp("FilterDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) | |||
| .SetAttributes(new | |||
| { | |||
| predicate, | |||
| output_types, | |||
| output_shapes | |||
| })); | |||
| /// <summary> | |||
| /// Creates a dataset that applies `f` to the outputs of `input_dataset`. | |||
| /// </summary> | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Runtime.CompilerServices; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class Tensor | |||
| { | |||
| public static Tensor operator !=(Tensor x, int y) | |||
| => gen_math_ops.not_equal(x, math_ops.cast(y, dtype: x.dtype)); | |||
| public static Tensor operator ==(Tensor x, int y) | |||
| => gen_math_ops.equal(x, math_ops.cast(y, dtype: x.dtype)); | |||
| } | |||
| } | |||
| @@ -144,6 +144,12 @@ namespace Tensorflow | |||
| break; | |||
| } | |||
| } | |||
| else if (dtype != TF_DataType.DtInvalid && | |||
| value is NDArray nd && | |||
| dtypes.as_dtype(nd.dtype) != dtype) | |||
| { | |||
| value = nd.astype(dtype.as_numpy_dtype()); | |||
| } | |||
| if (dtype == TF_DataType.TF_STRING && value is byte[] bytes) | |||
| return new EagerTensor(bytes, ctx.DeviceName, TF_DataType.TF_STRING); | |||
| @@ -87,7 +87,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| if (adapter_steps > -1) | |||
| return adapter_steps; | |||
| var size = dataset.dataset_cardinality(); | |||
| var size = dataset.cardinality(); | |||
| return size.numpy(); | |||
| } | |||
| @@ -147,10 +147,10 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
| public void Cardinality() | |||
| { | |||
| var dataset = tf.data.Dataset.range(10); | |||
| var cardinality = dataset.dataset_cardinality(); | |||
| var cardinality = dataset.cardinality(); | |||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
| dataset = dataset.map(x => x[0] + 1); | |||
| cardinality = dataset.dataset_cardinality(); | |||
| cardinality = dataset.cardinality(); | |||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
| } | |||
| @@ -159,10 +159,23 @@ namespace TensorFlowNET.UnitTest.Dataset | |||
| { | |||
| var dataset = tf.data.Dataset.range(10); | |||
| dataset = dataset.map(x => x, num_parallel_calls: -1); | |||
| var cardinality = dataset.dataset_cardinality(); | |||
| var cardinality = dataset.cardinality(); | |||
| Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); | |||
| } | |||
| [TestMethod] | |||
| public void CardinalityWithRepeat() | |||
| { | |||
| var dataset = tf.data.Dataset.range(10); | |||
| dataset = dataset.repeat(); | |||
| var cardinality = dataset.cardinality(); | |||
| Assert.IsTrue((cardinality == tf.data.INFINITE_CARDINALITY).numpy()); | |||
| dataset = dataset.filter(x => true); | |||
| cardinality = dataset.cardinality(); | |||
| Assert.IsTrue((cardinality == tf.data.UNKNOWN_CARDINALITY).numpy()); | |||
| } | |||
| [TestMethod] | |||
| public void Shuffle() | |||
| { | |||