| @@ -14,6 +14,38 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Model | |||
| { | |||
| protected Dictionary<string, float> evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val) | |||
| { | |||
| callbacks.on_test_begin(); | |||
| //Dictionary<string, float>? logs = null; | |||
| var logs = new Dictionary<string, float>(); | |||
| int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| callbacks.on_epoch_begin(epoch); | |||
| // data_handler.catch_stop_iteration(); | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| callbacks.on_test_batch_begin(step); | |||
| var data = iterator.next(); | |||
| logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
| tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _test_counter.assign_add(1)); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| if (!is_val) | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| } | |||
| } | |||
| return logs; | |||
| } | |||
| /// <summary> | |||
| /// Returns the loss value & metrics values for the model in test mode. | |||
| /// </summary> | |||
| @@ -64,31 +96,8 @@ namespace Tensorflow.Keras.Engine | |||
| Verbose = verbose, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| callbacks.on_test_begin(); | |||
| //Dictionary<string, float>? logs = null; | |||
| var logs = new Dictionary<string, float>(); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| // data_handler.catch_stop_iteration(); | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| callbacks.on_test_batch_begin(step); | |||
| logs = test_function(data_handler, iterator); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| if (is_val == false) | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| } | |||
| } | |||
| var results = new Dictionary<string, float>(); | |||
| foreach (var log in logs) | |||
| { | |||
| results[log.Key] = log.Value; | |||
| } | |||
| return results; | |||
| return evaluate(callbacks, data_handler, is_val); | |||
| } | |||
| public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) | |||
| @@ -107,31 +116,8 @@ namespace Tensorflow.Keras.Engine | |||
| Verbose = verbose, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| callbacks.on_test_begin(); | |||
| Dictionary<string, float> logs = null; | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| callbacks.on_epoch_begin(epoch); | |||
| // data_handler.catch_stop_iteration(); | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| callbacks.on_test_batch_begin(step); | |||
| logs = test_function(data_handler, iterator); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| if (is_val == false) | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| } | |||
| } | |||
| var results = new Dictionary<string, float>(); | |||
| foreach (var log in logs) | |||
| { | |||
| results[log.Key] = log.Value; | |||
| } | |||
| return results; | |||
| return evaluate(callbacks, data_handler, is_val); | |||
| } | |||
| @@ -150,51 +136,8 @@ namespace Tensorflow.Keras.Engine | |||
| Verbose = verbose, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| callbacks.on_test_begin(); | |||
| Dictionary<string, float> logs = null; | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| callbacks.on_epoch_begin(epoch); | |||
| // data_handler.catch_stop_iteration(); | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| callbacks.on_test_batch_begin(step); | |||
| logs = test_function(data_handler, iterator); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| if (is_val == false) | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| } | |||
| } | |||
| var results = new Dictionary<string, float>(); | |||
| foreach (var log in logs) | |||
| { | |||
| results[log.Key] = log.Value; | |||
| } | |||
| return results; | |||
| } | |||
| Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size))); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y) | |||
| { | |||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| var y_pred = Apply(x, training: false); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2); | |||
| return evaluate(callbacks, data_handler, is_val); | |||
| } | |||
| } | |||
| } | |||
| } | |||