From d2e50dda558b6a30b6ad3e0b8380bffc07e93185 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 16 Jan 2021 08:08:15 -0600 Subject: [PATCH] Add keras model.predict. --- .../Engine/DataAdapters/DataHandler.cs | 3 +- .../DataAdapters/TensorLikeDataAdapter.cs | 14 ++++++-- src/TensorFlowNET.Keras/Engine/Functional.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 2 +- .../Engine/Model.Predict.cs | 35 +++++++++++++++++-- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 2 +- .../Layers/LayersApi.Reshaping.cs | 6 ++++ .../Normalization/BatchNormalization.cs | 3 -- .../Layers/Reshaping/Reshape.cs | 20 +++++++---- .../Tensorflow.Keras.csproj | 10 +++--- 10 files changed, 74 insertions(+), 23 deletions(-) diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index 42705dee..cb275166 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -18,6 +18,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters public int Inferredsteps => _inferred_steps; int _current_step; int _step_increment; + public int StepIncrement => _step_increment; bool _insufficient_data; int _steps_per_execution_value; int _initial_epoch => args.InitialEpoch; @@ -73,7 +74,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters _dataset = _adapter.GetDataset(); _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); _current_step = 0; - _step_increment = args.StepsPerExecution.numpy() - 1; + _step_increment = _steps_per_execution_value - 1; _insufficient_data = false; } diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 3d9306f5..1741201b 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters int _batch_size; int num_samples; int num_full_batches; + int _partial_batch_size; public TensorLikeDataAdapter(DataAdapterArgs args) { @@ -22,9 +23,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters num_samples = args.X.shape[0]; var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; _batch_size = batch_size; - _size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f))); + _size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size; num_full_batches = num_samples / batch_size; - var _partial_batch_size = num_samples % batch_size; + _partial_batch_size = num_samples % batch_size; var indices_dataset = tf.data.Dataset.range(1); indices_dataset = indices_dataset.repeat(args.Epochs); @@ -57,6 +58,15 @@ namespace Tensorflow.Keras.Engine.DataAdapters var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); + if (_partial_batch_size > 0) + { + var array = array_ops.slice(indices, + new[] { constant_op.constant(num_in_full_batch)}, + new[] { constant_op.constant(_partial_batch_size)}); + var index_remainder = tf.data.Dataset.from_tensor(array); + flat_dataset = flat_dataset.concatenate(index_remainder); + } + return flat_dataset; } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 6c67f109..3409f682 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); var outputs = node.Layer.Apply(layer_inputs, is_training: training); foreach (var output in outputs.Where(x => x != null)) - tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); + tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); // Update tensor_dict for next input foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 77039fae..ca117bb3 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -95,7 +95,7 @@ namespace Tensorflow.Keras.Engine foreach (var step in data_handler.steps()) { // callbacks.on_train_batch_begin(step) - var results = step_function(iterator); + var results = train_step_function(iterator); var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs index b90d1672..ab2a0ec0 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -1,7 +1,10 @@ using NumSharp; using System; +using System.Collections.Generic; +using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { @@ -21,7 +24,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - public Tensor predict(Tensor x, + public Tensors predict(Tensor x, int batch_size = -1, int verbose = 0, int steps = -1, @@ -43,7 +46,35 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - throw new NotImplementedException(""); + Tensors outputs = null; + _predict_counter.assign(0); + // callbacks.on_predict_begin() + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + foreach(var step in data_handler.steps()) + { + // callbacks.on_predict_batch_begin(step) + var batch_outputs = run_predict_step(iterator); + outputs = batch_outputs; + var end_step = step + data_handler.StepIncrement; + // callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs}) + } + } + // callbacks.on_predict_end() + return outputs; + } + + Tensors run_predict_step(OwnedIterator iterator) + { + var data = iterator.next(); + var outputs = predict_step(data[0]); + tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1)); + return outputs; + } + + Tensors predict_step(Tensor data) + { + return Apply(data, is_training: false); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index 4ea4dbfa..961405d5 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine { public partial class Model { - IEnumerable<(string, Tensor)> step_function(OwnedIterator iterator) + IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator) { var data = iterator.next(); var outputs = train_step(data[0], data[1]); diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs index c0bfa321..bbba0302 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs @@ -45,5 +45,11 @@ namespace Tensorflow.Keras.Layers { TargetShape = target_shape }); + + public Reshape Reshape(object[] target_shape) + => new Reshape(new ReshapeArgs + { + TargetShapeObjects = target_shape + }); } } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index bbbe495c..d4dbb3d7 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -142,9 +142,6 @@ namespace Tensorflow.Keras.Layers if (use_fused_avg_updates) exponential_avg_factor = 1.0f - momentum; - var beta = this.beta; - var gamma = this.gamma; - Func _fused_batch_norm_training = () => { return tf.nn.fused_batch_norm( diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index dce2013c..68bd76af 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -21,11 +21,15 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false) { - var shape_tensor = array_ops.shape(inputs); - var shape = new List { inputs.shape[0] }; - shape.AddRange(args.TargetShape.dims); + var shapes = new List(); + shapes.Add(array_ops.shape(inputs)[0]); + if (args.TargetShapeObjects != null) + shapes.AddRange(args.TargetShapeObjects); + if (args.TargetShape != null) + args.TargetShape.dims.ToList().ForEach(x => shapes.Add(x)); + var shape = ops.convert_to_tensor(shapes); - var result = array_ops.reshape(inputs, shape.ToArray()); + var result = array_ops.reshape(inputs, shape); if (!tf.Context.executing_eagerly()) result.set_shape(ComputeOutputShape(inputs.shape)); return result; @@ -33,14 +37,16 @@ namespace Tensorflow.Keras.Layers public override TensorShape ComputeOutputShape(TensorShape input_shape) { - if (input_shape.dims[0] == -1) + if (input_shape.dims[1..].Contains(-1)) + { + throw new NotImplementedException(""); + } + else { input_shape = input_shape.dims[0]; var output_shape = input_shape.concatenate(args.TargetShape.dims); return output_shape; } - else - throw new NotImplementedException(""); } } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index e4864e1d..e705b3d1 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -6,7 +6,7 @@ 8.0 Tensorflow.Keras AnyCPU;x64 - 0.3.0 + 0.4.0 Haiping Chen Keras for .NET Apache 2.0, Haiping Chen 2020 @@ -20,7 +20,8 @@ * Support Conv2D functional API. * Support BatchNormalization layer. * Building keras model in subclass, functional and sequential api -* Implemented backward_function. +* Implemented backward_function. +* Support model.load_weights. Keras for .NET Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. @@ -31,8 +32,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac Git true Open.snk - 0.3.0.0 - 0.3.0.0 + 0.4.0.0 + 0.4.0.0 LICENSE @@ -48,7 +49,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac -