feat: add the implementation of class_weight in model.fittags/v0.150.0-BERT-Model
| @@ -3,6 +3,8 @@ using System.Collections.Generic; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.Util; | |||||
| using Tensorflow.Framework; | |||||
| namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
| { | { | ||||
| @@ -24,6 +26,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| long _steps_per_execution_value; | long _steps_per_execution_value; | ||||
| int _initial_epoch => args.InitialEpoch; | int _initial_epoch => args.InitialEpoch; | ||||
| int _epochs => args.Epochs; | int _epochs => args.Epochs; | ||||
| NDArray _sample_weight => args.SampleWeight; | |||||
| IVariableV1 _steps_per_execution; | IVariableV1 _steps_per_execution; | ||||
| public DataHandler(DataHandlerArgs args) | public DataHandler(DataHandlerArgs args) | ||||
| @@ -75,10 +78,75 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| } | } | ||||
| _dataset = _adapter.GetDataset(); | _dataset = _adapter.GetDataset(); | ||||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
| _current_step = 0; | _current_step = 0; | ||||
| _step_increment = _steps_per_execution_value - 1; | _step_increment = _steps_per_execution_value - 1; | ||||
| _insufficient_data = false; | _insufficient_data = false; | ||||
| _configure_dataset_and_inferred_steps(args.X, args.ClassWeight); | |||||
| } | |||||
| void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight) | |||||
| { | |||||
| if (_dataset == null) | |||||
| { | |||||
| _dataset = _adapter.GetDataset(); | |||||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
| } | |||||
| if (class_weight != null) | |||||
| { | |||||
| _dataset = _dataset.map(_make_class_weight_map_fn(class_weight)); | |||||
| } | |||||
| _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
| } | |||||
| Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight) | |||||
| { | |||||
| var class_ids = class_weight.Keys.OrderBy(key => key).ToList(); | |||||
| var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1); | |||||
| if (!class_ids.SequenceEqual(expected_class_ids)) | |||||
| { | |||||
| throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+ | |||||
| $"than the number of classes, found {class_weight}"); | |||||
| } | |||||
| var class_weight_list = new List<float>(); | |||||
| foreach (var class_id in class_ids) | |||||
| { | |||||
| class_weight_list.Add(class_weight[class_id]); | |||||
| } | |||||
| var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray()); | |||||
| Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) => | |||||
| { | |||||
| var x = data[0]; | |||||
| var y = data[1]; | |||||
| var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight); | |||||
| if (y.shape.rank > 2) | |||||
| { | |||||
| throw new ValueError("`class_weight` not supported for 3+ dimensional targets."); | |||||
| } | |||||
| var y_classes = smart_module.smart_cond( | |||||
| y.shape.rank == 2 && y.shape[1] > 1, | |||||
| () => math_ops.argmax(y, dimension: 1), | |||||
| () => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64)); | |||||
| var cw = array_ops.gather(class_weight_tensor, y_classes); | |||||
| if (sw != null) | |||||
| { | |||||
| cw = tf.cast(cw, sw.dtype); | |||||
| cw *= sw; | |||||
| } | |||||
| else | |||||
| { | |||||
| sw = cw; | |||||
| } | |||||
| return new Tensors { x, y, sw }; | |||||
| }; | |||||
| return _class_weight_map_fn; | |||||
| } | } | ||||
| long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | ||||
| @@ -164,11 +164,20 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||||
| { | |||||
| (x,y) = data_handler.DataAdapter.Expand1d(x, y); | |||||
| var y_pred = Apply(x, training: false); | |||||
| var loss = compiled_loss.Call(y, y_pred); | |||||
| compiled_metrics.update_state(y, y_pred); | |||||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||||
| } | |||||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight) | |||||
| { | { | ||||
| (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | ||||
| var y_pred = Apply(x, training: false); | var y_pred = Apply(x, training: false); | ||||
| var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||||
| var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight); | |||||
| compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | ||||
| } | } | ||||
| @@ -63,12 +63,6 @@ namespace Tensorflow.Keras.Engine | |||||
| ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | ||||
| } | } | ||||
| // TODO(Wanglongzhi2001) | |||||
| if (class_weight != null) | |||||
| { | |||||
| throw new NotImplementedException("class_weight is not implemented"); | |||||
| } | |||||
| var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
| { | { | ||||
| X = x, | X = x, | ||||
| @@ -78,6 +72,7 @@ namespace Tensorflow.Keras.Engine | |||||
| InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
| Epochs = epochs, | Epochs = epochs, | ||||
| Shuffle = shuffle, | Shuffle = shuffle, | ||||
| ClassWeight = class_weight, | |||||
| MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
| Workers = workers, | Workers = workers, | ||||
| UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
| @@ -126,11 +121,12 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| X = new Tensors(x.ToArray()), | X = new Tensors(x.ToArray()), | ||||
| Y = y, | Y = y, | ||||
| SampleWeight = sample_weight, | |||||
| BatchSize = batch_size, | BatchSize = batch_size, | ||||
| InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
| Epochs = epochs, | Epochs = epochs, | ||||
| Shuffle = shuffle, | Shuffle = shuffle, | ||||
| SampleWeight = sample_weight, | |||||
| ClassWeight = class_weight, | |||||
| MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
| Workers = workers, | Workers = workers, | ||||
| UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
| @@ -174,6 +170,7 @@ namespace Tensorflow.Keras.Engine | |||||
| InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
| Epochs = epochs, | Epochs = epochs, | ||||
| Shuffle = shuffle, | Shuffle = shuffle, | ||||
| SampleWeight = sample_weight, | |||||
| MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
| Workers = workers, | Workers = workers, | ||||
| UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||