| @@ -1,4 +1,5 @@ | |||||
| using Tensorflow.NumPy; | |||||
| using OneOf; | |||||
| using Tensorflow.NumPy; | |||||
| namespace Tensorflow.Util | namespace Tensorflow.Util | ||||
| { | { | ||||
| @@ -8,10 +9,10 @@ namespace Tensorflow.Util | |||||
| /// </summary> | /// </summary> | ||||
| public class ValidationDataPack | public class ValidationDataPack | ||||
| { | { | ||||
| public NDArray val_x; | |||||
| public OneOf<NDArray, NDArray[]> val_x; | |||||
| public NDArray val_y; | public NDArray val_y; | ||||
| public NDArray val_sample_weight = null; | public NDArray val_sample_weight = null; | ||||
| public bool val_x_is_array = false; | |||||
| public ValidationDataPack((NDArray, NDArray) validation_data) | public ValidationDataPack((NDArray, NDArray) validation_data) | ||||
| { | { | ||||
| this.val_x = validation_data.Item1; | this.val_x = validation_data.Item1; | ||||
| @@ -27,15 +28,17 @@ namespace Tensorflow.Util | |||||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | ||||
| { | { | ||||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||||
| this.val_x = validation_data.Item1.ToArray(); | |||||
| this.val_y = validation_data.Item2; | this.val_y = validation_data.Item2; | ||||
| val_x_is_array = true; | |||||
| } | } | ||||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | ||||
| { | { | ||||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||||
| this.val_x = validation_data.Item1.ToArray(); | |||||
| this.val_y = validation_data.Item2; | this.val_y = validation_data.Item2; | ||||
| this.val_sample_weight = validation_data.Item3; | this.val_sample_weight = validation_data.Item3; | ||||
| val_x_is_array = true; | |||||
| } | } | ||||
| public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | ||||
| @@ -52,15 +55,24 @@ namespace Tensorflow.Util | |||||
| public void Deconstruct(out NDArray val_x, out NDArray val_y) | public void Deconstruct(out NDArray val_x, out NDArray val_y) | ||||
| { | { | ||||
| val_x = this.val_x; | |||||
| val_x = this.val_x.AsT0; | |||||
| val_y = this.val_y; | val_y = this.val_y; | ||||
| } | } | ||||
| public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | ||||
| { | { | ||||
| val_x = this.val_x; | |||||
| val_x = this.val_x.AsT0; | |||||
| val_y = this.val_y; | |||||
| val_sample_weight = this.val_sample_weight; | |||||
| } | |||||
| // add a unuse parameter to make it different from Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||||
| public void Deconstruct(out NDArray[] val_x_array, out NDArray val_y, out NDArray val_sample_weight, out NDArray unuse) | |||||
| { | |||||
| val_x_array = this.val_x.AsT1; | |||||
| val_y = this.val_y; | val_y = this.val_y; | ||||
| val_sample_weight = this.val_sample_weight; | val_sample_weight = this.val_sample_weight; | ||||
| unuse = null; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -92,9 +92,17 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| var train_y = y[new Slice(0, train_count)]; | var train_y = y[new Slice(0, train_count)]; | ||||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | ||||
| var val_y = y[new Slice(train_count)]; | var val_y = y[new Slice(train_count)]; | ||||
| NDArray tmp_sample_weight = sample_weight; | |||||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||||
| ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||||
| ValidationDataPack validation_data; | |||||
| if (sample_weight != null) | |||||
| { | |||||
| validation_data = (val_x, val_y, sample_weight[new Slice(train_count)]); | |||||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||||
| } | |||||
| else | |||||
| { | |||||
| validation_data = (val_x, val_y); | |||||
| } | |||||
| return ((train_x, train_y, sample_weight), validation_data); | return ((train_x, train_y, sample_weight), validation_data); | ||||
| } | } | ||||
| } | } | ||||
| @@ -70,13 +70,19 @@ namespace Tensorflow.Keras.Engine | |||||
| return evaluate(data_handler, callbacks, is_val, test_function); | return evaluate(data_handler, callbacks, is_val, test_function); | ||||
| } | } | ||||
| public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false) | |||||
| public Dictionary<string, float> evaluate( | |||||
| IEnumerable<Tensor> x, | |||||
| Tensor y, | |||||
| int verbose = 1, | |||||
| NDArray sample_weight = null, | |||||
| bool is_val = false) | |||||
| { | { | ||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| X = new Tensors(x.ToArray()), | X = new Tensors(x.ToArray()), | ||||
| Y = y, | Y = y, | ||||
| Model = this, | Model = this, | ||||
| SampleWeight = sample_weight, | |||||
| StepsPerExecution = _steps_per_execution | StepsPerExecution = _steps_per_execution | ||||
| }); | }); | ||||
| @@ -7,6 +7,7 @@ using Tensorflow.Keras.Engine.DataAdapters; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using Tensorflow.Keras.Callbacks; | using Tensorflow.Keras.Callbacks; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| using OneOf; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -287,10 +288,24 @@ namespace Tensorflow.Keras.Engine | |||||
| if (validation_data != null) | if (validation_data != null) | ||||
| { | { | ||||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||||
| var (val_x, val_y, val_sample_weight) = validation_data; | |||||
| var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||||
| NDArray val_x; | |||||
| NDArray[] val_x_array; | |||||
| NDArray val_y; | |||||
| NDArray val_sample_weight; | |||||
| Dictionary<string, float> val_logs; | |||||
| if (!validation_data.val_x_is_array) | |||||
| { | |||||
| (val_x, val_y, val_sample_weight) = validation_data; | |||||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||||
| val_logs = evaluate(val_x, val_y, sample_weight: val_sample_weight, is_val: true); | |||||
| } | |||||
| else | |||||
| { | |||||
| (val_x_array, val_y, val_sample_weight, _) = validation_data; | |||||
| val_logs = evaluate(val_x_array, val_y, sample_weight: val_sample_weight, is_val: true); | |||||
| } | |||||
| foreach (var log in val_logs) | foreach (var log in val_logs) | ||||
| { | { | ||||
| logs["val_" + log.Key] = log.Value; | logs["val_" + log.Key] = log.Value; | ||||