| @@ -12,4 +12,6 @@ public interface ICallback | |||
| void on_predict_batch_begin(long step); | |||
| void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs); | |||
| void on_predict_end(); | |||
| void on_test_begin(); | |||
| void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs); | |||
| } | |||
| @@ -20,7 +20,10 @@ public class CallbackList | |||
| { | |||
| callbacks.ForEach(x => x.on_train_begin()); | |||
| } | |||
| public void on_test_begin() | |||
| { | |||
| callbacks.ForEach(x => x.on_test_begin()); | |||
| } | |||
| public void on_epoch_begin(int epoch) | |||
| { | |||
| callbacks.ForEach(x => x.on_epoch_begin(epoch)); | |||
| @@ -60,4 +63,13 @@ public class CallbackList | |||
| { | |||
| callbacks.ForEach(x => x.on_predict_end()); | |||
| } | |||
| public void on_test_batch_begin(long step) | |||
| { | |||
| callbacks.ForEach(x => x.on_train_batch_begin(step)); | |||
| } | |||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
| { | |||
| callbacks.ForEach(x => x.on_test_batch_end(end_step, logs)); | |||
| } | |||
| } | |||
| @@ -18,7 +18,11 @@ public class History : ICallback | |||
| epochs = new List<int>(); | |||
| history = new Dictionary<string, List<float>>(); | |||
| } | |||
| public void on_test_begin() | |||
| { | |||
| epochs = new List<int>(); | |||
| history = new Dictionary<string, List<float>>(); | |||
| } | |||
| public void on_epoch_begin(int epoch) | |||
| { | |||
| @@ -26,7 +30,7 @@ public class History : ICallback | |||
| public void on_train_batch_begin(long step) | |||
| { | |||
| } | |||
| public void on_train_batch_end(long end_step, Dictionary<string, float> logs) | |||
| @@ -55,16 +59,25 @@ public class History : ICallback | |||
| public void on_predict_batch_begin(long step) | |||
| { | |||
| } | |||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||
| { | |||
| } | |||
| public void on_predict_end() | |||
| { | |||
| } | |||
| public void on_test_batch_begin(long step) | |||
| { | |||
| } | |||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
| { | |||
| } | |||
| } | |||
| @@ -22,7 +22,10 @@ namespace Tensorflow.Keras.Callbacks | |||
| _called_in_fit = true; | |||
| _sw = new Stopwatch(); | |||
| } | |||
| public void on_test_begin() | |||
| { | |||
| _sw = new Stopwatch(); | |||
| } | |||
| public void on_epoch_begin(int epoch) | |||
| { | |||
| _reset_progbar(); | |||
| @@ -44,7 +47,7 @@ namespace Tensorflow.Keras.Callbacks | |||
| var progress = ""; | |||
| var length = 30.0 / _parameters.Steps; | |||
| for (int i = 0; i < Math.Floor(end_step * length - 1); i++) | |||
| progress += "="; | |||
| progress += "="; | |||
| if (progress.Length < 28) | |||
| progress += ">"; | |||
| else | |||
| @@ -84,17 +87,35 @@ namespace Tensorflow.Keras.Callbacks | |||
| public void on_predict_batch_begin(long step) | |||
| { | |||
| } | |||
| public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) | |||
| { | |||
| } | |||
| public void on_predict_end() | |||
| { | |||
| } | |||
| public void on_test_batch_begin(long step) | |||
| { | |||
| _sw.Restart(); | |||
| } | |||
| public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) | |||
| { | |||
| _sw.Stop(); | |||
| var elapse = _sw.ElapsedMilliseconds; | |||
| var results = string.Join(" - ", logs.Select(x => $"{x.Item1}: {(float)x.Item2.numpy():F6}")); | |||
| Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} - {elapse}ms/step - {results}"); | |||
| if (!Console.IsOutputRedirected) | |||
| { | |||
| Console.CursorLeft = 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -5,6 +5,10 @@ using System.Linq; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine.DataAdapters; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Utils; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Callbacks; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -31,6 +35,11 @@ namespace Tensorflow.Keras.Engine | |||
| bool use_multiprocessing = false, | |||
| bool return_dict = false) | |||
| { | |||
| if (x.dims[0] != y.dims[0]) | |||
| { | |||
| throw new InvalidArgumentError( | |||
| $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||
| } | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = x, | |||
| @@ -46,18 +55,31 @@ namespace Tensorflow.Keras.Engine | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| var callbacks = new CallbackList(new CallbackParams | |||
| { | |||
| Model = this, | |||
| Verbose = verbose, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| callbacks.on_test_begin(); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| // callbacks.on_epoch_begin(epoch) | |||
| //callbacks.on_epoch_begin(epoch); | |||
| // data_handler.catch_stop_iteration(); | |||
| IEnumerable<(string, Tensor)> results = null; | |||
| IEnumerable<(string, Tensor)> logs = null; | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| // callbacks.on_train_batch_begin(step) | |||
| results = test_function(data_handler, iterator); | |||
| callbacks.on_train_batch_begin(step); | |||
| logs = test_function(data_handler, iterator); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| } | |||
| } | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| } | |||
| public KeyValuePair<string, float>[] evaluate(IDatasetV2 x) | |||
| @@ -75,7 +97,8 @@ namespace Tensorflow.Keras.Engine | |||
| reset_metrics(); | |||
| // callbacks.on_epoch_begin(epoch) | |||
| // data_handler.catch_stop_iteration(); | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| // callbacks.on_train_batch_begin(step) | |||