|
|
|
@@ -79,11 +79,15 @@ def _train(model, config: TransformerConfig, |
|
|
|
|
|
|
|
if pre_training_dataset is not None: |
|
|
|
print(" | Start pre-training job.") |
|
|
|
epoch_size = pre_training_dataset.get_repeat_count() |
|
|
|
epoch_size = config.epochs * pre_training_dataset.get_dataset_size() // config.dataset_sink_step |
|
|
|
|
|
|
|
if os.getenv("RANK_SIZE") is not None and int(os.getenv("RANK_SIZE")) > 1: |
|
|
|
print(f" | Rank {MultiAscend.get_rank()} Call model train.") |
|
|
|
|
|
|
|
model.train(epoch_size, pre_training_dataset, |
|
|
|
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode) |
|
|
|
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, |
|
|
|
sink_size=config.dataset_sink_step) |
|
|
|
|
|
|
|
# Test the accuracy of the model. |
|
|
|
if test_dataset is not None: |
|
|
|
print(" | Start test job.") |
|
|
|
@@ -93,10 +97,11 @@ def _train(model, config: TransformerConfig, |
|
|
|
|
|
|
|
if fine_tune_dataset is not None: |
|
|
|
print(" | Start fine-tuning job.") |
|
|
|
epoch_size = fine_tune_dataset.get_repeat_count() |
|
|
|
epoch_size = config.epochs * fine_tune_dataset.get_dataset_size() // config.dataset_sink_step |
|
|
|
|
|
|
|
model.train(epoch_size, fine_tune_dataset, |
|
|
|
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode) |
|
|
|
callbacks=callbacks, dataset_sink_mode=config.dataset_sink_mode, |
|
|
|
sink_size=config.dataset_sink_step) |
|
|
|
|
|
|
|
# Test the accuracy of the model. |
|
|
|
if test_dataset is not None: |
|
|
|
|