diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 00b04edf..850211d1 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -60,7 +60,7 @@ namespace Tensorflow preserve_cardinality: preserve_cardinality, use_legacy_function: use_legacy_function); - public IDatasetV2 map(Func map_func, int num_parallel_calls = -1) + public IDatasetV2 map(Func map_func, int num_parallel_calls = -1) => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); public IDatasetV2 flat_map(Func map_func) diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index 0f38d531..5240f550 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -60,7 +60,7 @@ namespace Tensorflow bool preserve_cardinality = true, bool use_legacy_function = false); - IDatasetV2 map(Func map_func, + IDatasetV2 map(Func map_func, int num_parallel_calls = -1); IDatasetV2 flat_map(Func map_func); diff --git a/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs index cd9cd3df..2a2e823b 100644 --- a/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs +++ b/src/TensorFlowNET.Core/Data/ParallelMapDataset.cs @@ -9,7 +9,7 @@ namespace Tensorflow public class ParallelMapDataset : UnaryDataset { public ParallelMapDataset(IDatasetV2 input_dataset, - Func map_func, + Func map_func, int num_parallel_calls = -1, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, diff --git a/src/TensorFlowNET.Core/Data/TensorDataset.cs b/src/TensorFlowNET.Core/Data/TensorDataset.cs index 2c9c6a44..db0e65dd 100644 --- a/src/TensorFlowNET.Core/Data/TensorDataset.cs +++ b/src/TensorFlowNET.Core/Data/TensorDataset.cs @@ -12,8 +12,7 @@ namespace Tensorflow public TensorDataset(Tensors elements) { _tensors = elements; - var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); - structure = batched_spec.Select(x => x._unbatch()).ToArray(); + structure = _tensors.Select(x => x.ToTensorSpec()).ToArray(); variant_tensor = ops.tensor_dataset(_tensors, output_shapes); } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 10bac1fc..18cd74a9 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -83,7 +83,7 @@ namespace Tensorflow.Functions graph.Exit(); } - public ConcreteFunction(Func func, + public ConcreteFunction(Func func, TF_DataType[] dtypes, TensorShape[] shapes) { string func_name = $"{func.Method.Name}_{Guid.NewGuid()}"; @@ -92,19 +92,14 @@ namespace Tensorflow.Functions using var graph = new FuncGraph(func_name); graph.as_default(); - var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); - var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); - var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); - var outputs = func(input1, (input2, input3)); - - Outputs = new[] { outputs.Item1, outputs.Item2 }; - OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; + var inputs = new Tensors(); + foreach(var (i, dtype) in enumerate(dtypes)) + inputs.Add(tf.placeholder(dtypes[i], shape: shapes[i], name: "args")); + Outputs = func(inputs); + OutputStructure = Outputs.Select(x => x.ToTensorSpec()).ToArray(); var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); - _handle = graph.ToGraph(opers, - new[] { input1, input2, input3 }, - new[] { outputs.Item1, outputs.Item2 }, - null); + _handle = graph.ToGraph(opers, inputs, Outputs, null); graph.Exit(); } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 94aef370..3d9306f5 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -1,4 +1,5 @@ using System; +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using static Tensorflow.Binding; @@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters num_samples = args.X.shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; - _size = Convert.ToInt32(Math.Ceiling(num_samples / (batch_size + 0f))); + _size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f))); num_full_batches = num_samples / batch_size; var _partial_batch_size = num_samples % batch_size; @@ -29,12 +30,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters indices_dataset = indices_dataset.repeat(args.Epochs); indices_dataset = indices_dataset.map(permutation).prefetch(1); indices_dataset = indices_dataset.flat_map(slice_batch_indices); - var elements = new Tensors(); + var inputs = new Tensors(); if (args.X != null) - elements.Add(args.X); + inputs.Add(args.X); if (args.Y != null) - elements.Add(args.Y); - dataset = slice_inputs(indices_dataset, elements); + inputs.Add(args.Y); + dataset = slice_inputs(indices_dataset, inputs); } Tensor permutation(Tensor tensor) @@ -64,11 +65,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters var dataset2 = tf.data.Dataset.from_tensor(elements).repeat(); var dataset = tf.data.Dataset.zip(indices_dataset, dataset2); - dataset = dataset.map((batch, data) => + dataset = dataset.map(inputs => { - var x = gen_array_ops.gather_v2(data.Item1, batch, 0); - var y = gen_array_ops.gather_v2(data.Item2, batch, 0); - return (x, y); + var indices = inputs[0]; + var results = inputs.Skip(1) + .Select(x => gen_array_ops.gather_v2(x, indices, 0)) + .ToArray(); + return new Tensors(results); }); dataset = dataset.with_options(new DatasetOptions { }); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 8c395281..77039fae 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -97,7 +97,7 @@ namespace Tensorflow.Keras.Engine // callbacks.on_train_batch_begin(step) var results = step_function(iterator); var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); - Console.WriteLine($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); + Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); } } }