| @@ -12,4 +12,6 @@ public interface ICallback | |||||
| void on_predict_batch_begin(long step); | void on_predict_batch_begin(long step); | ||||
| void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | ||||
| void on_predict_end(); | void on_predict_end(); | ||||
| void on_test_begin(); | |||||
| void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs); | |||||
| } | } | ||||
| @@ -20,7 +20,10 @@ public class CallbackList | |||||
| { | { | ||||
| callbacks.ForEach(x => x.on_train_begin()); | callbacks.ForEach(x => x.on_train_begin()); | ||||
| } | } | ||||
| public void on_test_begin() | |||||
| { | |||||
| callbacks.ForEach(x => x.on_test_begin()); | |||||
| } | |||||
| public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
| { | { | ||||
| callbacks.ForEach(x => x.on_epoch_begin(epoch)); | callbacks.ForEach(x => x.on_epoch_begin(epoch)); | ||||
| @@ -60,4 +63,13 @@ public class CallbackList | |||||
| { | { | ||||
| callbacks.ForEach(x => x.on_predict_end()); | callbacks.ForEach(x => x.on_predict_end()); | ||||
| } | } | ||||
| public void on_test_batch_begin(long step) | |||||
| { | |||||
| callbacks.ForEach(x => x.on_train_batch_begin(step)); | |||||
| } | |||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
| { | |||||
| callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | |||||
| } | |||||
| } | } | ||||
| @@ -18,7 +18,11 @@ public class History : ICallback | |||||
| epochs = new List<int>(); | epochs = new List<int>(); | ||||
| history = new Dictionary<string, List<float>>(); | history = new Dictionary<string, List<float>>(); | ||||
| } | } | ||||
| public void on_test_begin() | |||||
| { | |||||
| epochs = new List<int>(); | |||||
| history = new Dictionary<string, List<float>>(); | |||||
| } | |||||
| public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
| { | { | ||||
| @@ -26,7 +30,7 @@ public class History : ICallback | |||||
| public void on_train_batch_begin(long step) | public void on_train_batch_begin(long step) | ||||
| { | { | ||||
| } | } | ||||
| public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | ||||
| @@ -55,16 +59,25 @@ public class History : ICallback | |||||
| public void on_predict_batch_begin(long step) | public void on_predict_batch_begin(long step) | ||||
| { | { | ||||
| } | } | ||||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | ||||
| { | { | ||||
| } | } | ||||
| public void on_predict_end() | public void on_predict_end() | ||||
| { | { | ||||
| } | |||||
| public void on_test_batch_begin(long step) | |||||
| { | |||||
| } | |||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
| { | |||||
| } | } | ||||
| } | } | ||||
| @@ -22,7 +22,10 @@ namespace Tensorflow.Keras.Callbacks | |||||
| _called_in_fit = true; | _called_in_fit = true; | ||||
| _sw = new Stopwatch(); | _sw = new Stopwatch(); | ||||
| } | } | ||||
| public void on_test_begin() | |||||
| { | |||||
| _sw = new Stopwatch(); | |||||
| } | |||||
| public void on_epoch_begin(int epoch) | public void on_epoch_begin(int epoch) | ||||
| { | { | ||||
| _reset_progbar(); | _reset_progbar(); | ||||
| @@ -44,7 +47,7 @@ namespace Tensorflow.Keras.Callbacks | |||||
| var progress = ""; | var progress = ""; | ||||
| var length = 30.0 / _parameters.Steps; | var length = 30.0 / _parameters.Steps; | ||||
| for (int i = 0; i < Math.Floor(end_step * length - 1); i++) | for (int i = 0; i < Math.Floor(end_step * length - 1); i++) | ||||
| progress += "="; | |||||
| progress += "="; | |||||
| if (progress.Length < 28) | if (progress.Length < 28) | ||||
| progress += ">"; | progress += ">"; | ||||
| else | else | ||||
| @@ -84,17 +87,35 @@ namespace Tensorflow.Keras.Callbacks | |||||
| public void on_predict_batch_begin(long step) | public void on_predict_batch_begin(long step) | ||||
| { | { | ||||
| } | } | ||||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | ||||
| { | { | ||||
| } | } | ||||
| public void on_predict_end() | public void on_predict_end() | ||||
| { | { | ||||
| } | |||||
| public void on_test_batch_begin(long step) | |||||
| { | |||||
| _sw.Restart(); | |||||
| } | } | ||||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||||
| { | |||||
| _sw.Stop(); | |||||
| var elapse = _sw.ElapsedMilliseconds; | |||||
| var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}")); | |||||
| Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | |||||
| if (!Console.IsOutputRedirected) | |||||
| { | |||||
| Console.CursorLeft = 0; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,6 +5,10 @@ using System.Linq; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Keras.Utils; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Callbacks; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -31,6 +35,11 @@ namespace Tensorflow.Keras.Engine | |||||
| bool use_multiprocessing = false, | bool use_multiprocessing = false, | ||||
| bool return_dict = false) | bool return_dict = false) | ||||
| { | { | ||||
| if (x.dims[0] != y.dims[0]) | |||||
| { | |||||
| throw new InvalidArgumentError( | |||||
| $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||||
| } | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| X = x, | X = x, | ||||
| @@ -46,18 +55,31 @@ namespace Tensorflow.Keras.Engine | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| var callbacks = new CallbackList(new CallbackParams | |||||
| { | |||||
| Model = this, | |||||
| Verbose = verbose, | |||||
| Steps = data_handler.Inferredsteps | |||||
| }); | |||||
| callbacks.on_test_begin(); | |||||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | ||||
| { | { | ||||
| reset_metrics(); | reset_metrics(); | ||||
| // callbacks.on_epoch_begin(epoch) | |||||
| //callbacks.on_epoch_begin(epoch); | |||||
| // data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
| IEnumerable<(string, Tensor)> results = null; | |||||
| IEnumerable<(string, Tensor)> logs = null; | |||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| { | { | ||||
| // callbacks.on_train_batch_begin(step) | |||||
| results = test_function(data_handler, iterator); | |||||
| callbacks.on_train_batch_begin(step); | |||||
| logs = test_function(data_handler, iterator); | |||||
| var end_step = step + data_handler.StepIncrement; | |||||
| callbacks.on_test_batch_end(end_step, logs); | |||||
| } | } | ||||
| } | } | ||||
| GC.Collect(); | |||||
| GC.WaitForPendingFinalizers(); | |||||
| } | } | ||||
| public KeyValuePair<string, float>[] evaluate(IDatasetV2 x) | public KeyValuePair<string, float>[] evaluate(IDatasetV2 x) | ||||
| @@ -75,7 +97,8 @@ namespace Tensorflow.Keras.Engine | |||||
| reset_metrics(); | reset_metrics(); | ||||
| // callbacks.on_epoch_begin(epoch) | // callbacks.on_epoch_begin(epoch) | ||||
| // data_handler.catch_stop_iteration(); | // data_handler.catch_stop_iteration(); | ||||
| foreach (var step in data_handler.steps()) | foreach (var step in data_handler.steps()) | ||||
| { | { | ||||
| // callbacks.on_train_batch_begin(step) | // callbacks.on_train_batch_begin(step) | ||||