diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 11f57bfc..0ae6187a 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -66,7 +66,9 @@ namespace Tensorflow use_legacy_function: use_legacy_function); public IDatasetV2 map(Func map_func, int num_parallel_calls) - => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); + => new ParallelMapDataset(this, map_func, + num_parallel_calls: num_parallel_calls, + preserve_cardinality: true); public OwnedIterator make_one_shot_iterator() { diff --git a/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs index 2a2e823b..6deb30bd 100644 --- a/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs +++ b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs @@ -15,18 +15,26 @@ namespace Tensorflow bool preserve_cardinality = false, bool use_legacy_function = false) : base(input_dataset) { - var func = new ConcreteFunction(map_func, - input_dataset.element_spec.Select(x => x.dtype).ToArray(), - input_dataset.element_spec.Select(x => x.shape).ToArray()); + var func = new ConcreteFunction($"{map_func.Method.Name}_{Tensorflow.ops.uid_function()}"); + func.Enter(); + var inputs = new Tensors(); + foreach (var input in input_dataset.element_spec) + inputs.Add(tf.placeholder(input.dtype, shape: input.shape, name: "arg")); + var outputs = map_func(inputs); + func.ToGraph(inputs, outputs); + func.Exit(); structure = func.OutputStructure; + var _num_parallel_calls = tf.convert_to_tensor(num_parallel_calls, dtype: tf.int64, name: "num_parallel_calls"); variant_tensor = ops.parallel_map_dataset_v2(input_dataset.variant_tensor, _num_parallel_calls, func, output_types, - output_shapes); + output_shapes, + use_inter_op_parallelism: use_inter_op_parallelism, + preserve_cardinality: preserve_cardinality); } } } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 1ee09a0a..c1f9788c 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -71,7 +71,7 @@ namespace Tensorflow.Functions func_graph.Exit(); } - public ConcreteFunction(Func func, + /*public ConcreteFunction(Func func, TF_DataType[] dtypes, TensorShape[] shapes) { string func_name = $"{func.Method.Name}_{ops.uid_function()}"; @@ -89,7 +89,7 @@ namespace Tensorflow.Functions var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); func_graph.ToGraph(opers, inputs, Outputs, null); func_graph.Exit(); - } + }*/ public void ToGraph(Tensors inputs, Tensors outputs) { diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 1f23fc44..04f21bb0 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -38,6 +38,8 @@ namespace Tensorflow } } + public Tensor this[params string[] slices] + => items.First()[slices]; public Tensors(params Tensor[] tensors) { items.AddRange(tensors); diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 048106bb..2bd25da0 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -585,6 +585,10 @@ would not be rank 1.", tensor.op.get_attr("axis"))); else return $"['{string.Join("', '", tensor.StringData().Take(25))}']"; } + else if(dtype == TF_DataType.TF_VARIANT) + { + return ""; + } var nd = tensor.numpy(); diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 0ff5c296..a5b26e2c 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -100,6 +100,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters using var data_iterator = new OwnedIterator(_dataset); yield return (epoch, data_iterator); } + // _adapter.on_epoch_end() } public IEnumerable steps() diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs index 6f1b7790..71bd2f38 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -33,22 +33,22 @@ namespace Tensorflow.Keras.Engine public void compile(string optimizer, string loss, string[] metrics) { - switch (optimizer) + var _optimizer = optimizer switch { - case "rmsprop": - this.optimizer = new RMSprop(new RMSpropArgs - { + "rmsprop" => new RMSprop(new RMSpropArgs + { - }); - break; - } + }), + _ => throw new NotImplementedException("") + }; - int experimental_steps_per_execution = 1; - _configure_steps_per_execution(experimental_steps_per_execution); - - _reset_compile_cache(); + var _loss = loss switch + { + "mse" => new MeanSquaredError(), + _ => throw new NotImplementedException("") + }; - _is_compiled = true; + compile(optimizer: _optimizer, loss: _loss, metrics: metrics); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 11910db4..7f48f8ab 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -49,7 +49,32 @@ namespace Tensorflow.Keras.Engine Binding.tf_output_redirect.WriteLine($"Testing..."); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { - // reset_metrics(); + reset_metrics(); + // callbacks.on_epoch_begin(epoch) + // data_handler.catch_stop_iteration(); + IEnumerable<(string, Tensor)> results = null; + foreach (var step in data_handler.steps()) + { + // callbacks.on_train_batch_begin(step) + results = test_function(iterator); + } + Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); + } + } + + public void evaluate(IDatasetV2 x) + { + data_handler = new DataHandler(new DataHandlerArgs + { + Dataset = x, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + Binding.tf_output_redirect.WriteLine($"Testing..."); + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + reset_metrics(); // callbacks.on_epoch_begin(epoch) // data_handler.catch_stop_iteration(); IEnumerable<(string, Tensor)> results = null; diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs index f820da9a..d34c97b3 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs @@ -124,10 +124,11 @@ namespace Tensorflow.Keras var start_positions_tensor = tf.constant(start_positions); var positions_ds = tf.data.Dataset.from_tensors(start_positions_tensor).repeat(); - var z = tf.data.Dataset.zip(tf.data.Dataset.range(len(start_positions)), positions_ds); + var r = tf.data.Dataset.range(len(start_positions)); + var z = tf.data.Dataset.zip(r, positions_ds); var indices = z.map(m => { - var (i, positions) = (m[0], m[1]); + var (i, positions) = m; return tf.range(positions[i], positions[i] + sequence_length_tensor * sampling_rate_tensor, sampling_rate_tensor); }, num_parallel_calls: -1); var dataset = sequences_from_indices(data, indices, start_index, end_index); @@ -142,7 +143,11 @@ namespace Tensorflow.Keras { var dataset = tf.data.Dataset.from_tensors(array[new Slice(start: start_index, stop: end_index)]); dataset = tf.data.Dataset.zip(dataset.repeat(), indices_ds) - .map(x => array_ops.gather(x[0], x[1]), num_parallel_calls: -1); + .map(x => + { + var (steps, indx) = x; + return array_ops.gather(steps, indx); + }, num_parallel_calls: -1); return dataset; } } diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index b3729abc..f624476c 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -147,7 +147,18 @@ namespace TensorFlowNET.UnitTest.Dataset public void Cardinality() { var dataset = tf.data.Dataset.range(10); + var cardinality = dataset.dataset_cardinality(); + Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); dataset = dataset.map(x => x[0] + 1); + cardinality = dataset.dataset_cardinality(); + Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); + } + + [TestMethod] + public void CardinalityWithAutoTune() + { + var dataset = tf.data.Dataset.range(10); + dataset = dataset.map(x => x, num_parallel_calls: -1); var cardinality = dataset.dataset_cardinality(); Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); }