|
|
|
@@ -1,6 +1,6 @@ |
|
|
|
import math |
|
|
|
import os |
|
|
|
from functools import partial |
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
|
from torch import distributed as dist |
|
|
|
@@ -27,13 +27,7 @@ class OFATrainer(EpochBasedTrainer): |
|
|
|
model_dir = model.model_dir |
|
|
|
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) |
|
|
|
cfg = Config.from_file(cfg_file) |
|
|
|
dataset = load_dataset( |
|
|
|
cfg.dataset.script, |
|
|
|
data_files=cfg.dataset.hf_dataset, |
|
|
|
sep=cfg.dataset.sep, |
|
|
|
) |
|
|
|
dataset = MsDataset.from_hf_dataset( |
|
|
|
dataset.rename_columns(cfg.dataset.column_map)) |
|
|
|
dataset = self._build_dataset_with_config(cfg) |
|
|
|
preprocessor = { |
|
|
|
ConfigKeys.train: |
|
|
|
OfaPreprocessor( |
|
|
|
@@ -42,9 +36,11 @@ class OFATrainer(EpochBasedTrainer): |
|
|
|
OfaPreprocessor( |
|
|
|
model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True), |
|
|
|
} |
|
|
|
epoch_steps = len(dataset['train']) // ( |
|
|
|
cfg.train.optimizer_hook.cumulative_iters |
|
|
|
* cfg.train.dataloader.batch_size_per_gpu) |
|
|
|
# use torchrun launch |
|
|
|
world_size = int(os.environ.get('WORLD_SIZE', 1)) |
|
|
|
epoch_steps = math.ceil( |
|
|
|
len(dataset['train']) / # noqa |
|
|
|
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa |
|
|
|
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs |
|
|
|
cfg.train.criterion.tokenizer = model.tokenizer |
|
|
|
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( |
|
|
|
@@ -104,3 +100,24 @@ class OFATrainer(EpochBasedTrainer): |
|
|
|
else: |
|
|
|
self.log_buffer.update(train_outputs['log_vars']) |
|
|
|
self.train_outputs = train_outputs |
|
|
|
|
|
|
|
def _build_dataset_with_config(self, cfg): |
|
|
|
if hasattr(cfg.dataset, 'hf_dataset'): |
|
|
|
dataset = load_dataset( |
|
|
|
cfg.dataset.script, |
|
|
|
data_files=cfg.dataset.hf_dataset, |
|
|
|
sep=cfg.dataset.sep, |
|
|
|
) |
|
|
|
dataset = MsDataset.from_hf_dataset( |
|
|
|
dataset.rename_columns(cfg.dataset.column_map)) |
|
|
|
return dataset |
|
|
|
elif hasattr(cfg.dataset, 'ms_dataset'): |
|
|
|
dataset_d = dict() |
|
|
|
for key in cfg.dataset.ms_dataset.keys(): |
|
|
|
dataset_d[key] = MsDataset.load(**cfg.dataset.ms_dataset[key]) |
|
|
|
dataset_d[key] = MsDataset.from_hf_dataset( |
|
|
|
dataset_d[key]._hf_ds.rename_columns( |
|
|
|
cfg.dataset.column_map)) |
|
|
|
return dataset_d |
|
|
|
else: |
|
|
|
raise NotImplementedError |