| @@ -287,5 +287,7 @@ class OfaForAllTasks(TorchModel): | |||||
| def load_ans2label(self): | def load_ans2label(self): | ||||
| if self.cfg.model.get('answer2label', None): | if self.cfg.model.get('answer2label', None): | ||||
| filename = osp.join(self.model_dir, self.cfg.model.answer2label) | |||||
| self.ans2label_dict = json.load(open(filename)) | |||||
| ans2label_file = osp.join(self.model_dir, | |||||
| self.cfg.model.answer2label) | |||||
| with open(ans2label_file, 'r') as reader: | |||||
| self.ans2label_dict = json.load(reader) | |||||
| @@ -61,7 +61,8 @@ class OfaBasePreprocessor: | |||||
| self.index2ans = {} | self.index2ans = {} | ||||
| if self.cfg.model.get('answer2label', False): | if self.cfg.model.get('answer2label', False): | ||||
| ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | ans2label_file = osp.join(model_dir, self.cfg.model.answer2label) | ||||
| ans2label_dict = json.load(open(ans2label_file, 'r')) | |||||
| with open(ans2label_file, 'r') as reader: | |||||
| ans2label_dict = json.load(reader) | |||||
| self.constraint_trie = Trie(tokenizer.eos_token_id) | self.constraint_trie = Trie(tokenizer.eos_token_id) | ||||
| for i, answer in enumerate(ans2label_dict.keys()): | for i, answer in enumerate(ans2label_dict.keys()): | ||||
| answer_item = tokenizer( | answer_item = tokenizer( | ||||
| @@ -79,7 +79,7 @@ class OFAFileDataset: | |||||
| self.total_row_count += 1 | self.total_row_count += 1 | ||||
| offset += len(line.encode('utf-8')) | offset += len(line.encode('utf-8')) | ||||
| pickle.dump(self.lineid_to_offset, | pickle.dump(self.lineid_to_offset, | ||||
| open('{}.index'.format(self.file_path), 'rb')) | |||||
| open('{}.index'.format(self.file_path), 'wb')) | |||||
| self._compute_start_pos_and_row_count() | self._compute_start_pos_and_row_count() | ||||
| print( | print( | ||||
| 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | 'local datafile {} slice_id {} finished initializing row_count and line_idx-to-offset mapping' | ||||
| @@ -1,120 +1,84 @@ | |||||
| import os | import os | ||||
| from os import path as osp | |||||
| from typing import Dict, Optional | from typing import Dict, Optional | ||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import transformers | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.utils.data.distributed import DistributedSampler | |||||
| from datasets import load_dataset | |||||
| from modelscope.metainfo import Trainers | from modelscope.metainfo import Trainers | ||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.msdatasets.ms_dataset import MsDataset | |||||
| from modelscope.preprocessors.multi_modal import OfaPreprocessor | from modelscope.preprocessors.multi_modal import OfaPreprocessor | ||||
| from modelscope.preprocessors.ofa.utils.collate import collate_fn | from modelscope.preprocessors.ofa.utils.collate import collate_fn | ||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers import EpochBasedTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | from modelscope.trainers.builder import TRAINERS | ||||
| from modelscope.trainers.optimizer.builder import build_optimizer | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModeKeys, ModelFile | from modelscope.utils.constant import ModeKeys, ModelFile | ||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import init_dist | |||||
| from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | ||||
| OFADataset, get_schedule) | OFADataset, get_schedule) | ||||
| logger = get_logger() | |||||
| @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | ||||
| class OFATrainer(BaseTrainer): | |||||
| class OFATrainer(EpochBasedTrainer): | |||||
| def __init__(self, model: str, *args, **kwargs): | def __init__(self, model: str, *args, **kwargs): | ||||
| # import pdb | |||||
| # pdb.set_trace() | |||||
| model = Model.from_pretrained(model) | model = Model.from_pretrained(model) | ||||
| super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.model_dir = model.model_dir | |||||
| self.model = model.model | |||||
| self.device_id = 0 | |||||
| self.total_epoch = self.cfg.train.epoch | |||||
| self.train_batch_size = self.cfg.train.batch_size | |||||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||||
| self.save_dir = self.cfg.train.save_dir | |||||
| init_dist(launcher='pytorch') | |||||
| self.train_dataset = OFADataset( | |||||
| file_path=self.cfg.dataset.train_set, | |||||
| selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
| preprocessor=OfaPreprocessor( | |||||
| model_dir=self.model_dir, split=ModeKeys.TRAIN), | |||||
| ) | |||||
| self.val_dataset = OFADataset( | |||||
| file_path=self.cfg.dataset.valid_set, | |||||
| selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
| preprocessor=OfaPreprocessor( | |||||
| model_dir=self.model_dir, split=ModeKeys.EVAL), | |||||
| model_dir = model.model_dir | |||||
| cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION) | |||||
| cfg = Config.from_file(cfg_file) | |||||
| dataset = load_dataset( | |||||
| cfg.dataset.script, | |||||
| data_files=cfg.dataset.hf_dataset, | |||||
| sep=cfg.dataset.sep, | |||||
| ) | ) | ||||
| epoch_steps = len( | |||||
| self.train_dataset) // self.cfg.train.gradient_accumulation_steps | |||||
| self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch | |||||
| ms_dadaset = MsDataset.from_hf_dataset(dataset) | |||||
| # train_dataset = OFADataset( | |||||
| # file_path=cfg.dataset.train_set, | |||||
| # selected_id_keys=cfg.dataset.selected_id_keys, | |||||
| # preprocessor=OfaPreprocessor( | |||||
| # model_dir=model_dir, mode=ModeKeys.TRAIN), | |||||
| # ) | |||||
| # val_dataset = OFADataset( | |||||
| # file_path=cfg.dataset.valid_set, | |||||
| # selected_id_keys=cfg.dataset.selected_id_keys, | |||||
| # preprocessor=OfaPreprocessor( | |||||
| # model_dir=model_dir, mode=ModeKeys.EVAL), | |||||
| # ) | |||||
| epoch_steps = len(ms_dadaset['train']) // ( | |||||
| cfg.train.gradient_accumulation_steps | |||||
| * cfg.train.dataloader.batch_size_per_gpu) | |||||
| cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs | |||||
| cfg.train.criterion.tokenizer = model.tokenizer | |||||
| self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | ||||
| self.cfg.train.criterion) | |||||
| def train(self, *args, **kwargs): | |||||
| assert dist.is_initialized() | |||||
| self.model.train() | |||||
| self.model.to(self.device_id) | |||||
| ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
| self.model, device_ids=[ | |||||
| self.device_id, | |||||
| ]) | |||||
| optimizer = transformers.AdamW( | |||||
| self.model.parameters(), | |||||
| lr=self.cfg.train.lr, | |||||
| weight_decay=self.cfg.train.weight_decay, | |||||
| correct_bias=False, | |||||
| ) | |||||
| scheduler_class, scheduler_args = get_schedule(self.cfg.train) | |||||
| cfg.train.criterion) | |||||
| optimizer = build_optimizer(model, cfg=cfg.train.optimizer) | |||||
| scheduler_class, scheduler_args = get_schedule(cfg.train.lr_scheduler) | |||||
| if scheduler_class is not None: | if scheduler_class is not None: | ||||
| lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | ||||
| **scheduler_args) | **scheduler_args) | ||||
| else: | else: | ||||
| lr_scheduler = None | lr_scheduler = None | ||||
| for epoch in range(self.total_epoch): | |||||
| train_sampler = DistributedSampler( | |||||
| dataset=self.train_dataset, shuffle=True) | |||||
| train_sampler.set_epoch(epoch) | |||||
| train_params = { | |||||
| 'pin_memory': True, | |||||
| 'collate_fn': collate_fn, | |||||
| 'batch_size': self.train_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': True, | |||||
| 'sampler': train_sampler, | |||||
| 'num_workers': 2, | |||||
| } | |||||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||||
| super().__init__( | |||||
| cfg_file=cfg_file, | |||||
| model=model, | |||||
| data_collator=collate_fn, | |||||
| train_dataset=dataset['train'], | |||||
| eval_dataset=dataset['valid'], | |||||
| optimizers=(optimizer, lr_scheduler), | |||||
| work_dir=cfg.train.work_dir, | |||||
| *args, | |||||
| **kwargs, | |||||
| ) | |||||
| for idx, batch in enumerate(train_loader, start=1): | |||||
| model_outputs = ddp_model(**batch) | |||||
| loss, sample_size, logging_output = self.criterion( | |||||
| model_outputs, batch) | |||||
| loss.backward() | |||||
| optimizer.zero_grad() | |||||
| if lr_scheduler is not None: | |||||
| lr_scheduler.step() | |||||
| optimizer.step() | |||||
| optimizer.zero_grad() | |||||
| if idx % 10 == 0: | |||||
| logger.info( | |||||
| 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||||
| epoch, idx, len(train_loader), loss.item())) | |||||
| if dist.get_rank() == 0: | |||||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
| torch.save(ddp_model.module.state_dict(), | |||||
| f'{self.ckpt_dir}/epoch{epoch}.bin') | |||||
| def train(self, *args, **kwargs): | |||||
| pass | |||||
| def evaluate(self, | def evaluate(self, | ||||
| checkpoint_path: Optional[str] = None, | checkpoint_path: Optional[str] = None, | ||||
| *args, | *args, | ||||
| **kwargs) -> Dict[str, float]: | **kwargs) -> Dict[str, float]: | ||||
| pass | pass | ||||
| def prediction_step(self, model, inputs): | |||||
| pass | |||||
| @@ -0,0 +1,120 @@ | |||||
| import os | |||||
| from os import path as osp | |||||
| from typing import Dict, Optional | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| import transformers | |||||
| from torch.utils.data import DataLoader | |||||
| from torch.utils.data.distributed import DistributedSampler | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.preprocessors.multi_modal import OfaPreprocessor | |||||
| from modelscope.preprocessors.ofa.utils.collate import collate_fn | |||||
| from modelscope.trainers.base import BaseTrainer | |||||
| from modelscope.trainers.builder import TRAINERS | |||||
| from modelscope.utils.constant import ModeKeys, ModelFile | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import init_dist | |||||
| from .ofa_trainer_utils import (AdjustLabelSmoothedCrossEntropyCriterion, | |||||
| OFADataset, get_schedule) | |||||
| logger = get_logger() | |||||
| @TRAINERS.register_module(module_name=Trainers.ofa_tasks) | |||||
| class OFAOldTrainer(BaseTrainer): | |||||
| def __init__(self, model: str, *args, **kwargs): | |||||
| model = Model.from_pretrained(model) | |||||
| super().__init__(osp.join(model.model_dir, ModelFile.CONFIGURATION)) | |||||
| self.model_dir = model.model_dir | |||||
| self.model = model.model | |||||
| self.device_id = 0 | |||||
| self.total_epoch = self.cfg.train.epoch | |||||
| self.train_batch_size = self.cfg.train.batch_size | |||||
| self.val_batch_size = self.cfg.evaluation.batch_size | |||||
| self.save_dir = self.cfg.train.save_dir | |||||
| init_dist(launcher='pytorch') | |||||
| self.train_dataset = OFADataset( | |||||
| file_path=self.cfg.dataset.train_set, | |||||
| selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
| preprocessor=OfaPreprocessor( | |||||
| model_dir=self.model_dir, split=ModeKeys.TRAIN), | |||||
| ) | |||||
| self.val_dataset = OFADataset( | |||||
| file_path=self.cfg.dataset.valid_set, | |||||
| selected_id_keys=self.cfg.dataset.selected_id_keys, | |||||
| preprocessor=OfaPreprocessor( | |||||
| model_dir=self.model_dir, split=ModeKeys.EVAL), | |||||
| ) | |||||
| epoch_steps = len( | |||||
| self.train_dataset) // self.cfg.train.gradient_accumulation_steps | |||||
| self.cfg.train.num_train_steps = epoch_steps * self.cfg.train.epoch | |||||
| self.criterion = AdjustLabelSmoothedCrossEntropyCriterion( | |||||
| self.cfg.train.criterion) | |||||
| def train(self, *args, **kwargs): | |||||
| assert dist.is_initialized() | |||||
| self.model.train() | |||||
| self.model.to(self.device_id) | |||||
| ddp_model = torch.nn.parallel.DistributedDataParallel( | |||||
| self.model, device_ids=[ | |||||
| self.device_id, | |||||
| ]) | |||||
| optimizer = transformers.AdamW( | |||||
| self.model.parameters(), | |||||
| lr=self.cfg.train.lr, | |||||
| weight_decay=self.cfg.train.weight_decay, | |||||
| correct_bias=False, | |||||
| ) | |||||
| scheduler_class, scheduler_args = get_schedule(self.cfg.train) | |||||
| if scheduler_class is not None: | |||||
| lr_scheduler = scheduler_class(**{'optimizer': optimizer}, | |||||
| **scheduler_args) | |||||
| else: | |||||
| lr_scheduler = None | |||||
| for epoch in range(self.total_epoch): | |||||
| train_sampler = DistributedSampler( | |||||
| dataset=self.train_dataset, shuffle=True) | |||||
| train_sampler.set_epoch(epoch) | |||||
| train_params = { | |||||
| 'pin_memory': True, | |||||
| 'collate_fn': collate_fn, | |||||
| 'batch_size': self.train_batch_size, | |||||
| 'shuffle': False, | |||||
| 'drop_last': True, | |||||
| 'sampler': train_sampler, | |||||
| 'num_workers': 2, | |||||
| } | |||||
| train_loader = DataLoader(self.train_dataset, **train_params) | |||||
| for idx, batch in enumerate(train_loader, start=1): | |||||
| model_outputs = ddp_model(**batch) | |||||
| loss, sample_size, logging_output = self.criterion( | |||||
| model_outputs, batch) | |||||
| loss.backward() | |||||
| optimizer.zero_grad() | |||||
| if lr_scheduler is not None: | |||||
| lr_scheduler.step() | |||||
| optimizer.step() | |||||
| optimizer.zero_grad() | |||||
| if idx % 10 == 0: | |||||
| logger.info( | |||||
| 'epoch: {}, train batch {}/{}, loss={:.5f}'.format( | |||||
| epoch, idx, len(train_loader), loss.item())) | |||||
| if dist.get_rank() == 0: | |||||
| os.makedirs(self.ckpt_dir, exist_ok=True) | |||||
| torch.save(ddp_model.module.state_dict(), | |||||
| f'{self.ckpt_dir}/epoch{epoch}.bin') | |||||
| def evaluate(self, | |||||
| checkpoint_path: Optional[str] = None, | |||||
| *args, | |||||
| **kwargs) -> Dict[str, float]: | |||||
| pass | |||||
| @@ -35,7 +35,7 @@ class OFADataset(Dataset): | |||||
| self.dataset = OFAFileDataset( | self.dataset = OFAFileDataset( | ||||
| file_path=file_path, | file_path=file_path, | ||||
| selected_col_ids=selected_col_ids, | |||||
| selected_col_ids=','.join(selected_col_ids), | |||||
| dtypes=dtypes, | dtypes=dtypes, | ||||
| separator=separator, | separator=separator, | ||||
| cached_index=cached_index) | cached_index=cached_index) | ||||
| @@ -157,7 +157,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| self.constraint_start = None | self.constraint_start = None | ||||
| self.constraint_end = None | self.constraint_end = None | ||||
| if args.constraint_range is not None: | |||||
| if args.constraint_range: | |||||
| constraint_start, constraint_end = args.constraint_range.split(',') | constraint_start, constraint_end = args.constraint_range.split(',') | ||||
| self.constraint_start = int(constraint_start) | self.constraint_start = int(constraint_start) | ||||
| self.constraint_end = int(constraint_end) | self.constraint_end = int(constraint_end) | ||||
| @@ -280,35 +280,39 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): | |||||
| return loss, nll_loss, ntokens | return loss, nll_loss, ntokens | ||||
| def get_schedule(args): | |||||
| def get_schedule(scheduler): | |||||
| if args.schedule == 'const': | |||||
| if scheduler.name == 'const': | |||||
| scheduler_class = transformers.get_constant_schedule_with_warmup | scheduler_class = transformers.get_constant_schedule_with_warmup | ||||
| scheduler_args = { | scheduler_args = { | ||||
| 'num_warmup_steps': | 'num_warmup_steps': | ||||
| int(args.warmup_proportion * args.num_train_steps) | |||||
| int(scheduler.warmup_proportion * scheduler.num_train_steps) | |||||
| } | } | ||||
| elif args.schedule == 'linear': | |||||
| elif scheduler.name == 'linear': | |||||
| scheduler_class = transformers.get_linear_schedule_with_warmup | scheduler_class = transformers.get_linear_schedule_with_warmup | ||||
| scheduler_args = { | scheduler_args = { | ||||
| 'num_warmup_steps': | 'num_warmup_steps': | ||||
| int(args.warmup_proportion * args.num_train_steps), | |||||
| 'num_training_steps': args.num_train_steps | |||||
| int(scheduler.warmup_proportion * scheduler.num_train_steps), | |||||
| 'num_training_steps': | |||||
| scheduler.num_train_steps | |||||
| } | } | ||||
| elif args.schedule == 'cosine': | |||||
| elif scheduler.name == 'cosine': | |||||
| scheduler_class = transformers.get_cosine_schedule_with_warmup | scheduler_class = transformers.get_cosine_schedule_with_warmup | ||||
| scheduler_args = { | scheduler_args = { | ||||
| 'num_warmup_steps': | 'num_warmup_steps': | ||||
| int(args.warmup_proportion * args.num_train_steps), | |||||
| 'num_training_steps': args.num_train_steps | |||||
| int(scheduler.warmup_proportion * scheduler.num_train_steps), | |||||
| 'num_training_steps': | |||||
| scheduler.num_train_steps | |||||
| } | } | ||||
| elif args.schedule == 'polynomial_decay': | |||||
| elif scheduler.name == 'polynomial_decay': | |||||
| scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup | scheduler_class = transformers.get_polynomial_decay_schedule_with_warmup | ||||
| scheduler_args = { | scheduler_args = { | ||||
| 'num_warmup_steps': | 'num_warmup_steps': | ||||
| int(args.warmup_proportion * args.num_train_steps), | |||||
| 'num_training_steps': args.num_train_steps, | |||||
| 'lr_end': args.lr_end | |||||
| int(scheduler.warmup_proportion * scheduler.num_train_steps), | |||||
| 'num_training_steps': | |||||
| scheduler.num_train_steps, | |||||
| 'lr_end': | |||||
| scheduler.lr_end | |||||
| } | } | ||||
| else: | else: | ||||
| raise NotImplementedError | raise NotImplementedError | ||||
| @@ -1,4 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import shutil | import shutil | ||||
| import unittest | import unittest | ||||
| @@ -13,7 +14,8 @@ class TestOfaTrainer(unittest.TestCase): | |||||
| model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | model_id = '/apsarapangu/disk2/yichang.zyc/ckpt/MaaS/ofa_text-classification_mnli_large_en' | ||||
| self.trainer = OFATrainer(model_id) | self.trainer = OFATrainer(model_id) | ||||
| self.trainer.train() | self.trainer.train() | ||||
| shutil.rmtree(self.trainer.save_dir) | |||||
| if os.path.exists(self.trainer.work_dir): | |||||
| shutil.rmtree(self.trainer.work_dir) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||