feat: add the implementation of sample_weight in model.fittags/v0.150.0-BERT-Model
| @@ -1,5 +1,6 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| @@ -16,5 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| public int Worker { get; set; } | |||
| public bool UseMultiprocessing { get; set; } | |||
| public IModel Model { get; set; } | |||
| public Dictionary<int, float> ClassWeight = null; | |||
| public NDArray SampleWeight = null; | |||
| } | |||
| } | |||
| @@ -1,5 +1,6 @@ | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Keras.ArgsDefinition | |||
| { | |||
| @@ -18,5 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
| public bool UseMultiprocessing { get; set; } = false; | |||
| public IModel Model { get; set; } | |||
| public IVariableV1 StepsPerExecution { get; set; } | |||
| public Dictionary<int, float> ClassWeight = null; | |||
| public NDArray SampleWeight = null; | |||
| } | |||
| } | |||
| @@ -3,6 +3,7 @@ using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine; | |||
| @@ -22,8 +23,10 @@ public interface IModel : ILayer | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (NDArray val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -35,8 +38,10 @@ public interface IModel : ILayer | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -63,6 +68,8 @@ public interface IModel : ILayer | |||
| Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int verbose = 1, | |||
| NDArray sample_weight = null, | |||
| int steps = -1, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -0,0 +1,66 @@ | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow.Util | |||
| { | |||
| /// <summary> | |||
| /// ValidationDataPack is used to pass validation data to fit method. | |||
| /// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays. | |||
| /// </summary> | |||
| public class ValidationDataPack | |||
| { | |||
| public NDArray val_x; | |||
| public NDArray val_y; | |||
| public NDArray val_sample_weight = null; | |||
| public ValidationDataPack((NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1; | |||
| this.val_y = validation_data.Item2; | |||
| } | |||
| public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1; | |||
| this.val_y = validation_data.Item2; | |||
| this.val_sample_weight = validation_data.Item3; | |||
| } | |||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||
| this.val_y = validation_data.Item2; | |||
| } | |||
| public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
| { | |||
| this.val_x = validation_data.Item1.ToArray()[0]; | |||
| this.val_y = validation_data.Item2; | |||
| this.val_sample_weight = validation_data.Item3; | |||
| } | |||
| public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
| => new ValidationDataPack(validation_data); | |||
| public void Deconstruct(out NDArray val_x, out NDArray val_y) | |||
| { | |||
| val_x = this.val_x; | |||
| val_y = this.val_y; | |||
| } | |||
| public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
| { | |||
| val_x = this.val_x; | |||
| val_y = this.val_y; | |||
| val_sample_weight = this.val_sample_weight; | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine.DataAdapters | |||
| { | |||
| @@ -34,9 +35,67 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| return (x, y); | |||
| } | |||
| public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight) | |||
| { | |||
| for (int i = 0; i < x.Length; i++) | |||
| { | |||
| if (x[i].shape.ndim == 1) | |||
| x[i] = array_ops.expand_dims(x[i], axis: -1); | |||
| } | |||
| for (int i = 0; i < y.Length; i++) | |||
| { | |||
| if (y[i].shape.ndim == 1) | |||
| y[i] = array_ops.expand_dims(y[i], axis: -1); | |||
| } | |||
| for (int i = 0; i < sample_weight.Length; i++) | |||
| { | |||
| if (sample_weight[i].shape.ndim == 1) | |||
| sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1); | |||
| } | |||
| return (x, y, sample_weight); | |||
| } | |||
| public virtual bool ShouldRecreateIterator() | |||
| { | |||
| return true; | |||
| } | |||
| public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||
| { | |||
| var x = x_y_sample_weight.Item1; | |||
| var y = x_y_sample_weight.Item2; | |||
| var sample_weight = x_y_sample_weight.Item3; | |||
| int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
| var train_x = x[new Slice(0, train_count)]; | |||
| var train_y = y[new Slice(0, train_count)]; | |||
| ValidationDataPack validation_data; | |||
| if (sample_weight != null) | |||
| { | |||
| validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]); | |||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||
| } | |||
| else | |||
| { | |||
| validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]); | |||
| } | |||
| return ((train_x, train_y, sample_weight), validation_data); | |||
| } | |||
| public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||
| { | |||
| var x = x_y_sample_weight.Item1; | |||
| var y = x_y_sample_weight.Item2; | |||
| var sample_weight = x_y_sample_weight.Item3; | |||
| int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
| var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
| var train_y = y[new Slice(0, train_count)]; | |||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
| var val_y = y[new Slice(train_count)]; | |||
| NDArray tmp_sample_weight = sample_weight; | |||
| sample_weight = sample_weight[new Slice(0, train_count)]; | |||
| ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||
| return ((train_x, train_y, sample_weight), validation_data); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using static Tensorflow.Binding; | |||
| using Tensorflow.Keras.Utils; | |||
| namespace Tensorflow.Keras.Engine.DataAdapters | |||
| { | |||
| @@ -28,6 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public DataHandler(DataHandlerArgs args) | |||
| { | |||
| this.args = args; | |||
| if (args.StepsPerExecution == null) | |||
| { | |||
| _steps_per_execution = tf.Variable(1L); | |||
| @@ -48,6 +50,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| BatchSize = args.BatchSize, | |||
| Steps = args.StepsPerEpoch, | |||
| Epochs = args.Epochs - args.InitialEpoch, | |||
| SampleWeight = args.SampleWeight, | |||
| Shuffle = args.Shuffle, | |||
| MaxQueueSize = args.MaxQueueSize, | |||
| Worker = args.Workers, | |||
| @@ -17,6 +17,8 @@ | |||
| IDatasetV2 GetDataset(); | |||
| int GetSize(); | |||
| (Tensors, Tensors) Expand1d(Tensors x, Tensors y); | |||
| (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight); | |||
| bool ShouldRecreateIterator(); | |||
| } | |||
| } | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public TensorLikeDataAdapter(DataAdapterArgs args) | |||
| { | |||
| this.args = args; | |||
| _process_tensorlike(); | |||
| Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null; | |||
| num_samples = (int)args.X.shape[0]; | |||
| var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | |||
| _batch_size = batch_size; | |||
| @@ -37,6 +37,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| inputs.AddRange(args.X); | |||
| if (args.Y != null) | |||
| inputs.AddRange(args.Y); | |||
| if (sample_weight_tensor != null) | |||
| inputs.Add(sample_weight_tensor); | |||
| dataset = slice_inputs(indices_dataset, inputs); | |||
| dataset.FirstInputTensorCount = args.X.Length; | |||
| } | |||
| @@ -94,8 +96,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
| public override bool ShouldRecreateIterator() => false; | |||
| void _process_tensorlike() | |||
| Tensor _process_tensorlike(NDArray sample_weights) | |||
| { | |||
| return tf.convert_to_tensor(sample_weights); | |||
| } | |||
| } | |||
| } | |||
| @@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="y_true"></param> | |||
| /// <param name="y_pred"></param> | |||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||
| public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||
| { | |||
| if (!_built) | |||
| Build(y_pred); | |||
| var loss_value = _losses.Call(y_true, y_pred); | |||
| var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight); | |||
| var loss_metric_value = loss_value; | |||
| var batch_dim = array_ops.shape(y_true)[0]; | |||
| @@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
| public Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int verbose = 1, | |||
| NDArray sample_weight = null, | |||
| int steps = -1, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine | |||
| StepsPerEpoch = steps, | |||
| InitialEpoch = 0, | |||
| Epochs = 1, | |||
| SampleWeight = sample_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -140,7 +142,8 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| var outputs = test_step(data_handler, data[0], data[1]); | |||
| var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) : | |||
| test_step(data_handler, data[0], data[1], data[2]); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -149,17 +152,23 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var data = iterator.next(); | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); | |||
| var outputs = data.Length == 2 ? | |||
| test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||
| test_step( | |||
| data_handler, | |||
| new Tensors(data.Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(2 * x_size).ToArray())); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||
| { | |||
| (x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||
| var y_pred = Apply(x, training: false); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||
| } | |||
| @@ -6,10 +6,12 @@ using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine.DataAdapters; | |||
| using System.Diagnostics; | |||
| using Tensorflow.Keras.Callbacks; | |||
| using System.Data; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| public partial class Model | |||
| { | |||
| /// <summary> | |||
| @@ -19,19 +21,29 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="y"></param> | |||
| /// <param name="batch_size"></param> | |||
| /// <param name="epochs"></param> | |||
| /// <param name="callbacks"></param> | |||
| /// <param name="verbose"></param> | |||
| /// <param name="callbacks"></param> | |||
| /// <param name="validation_split"></param> | |||
| /// <param name="validation_data"></param> | |||
| /// <param name="shuffle"></param> | |||
| /// <param name="class_weight"></param> | |||
| /// <param name="sample_weight"></param> | |||
| /// <param name="initial_epoch"></param> | |||
| /// <param name="max_queue_size"></param> | |||
| /// <param name="workers"></param> | |||
| /// <param name="use_multiprocessing"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="InvalidArgumentError"></exception> | |||
| public ICallback fit(NDArray x, NDArray y, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (NDArray val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -43,21 +55,25 @@ namespace Tensorflow.Keras.Engine | |||
| $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||
| } | |||
| var train_x = x; | |||
| var train_y = y; | |||
| // The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float. | |||
| sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||
| if (validation_split != 0f && validation_data == null) | |||
| { | |||
| int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
| train_x = x[new Slice(0, train_count)]; | |||
| train_y = y[new Slice(0, train_count)]; | |||
| validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||
| ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||
| } | |||
| // TODO(Wanglongzhi2001) | |||
| if (class_weight != null) | |||
| { | |||
| throw new NotImplementedException("class_weight is not implemented"); | |||
| } | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = train_x, | |||
| Y = train_y, | |||
| X = x, | |||
| Y = y, | |||
| SampleWeight = sample_weight, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| @@ -73,14 +89,17 @@ namespace Tensorflow.Keras.Engine | |||
| train_step_func: train_step_function); | |||
| } | |||
| public ICallback fit(IEnumerable<NDArray> x, NDArray y, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| (IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
| ValidationDataPack validation_data = null, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -95,27 +114,23 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| var train_x = x; | |||
| var train_y = y; | |||
| sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||
| if (validation_split != 0f && validation_data == null) | |||
| { | |||
| int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
| train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
| train_y = y[new Slice(0, train_count)]; | |||
| var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
| var val_y = y[new Slice(train_count)]; | |||
| validation_data = (val_x, val_y); | |||
| ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||
| } | |||
| var data_handler = new DataHandler(new DataHandlerArgs | |||
| { | |||
| X = new Tensors(train_x.ToArray()), | |||
| Y = train_y, | |||
| X = new Tensors(x.ToArray()), | |||
| Y = y, | |||
| BatchSize = batch_size, | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| SampleWeight = sample_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -142,8 +157,10 @@ namespace Tensorflow.Keras.Engine | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| IDatasetV2 validation_data = null, | |||
| int validation_step = 10, // 间隔多少次会进行一次验证 | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -210,7 +227,7 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | |||
| continue; | |||
| var val_logs = evaluate(validation_data); | |||
| foreach(var log in val_logs) | |||
| { | |||
| @@ -233,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||
| return callbacks.History; | |||
| } | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, ValidationDataPack validation_data, | |||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
| { | |||
| stop_training = false; | |||
| @@ -274,7 +291,8 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
| // so we need to pass a is_val parameter to stop on_test_batch_end | |||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||
| var (val_x, val_y, val_sample_weight) = validation_data; | |||
| var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||
| foreach (var log in val_logs) | |||
| { | |||
| logs["val_" + log.Key] = log.Value; | |||
| @@ -296,64 +314,5 @@ namespace Tensorflow.Keras.Engine | |||
| return callbacks.History; | |||
| } | |||
| History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||
| Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
| { | |||
| stop_training = false; | |||
| _train_counter.assign(0); | |||
| var callbacks = new CallbackList(new CallbackParams | |||
| { | |||
| Model = this, | |||
| Verbose = verbose, | |||
| Epochs = epochs, | |||
| Steps = data_handler.Inferredsteps | |||
| }); | |||
| if (callbackList != null) | |||
| { | |||
| foreach (var callback in callbackList) | |||
| callbacks.callbacks.add(callback); | |||
| } | |||
| callbacks.on_train_begin(); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| reset_metrics(); | |||
| callbacks.on_epoch_begin(epoch); | |||
| // data_handler.catch_stop_iteration(); | |||
| var logs = new Dictionary<string, float>(); | |||
| long End_step = 0; | |||
| foreach (var step in data_handler.steps()) | |||
| { | |||
| callbacks.on_train_batch_begin(step); | |||
| logs = train_step_func(data_handler, iterator); | |||
| var end_step = step + data_handler.StepIncrement; | |||
| End_step = end_step; | |||
| callbacks.on_train_batch_end(end_step, logs); | |||
| } | |||
| if (validation_data != null) | |||
| { | |||
| var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||
| foreach (var log in val_logs) | |||
| { | |||
| logs["val_" + log.Key] = log.Value; | |||
| callbacks.on_train_batch_end(End_step, logs); | |||
| } | |||
| } | |||
| callbacks.on_epoch_end(epoch, logs); | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| if (stop_training) | |||
| { | |||
| break; | |||
| } | |||
| } | |||
| return callbacks.History; | |||
| } | |||
| } | |||
| } | |||
| @@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator) | |||
| { | |||
| var data = iterator.next(); | |||
| var outputs = train_step(data_handler, data[0], data[1]); | |||
| // whether have sample_weight | |||
| var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) : | |||
| train_step(data_handler, data[0], data[1], data[2]); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -21,7 +23,13 @@ namespace Tensorflow.Keras.Engine | |||
| { | |||
| var data = iterator.next(); | |||
| var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
| var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
| var outputs = data.Length == 2 ? | |||
| train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||
| train_step( | |||
| data_handler, | |||
| new Tensors(data.Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||
| new Tensors(data.Skip(2 * x_size).ToArray())); | |||
| tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
| return outputs; | |||
| } | |||
| @@ -61,6 +69,34 @@ namespace Tensorflow.Keras.Engine | |||
| }); | |||
| return dict; | |||
| } | |||
| Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||
| { | |||
| (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||
| using var tape = tf.GradientTape(); | |||
| var y_pred = Apply(x, training: true); | |||
| var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||
| // For custom training steps, users can just write: | |||
| // trainable_variables = self.trainable_variables | |||
| // gradients = tape.gradient(loss, trainable_variables) | |||
| // self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||
| // The _minimize call does a few extra steps unnecessary in most cases, | |||
| // such as loss scaling and gradient clipping. | |||
| _minimize(tape, optimizer, loss, TrainableVariables); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| var dict = new Dictionary<string, float>(); | |||
| metrics.ToList().ForEach(x => | |||
| { | |||
| var r = x.result(); | |||
| if (r.ndim > 0) | |||
| { | |||
| r = tf.reduce_mean(r); | |||
| } | |||
| dict[x.Name] = (float)r; | |||
| }); | |||
| return dict; | |||
| } | |||
| void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables) | |||
| { | |||
| @@ -74,8 +74,8 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
| OneHot = true, | |||
| ValidationSize = 55000, | |||
| }).Result; | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1); | |||
| var sample_weight = np.ones(((int)dataset.Train.Data.shape[0])); | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight); | |||
| } | |||
| [TestMethod] | |||