观察到的现象是,一些模型增大batchsize后,会在首个epoch的中途爆显存不足,只要过了一个epoch后,就能完整训练。同样的batchsize在python下能设置大得多的值。 最后使用最小训练代码分析出,是每个step之后,图片加载到显存里的数据没有释放导致的。 在寻找释放显存接口没有结果的时候,直接使用了GC.Collect();可以让显存主动回收。 因此当前的修复方案是在每个step里,都执行一次 GC.Collect(); 用来释放显存资源。tags/v0.150.0-BERT-Model
| @@ -24,6 +24,7 @@ public interface IModel : ILayer | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| ValidationDataPack validation_data = null, | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| @@ -47,6 +48,20 @@ public interface IModel : ILayer | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| public ICallback fit(IDatasetV2 dataset, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| List<ICallback> callbacks = null, | |||
| IDatasetV2 validation_data = null, | |||
| int validation_step = 10, // 间隔多少次会进行一次验证 | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| void save(string filepath, | |||
| bool overwrite = true, | |||
| bool include_optimizer = true, | |||
| @@ -85,6 +100,14 @@ public interface IModel : ILayer | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| public Tensors predict(IDatasetV2 dataset, | |||
| int batch_size = -1, | |||
| int verbose = 0, | |||
| int steps = -1, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| bool use_multiprocessing = false); | |||
| void summary(int line_length = -1, float[] positions = null); | |||
| IKerasConfig get_config(); | |||
| @@ -132,6 +132,7 @@ namespace Tensorflow.Keras.Engine | |||
| var end_step = step + data_handler.StepIncrement; | |||
| if (!is_val) | |||
| callbacks.on_test_batch_end(end_step, logs); | |||
| GC.Collect(); | |||
| } | |||
| } | |||
| callbacks.on_test_end(logs); | |||
| @@ -167,7 +168,9 @@ namespace Tensorflow.Keras.Engine | |||
| Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||
| { | |||
| (x,y) = data_handler.DataAdapter.Expand1d(x, y); | |||
| var y_pred = Apply(x, training: false); | |||
| var loss = compiled_loss.Call(y, y_pred); | |||
| compiled_metrics.update_state(y, y_pred); | |||
| return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||
| @@ -41,6 +41,7 @@ namespace Tensorflow.Keras.Engine | |||
| List<ICallback> callbacks = null, | |||
| float validation_split = 0f, | |||
| ValidationDataPack validation_data = null, | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| @@ -147,7 +148,7 @@ namespace Tensorflow.Keras.Engine | |||
| } | |||
| } | |||
| public History fit(IDatasetV2 dataset, | |||
| public ICallback fit(IDatasetV2 dataset, | |||
| int batch_size = -1, | |||
| int epochs = 1, | |||
| int verbose = 1, | |||
| @@ -156,7 +157,6 @@ namespace Tensorflow.Keras.Engine | |||
| int validation_step = 10, | |||
| bool shuffle = true, | |||
| Dictionary<int, float> class_weight = null, | |||
| NDArray sample_weight = null, | |||
| int initial_epoch = 0, | |||
| int max_queue_size = 10, | |||
| int workers = 1, | |||
| @@ -170,7 +170,7 @@ namespace Tensorflow.Keras.Engine | |||
| InitialEpoch = initial_epoch, | |||
| Epochs = epochs, | |||
| Shuffle = shuffle, | |||
| SampleWeight = sample_weight, | |||
| ClassWeight = class_weight, | |||
| MaxQueueSize = max_queue_size, | |||
| Workers = workers, | |||
| UseMultiprocessing = use_multiprocessing, | |||
| @@ -218,6 +218,7 @@ namespace Tensorflow.Keras.Engine | |||
| var end_step = step + data_handler.StepIncrement; | |||
| End_step = end_step; | |||
| callbacks.on_train_batch_end(end_step, logs); | |||
| GC.Collect(); | |||
| } | |||
| if (validation_data != null) | |||
| @@ -233,11 +234,10 @@ namespace Tensorflow.Keras.Engine | |||
| callbacks.on_train_batch_end(End_step, logs); | |||
| } | |||
| GC.Collect(); | |||
| callbacks.on_epoch_end(epoch, logs); | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| if (stop_training) | |||
| { | |||
| break; | |||
| @@ -282,6 +282,7 @@ namespace Tensorflow.Keras.Engine | |||
| var end_step = step + data_handler.StepIncrement; | |||
| End_step = end_step; | |||
| callbacks.on_train_batch_end(end_step, logs); | |||
| GC.Collect(); | |||
| } | |||
| if (validation_data != null) | |||
| @@ -301,7 +302,6 @@ namespace Tensorflow.Keras.Engine | |||
| callbacks.on_epoch_end(epoch, logs); | |||
| GC.Collect(); | |||
| GC.WaitForPendingFinalizers(); | |||
| if (stop_training) | |||
| { | |||
| break; | |||
| @@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Engine | |||
| for (int i = 0; i < batch_outputs.Length; i++) | |||
| batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0); | |||
| } | |||
| var end_step = step + data_handler.StepIncrement; | |||
| callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); | |||
| GC.Collect(); | |||
| } | |||
| } | |||