From c5cdf2c540922ac3abf38e53d1b00e4241802e35 Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Mon, 30 Jan 2023 20:08:40 -0600 Subject: [PATCH] Fixed model.fit return results. #927 --- .../NumPy/Numpy.Manipulation.cs | 4 + .../Callbacks/CallbackList.cs | 43 ++++++++++ .../Callbacks/CallbackParams.cs | 15 ++++ src/TensorFlowNET.Keras/Callbacks/History.cs | 52 ++++++++++++ .../Callbacks/ICallback.cs | 15 ++++ .../Callbacks/ProgbarLogger.cs | 81 +++++++++++++++++++ .../Engine/Model.Evaluate.cs | 18 ++--- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 79 ++++++++---------- src/TensorFlowNET.Keras/Engine/Model.Train.cs | 11 ++- src/TensorFlowNET.Keras/Engine/Model.cs | 1 - .../Layers/LayersTest.cs | 2 +- 11 files changed, 258 insertions(+), 63 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Callbacks/CallbackList.cs create mode 100644 src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs create mode 100644 src/TensorFlowNET.Keras/Callbacks/History.cs create mode 100644 src/TensorFlowNET.Keras/Callbacks/ICallback.cs create mode 100644 src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs index 091509fd..94085605 100644 --- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs +++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs @@ -8,6 +8,10 @@ namespace Tensorflow.NumPy { public partial class np { + [AutoNumPy] + public static NDArray concatenate((NDArray, NDArray) tuple, int axis = 0) + => new NDArray(array_ops.concat(new[] { tuple.Item1, tuple.Item2 }, axis)); + [AutoNumPy] public static NDArray concatenate(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.concat(arrays, axis)); diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs new file mode 100644 index 00000000..bb3ed6ed --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackList.cs @@ -0,0 +1,43 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Callbacks +{ + public class CallbackList + { + List callbacks = new List(); + public History History => callbacks[0] as History; + + public CallbackList(CallbackParams parameters) + { + callbacks.Add(new History(parameters)); + callbacks.Add(new ProgbarLogger(parameters)); + } + + public void on_train_begin() + { + callbacks.ForEach(x => x.on_train_begin()); + } + + public void on_epoch_begin(int epoch) + { + callbacks.ForEach(x => x.on_epoch_begin(epoch)); + } + + public void on_train_batch_begin(long step) + { + callbacks.ForEach(x => x.on_train_batch_begin(step)); + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + callbacks.ForEach(x => x.on_train_batch_end(end_step, logs)); + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs)); + } + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs b/src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs new file mode 100644 index 00000000..fe859c8a --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/CallbackParams.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Callbacks +{ + public class CallbackParams + { + public IModel Model { get; set; } + public int Verbose { get; set; } + public int Epochs { get; set; } + public long Steps { get; set; } + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/History.cs b/src/TensorFlowNET.Keras/Callbacks/History.cs new file mode 100644 index 00000000..02588b5e --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/History.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Callbacks +{ + public class History : ICallback + { + List epochs; + CallbackParams _parameters; + public Dictionary> history { get; set; } + + public History(CallbackParams parameters) + { + _parameters = parameters; + } + + public void on_train_begin() + { + epochs = new List(); + history = new Dictionary>(); + } + + public void on_epoch_begin(int epoch) + { + + } + + public void on_train_batch_begin(long step) + { + + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + epochs.Add(epoch); + + foreach (var log in epoch_logs) + { + if (!history.ContainsKey(log.Key)) + { + history[log.Key] = new List(); + } + history[log.Key].Add((float)log.Value); + } + } + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/ICallback.cs b/src/TensorFlowNET.Keras/Callbacks/ICallback.cs new file mode 100644 index 00000000..34763c55 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/ICallback.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Callbacks +{ + public interface ICallback + { + void on_train_begin(); + void on_epoch_begin(int epoch); + void on_train_batch_begin(long step); + void on_train_batch_end(long end_step, Dictionary logs); + void on_epoch_end(int epoch, Dictionary epoch_logs); + } +} diff --git a/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs new file mode 100644 index 00000000..17e04101 --- /dev/null +++ b/src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs @@ -0,0 +1,81 @@ +using PureHDF; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Text; + +namespace Tensorflow.Keras.Callbacks +{ + public class ProgbarLogger : ICallback + { + bool _called_in_fit = false; + int seen = 0; + CallbackParams _parameters; + Stopwatch _sw; + + public ProgbarLogger(CallbackParams parameters) + { + _parameters = parameters; + } + + public void on_train_begin() + { + _called_in_fit = true; + _sw = new Stopwatch(); + } + + public void on_epoch_begin(int epoch) + { + _reset_progbar(); + _maybe_init_progbar(); + Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{_parameters.Epochs:D3}"); + } + + public void on_train_batch_begin(long step) + { + _sw.Restart(); + } + + public void on_train_batch_end(long end_step, Dictionary logs) + { + _sw.Stop(); + var elapse = _sw.ElapsedMilliseconds; + var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {(float)x.Value:F6}")); + + var progress = ""; + var length = 30.0 / _parameters.Steps; + for (int i = 0; i < Math.Floor(end_step * length - 1); i++) + progress += "="; + if (progress.Length < 28) + progress += ">"; + else + progress += "="; + + var remaining = ""; + for (int i = 1; i < 30 - progress.Length; i++) + remaining += "."; + + Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} [{progress}{remaining}] - {elapse}ms/step - {results}"); + if (!Console.IsOutputRedirected) + { + Console.CursorLeft = 0; + } + } + + public void on_epoch_end(int epoch, Dictionary epoch_logs) + { + Console.WriteLine(); + } + + void _reset_progbar() + { + seen = 0; + } + + void _maybe_init_progbar() + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 98e02ed3..c9d39833 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -31,7 +31,7 @@ namespace Tensorflow.Keras.Engine bool use_multiprocessing = false, bool return_dict = false) { - data_handler = new DataHandler(new DataHandlerArgs + var data_handler = new DataHandler(new DataHandlerArgs { X = x, Y = y, @@ -46,7 +46,6 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - Binding.tf_output_redirect.WriteLine($"Testing..."); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); @@ -56,22 +55,20 @@ namespace Tensorflow.Keras.Engine foreach (var step in data_handler.steps()) { // callbacks.on_train_batch_begin(step) - results = test_function(iterator); + results = test_function(data_handler, iterator); } - Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); } } public KeyValuePair[] evaluate(IDatasetV2 x) { - data_handler = new DataHandler(new DataHandlerArgs + var data_handler = new DataHandler(new DataHandlerArgs { Dataset = x, Model = this, StepsPerExecution = _steps_per_execution }); - Binding.tf_output_redirect.WriteLine($"Testing..."); IEnumerable<(string, Tensor)> logs = null; foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { @@ -82,22 +79,21 @@ namespace Tensorflow.Keras.Engine foreach (var step in data_handler.steps()) { // callbacks.on_train_batch_begin(step) - logs = test_function(iterator); + logs = test_function(data_handler, iterator); } - Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", logs.Select(x => $"{x.Item1}: {(float)x.Item2}"))); } return logs.Select(x => new KeyValuePair(x.Item1, (float)x.Item2)).ToArray(); } - IEnumerable<(string, Tensor)> test_function(OwnedIterator iterator) + IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator) { var data = iterator.next(); - var outputs = test_step(data[0], data[1]); + var outputs = test_step(data_handler, data[0], data[1]); tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); return outputs; } - List<(string, Tensor)> test_step(Tensor x, Tensor y) + List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); var y_pred = Apply(x, training: false); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index e0b4af78..bc2c2cea 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -5,6 +5,8 @@ using System.Linq; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine.DataAdapters; using System.Diagnostics; +using Tensorflow.Keras.Callbacks; +using System.Data; namespace Tensorflow.Keras.Engine { @@ -20,7 +22,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - public void fit(NDArray x, NDArray y, + public History fit(NDArray x, NDArray y, int batch_size = -1, int epochs = 1, int verbose = 1, @@ -37,7 +39,7 @@ namespace Tensorflow.Keras.Engine var val_x = x[new Slice(train_count)]; var val_y = y[new Slice(train_count)]; - data_handler = new DataHandler(new DataHandlerArgs + var data_handler = new DataHandler(new DataHandlerArgs { X = train_x, Y = train_y, @@ -52,10 +54,10 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - FitInternal(epochs, verbose); + return FitInternal(data_handler, epochs, verbose); } - public void fit(IDatasetV2 dataset, + public History fit(IDatasetV2 dataset, IDatasetV2 validation_data = null, int batch_size = -1, int epochs = 1, @@ -67,7 +69,7 @@ namespace Tensorflow.Keras.Engine int workers = 1, bool use_multiprocessing = false) { - data_handler = new DataHandler(new DataHandlerArgs + var data_handler = new DataHandler(new DataHandlerArgs { Dataset = dataset, BatchSize = batch_size, @@ -81,67 +83,52 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - FitInternal(epochs, verbose); + return FitInternal(data_handler, epochs, verbose, validation_data: validation_data); } - void FitInternal(int epochs, int verbose) + History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data = null) { stop_training = false; _train_counter.assign(0); - Stopwatch sw = new Stopwatch(); + var callbacks = new CallbackList(new CallbackParams + { + Model = this, + Verbose = verbose, + Epochs = epochs, + Steps = data_handler.Inferredsteps + }); + callbacks.on_train_begin(); + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { reset_metrics(); - on_epoch_begin(epoch, epochs); + callbacks.on_epoch_begin(epoch); // data_handler.catch_stop_iteration(); + var logs = new Dictionary(); foreach (var step in data_handler.steps()) { - sw.Start(); - var results = train_step_function(iterator); - sw.Stop(); - on_train_batch_begin(verbose, step, sw.ElapsedMilliseconds, results); + callbacks.on_train_batch_begin(step); + logs = train_step_function(data_handler, iterator); + var end_step = step + data_handler.StepIncrement; + callbacks.on_train_batch_end(end_step, logs); + } - // recycle memory more frequency - if (sw.ElapsedMilliseconds > 100) + if (validation_data != null) + { + var val_logs = evaluate(validation_data); + foreach(var log in val_logs) { - GC.Collect(); + logs["val_" + log.Key] = log.Value; } - sw.Reset(); } - Console.WriteLine(); + + callbacks.on_epoch_end(epoch, logs); GC.Collect(); GC.WaitForPendingFinalizers(); } - } - - void on_epoch_begin(int epoch, int epochs) - { - Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}"); - } - - void on_train_batch_begin(int verbose, long step, long elapse, IEnumerable<(string, Tensor)> results) - { - if (verbose == 1) - { - var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); - var progress = ""; - for (int i = 0; i < step + 1; i++) - for (int j = 0; j < 30 / data_handler.Inferredsteps; j++) - progress += "="; - progress += ">"; - - var remaining = ""; - for (int i = 1; i < 30 - progress.Length; i++) - remaining += "."; - - Binding.tf_output_redirect.Write($"{step + 1:D4}/{data_handler.Inferredsteps:D4} [{progress}{remaining}] - {elapse}ms/step {result_pairs}"); - if (!Console.IsOutputRedirected) - { - Console.CursorLeft = 0; - } - } + return callbacks.History; } } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs index f2ff68e9..0090b69e 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using System.Linq; using Tensorflow.Gradients; +using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Optimizers; using static Tensorflow.Binding; @@ -8,10 +9,10 @@ namespace Tensorflow.Keras.Engine { public partial class Model { - IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator) + Dictionary train_step_function(DataHandler data_handler, OwnedIterator iterator) { var data = iterator.next(); - var outputs = train_step(data[0], data[1]); + var outputs = train_step(data_handler, data[0], data[1]); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); return outputs; } @@ -21,7 +22,7 @@ namespace Tensorflow.Keras.Engine /// /// /// - List<(string, Tensor)> train_step(Tensor x, Tensor y) + Dictionary train_step(DataHandler data_handler, Tensor x, Tensor y) { (x, y) = data_handler.DataAdapter.Expand1d(x, y); using var tape = tf.GradientTape(); @@ -37,7 +38,9 @@ namespace Tensorflow.Keras.Engine _minimize(tape, optimizer, loss, TrainableVariables); compiled_metrics.update_state(y, y_pred); - return metrics.Select(x => (x.Name, x.result())).ToList(); + var dict = new Dictionary(); + metrics.ToList().ForEach(x => dict[x.Name] = (float)x.result()); + return dict; } void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List trainable_variables) diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index 162d06c5..9bab9bd2 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -34,7 +34,6 @@ namespace Tensorflow.Keras.Engine IVariableV1 _predict_counter; bool _base_model_initialized; bool stop_training; - DataHandler data_handler; public Model(ModelArgs args) : base(args) diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs index d4ac4b90..029592c3 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs @@ -147,7 +147,7 @@ namespace TensorFlowNET.Keras.UnitTest Assert.AreEqual(expected_output, actual_output); } - [TestMethod] + [TestMethod, Ignore("WIP")] public void SimpleRNN() { tf.UseKeras();