| @@ -1,4 +1,5 @@ | |||
| using Tensorflow.NumPy; | |||
| using OneOf; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Util | |||
| { | |||
| @@ -8,10 +9,10 @@ namespace Tensorflow.Util | |||
| /// </summary> | |||
| public class ValidationDataPack | |||
| { | |||
| public NDArray val_x; | |||
| public OneOf<NDArray, NDArray[]> val_x; | |||
| public NDArray val_y; | |||
| public NDArray val_sample_weight = null; | |||
| public bool val_x_is_array = false; | |||
| public ValidationDataPack((NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1; | |||
| @@ -27,15 +28,17 @@ namespace Tensorflow.Util | |||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||
| this.val_x = validation_data.Item1.ToArray(); | |||
| this.val_y = validation_data.Item2; | |||
| val_x_is_array = true; | |||
| } | |||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||
| this.val_x = validation_data.Item1.ToArray(); | |||
| this.val_y = validation_data.Item2; | |||
| this.val_sample_weight = validation_data.Item3; | |||
| val_x_is_array = true; | |||
| } | |||
| public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | |||
| @@ -52,15 +55,24 @@ namespace Tensorflow.Util | |||
| public void Deconstruct(out NDArray val_x, out NDArray val_y) | |||
| { | |||
| val_x = this.val_x; | |||
| val_x = this.val_x.AsT0; | |||
| val_y = this.val_y; | |||
| } | |||
| public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
| { | |||
| val_x = this.val_x; | |||
| val_x = this.val_x.AsT0; | |||
| val_y = this.val_y; | |||
| val_sample_weight = this.val_sample_weight; | |||
| } | |||
| // add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
| public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) | |||
| { | |||
| val_x_array = this.val_x.AsT1; | |||
| val_y = this.val_y; | |||
| val_sample_weight = this.val_sample_weight; | |||
| unuse = null; | |||
| } | |||
| } | |||
| } | |||
| @@ -92,9 +92,17 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| var train_y = y[new Slice(0, train_count)]; | |||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
| var val_y = y[new Slice(train_count)]; | |||
| NDArray tmp_sample_weight = sample_weight; | |||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||
| ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||
| ValidationDataPack validation_data; | |||
| if (sample_weight != null) | |||
| { | |||
| validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); | |||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||
| } | |||
| else | |||
| { | |||
| validation_data = (val_x, val_y); | |||
| } | |||
| return ((train_x, train_y, sample_weight), validation_data); | |||
| } | |||
| } | |||
| @@ -70,13 +70,19 @@ namespace Tensorflow.Keras.Engine | |||
| return evaluate(data_handler, callbacks, is_val, test_function); | |||
| } | |||
| public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) | |||
| public Dictionary<string, float> evaluate( | |||
| IEnumerable<Tensor> x, | |||
| Tensor y, | |||
| int verbose = 1, | |||
| NDArray sample_weight = null, | |||
| bool is_val = false) | |||
| { | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = new Tensors(x.ToArray()), | |||
| Y = y, | |||
| Model = this, | |||
| SampleWeight = sample_weight, | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| @@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Keras.Callbacks; | |||
| using Tensorflow.Util; | |||
| using OneOf; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -287,10 +288,24 @@ namespace Tensorflow.Keras.Engine | |||
| if (validation_data != null) | |||
| { | |||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||
| var (val_x, val_y, val_sample_weight) = validation_data; | |||
| var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||
| NDArray val_x; | |||
| NDArray[] val_x_array; | |||
| NDArray val_y; | |||
| NDArray val_sample_weight; | |||
| Dictionary<string, float> val_logs; | |||
| if (!validation_data.val_x_is_array) | |||
| { | |||
| (val_x, val_y, val_sample_weight) = validation_data; | |||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||
| val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); | |||
| } | |||
| else | |||
| { | |||
| (val_x_array, val_y, val_sample_weight, _) = validation_data; | |||
| val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); | |||
| } | |||
| foreach (var log in val_logs) | |||
| { | |||
| logs["val_" + log.Key] = log.Value; | |||