1. add child-tuning optimizer and ut
2. fix a training bug which can cause interruption after cross-evaluation
3. move model.params from cfg to default args in build_optimizer to prevent the saving of params in save_pretrained
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9891963
master
| @@ -1,4 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| from .builder import OPTIMIZERS, build_optimizer | from .builder import OPTIMIZERS, build_optimizer | ||||
| from .child_tuning_adamw_optimizer import ChildTuningAdamW | |||||
| __all__ = ['OPTIMIZERS', 'build_optimizer'] | |||||
| __all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] | |||||
| @@ -20,7 +20,10 @@ def build_optimizer(model: torch.nn.Module, | |||||
| """ | """ | ||||
| if hasattr(model, 'module'): | if hasattr(model, 'module'): | ||||
| model = model.module | model = model.module | ||||
| cfg.params = model.parameters() | |||||
| if default_args is None: | |||||
| default_args = {} | |||||
| default_args['params'] = model.parameters() | |||||
| return build_from_cfg( | return build_from_cfg( | ||||
| cfg, OPTIMIZERS, group_key=default_group, default_args=default_args) | cfg, OPTIMIZERS, group_key=default_group, default_args=default_args) | ||||
| @@ -0,0 +1,188 @@ | |||||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||||
| # All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import math | |||||
| import types | |||||
| from typing import Callable, Iterable, Tuple | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch.distributions.bernoulli import Bernoulli | |||||
| from torch.optim import Optimizer | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .builder import OPTIMIZERS, default_group | |||||
| logger = get_logger(__name__) | |||||
| __all__ = ['calculate_fisher', 'ChildTuningAdamW'] | |||||
| def calculate_fisher(model: torch.nn.Module, | |||||
| data_loader, | |||||
| forward_step, | |||||
| reserve_p, | |||||
| grad_clip=None): | |||||
| gradient_mask = dict() | |||||
| model.train() | |||||
| for name, params in model.named_parameters(): | |||||
| if 'layer' in name: | |||||
| gradient_mask[params] = params.new_zeros(params.size()) | |||||
| iters = len(data_loader) | |||||
| for inputs in data_loader: | |||||
| loss = forward_step(model, inputs) | |||||
| loss.backward() | |||||
| for name, params in model.named_parameters(): | |||||
| if 'layer' in name: | |||||
| if grad_clip is not None: | |||||
| torch.nn.utils.clip_grad_norm_(params, **grad_clip) | |||||
| gradient_mask[params] += (params.grad**2) / iters | |||||
| model.zero_grad() | |||||
| logger.info('Calculate Fisher Information...') | |||||
| # Numpy | |||||
| r = None | |||||
| for k, v in gradient_mask.items(): | |||||
| v = v.view(-1).cpu().numpy() | |||||
| if r is None: | |||||
| r = v | |||||
| else: | |||||
| r = np.append(r, v) | |||||
| polar = np.percentile(r, (1 - reserve_p) * 100) | |||||
| for k in gradient_mask: | |||||
| gradient_mask[k] = gradient_mask[k] >= polar | |||||
| print('Polar => {}'.format(polar)) | |||||
| # TODO: pytorch: torch.kthvalue | |||||
| return gradient_mask | |||||
| @OPTIMIZERS.register_module( | |||||
| group_key=default_group, module_name='ChildTuningAdamW') | |||||
| class ChildTuningAdamW(Optimizer): | |||||
| def __init__(self, | |||||
| params: Iterable[torch.nn.parameter.Parameter], | |||||
| lr: float = 1e-3, | |||||
| betas: Tuple[float, float] = (0.9, 0.999), | |||||
| eps: float = 1e-6, | |||||
| weight_decay: float = 0.0, | |||||
| correct_bias: bool = True, | |||||
| reserve_p=1.0, | |||||
| mode=None): | |||||
| if lr < 0.0: | |||||
| raise ValueError( | |||||
| 'Invalid learning rate: {} - should be >= 0.0'.format(lr)) | |||||
| if not 0.0 <= betas[0] < 1.0: | |||||
| raise ValueError( | |||||
| 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( | |||||
| betas[0])) | |||||
| if not 0.0 <= betas[1] < 1.0: | |||||
| raise ValueError( | |||||
| 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( | |||||
| betas[1])) | |||||
| if not 0.0 <= eps: | |||||
| raise ValueError( | |||||
| 'Invalid epsilon value: {} - should be >= 0.0'.format(eps)) | |||||
| defaults = dict( | |||||
| lr=lr, | |||||
| betas=betas, | |||||
| eps=eps, | |||||
| weight_decay=weight_decay, | |||||
| correct_bias=correct_bias) | |||||
| super().__init__(params, defaults) | |||||
| self.gradient_mask = None | |||||
| self.reserve_p = reserve_p | |||||
| self.mode = mode | |||||
| def set_gradient_mask(self, gradient_mask): | |||||
| self.gradient_mask = gradient_mask | |||||
| def step(self, closure: Callable = None): | |||||
| """ | |||||
| Performs a single optimization step. | |||||
| Arguments: | |||||
| closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. | |||||
| """ | |||||
| loss = None | |||||
| if closure is not None: | |||||
| loss = closure() | |||||
| for group in self.param_groups: | |||||
| for p in group['params']: | |||||
| if p.grad is None: | |||||
| continue | |||||
| grad = p.grad.data | |||||
| if grad.is_sparse: | |||||
| raise RuntimeError( | |||||
| 'Adam does not support sparse gradients, please consider SparseAdam instead' | |||||
| ) | |||||
| # ChildTuning code | |||||
| if self.mode is not None: | |||||
| if self.mode == 'ChildTuning-D': | |||||
| if p in self.gradient_mask: | |||||
| grad *= self.gradient_mask[p] | |||||
| else: | |||||
| # ChildTuning-F | |||||
| grad_mask = Bernoulli( | |||||
| grad.new_full( | |||||
| size=grad.size(), fill_value=self.reserve_p)) | |||||
| grad *= grad_mask.sample() / self.reserve_p | |||||
| state = self.state[p] | |||||
| # State initialization | |||||
| if len(state) == 0: | |||||
| state['step'] = 0 | |||||
| # Exponential moving average of gradient values | |||||
| state['exp_avg'] = torch.zeros_like(p.data) | |||||
| # Exponential moving average of squared gradient values | |||||
| state['exp_avg_sq'] = torch.zeros_like(p.data) | |||||
| exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |||||
| beta1, beta2 = group['betas'] | |||||
| state['step'] += 1 | |||||
| # Decay the first and second moment running average coefficient | |||||
| # In-place operations to update the averages at the same time | |||||
| exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) | |||||
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) | |||||
| denom = exp_avg_sq.sqrt().add_(group['eps']) | |||||
| step_size = group['lr'] | |||||
| if group['correct_bias']: # No bias correction for Bert | |||||
| bias_correction1 = 1.0 - beta1**state['step'] | |||||
| bias_correction2 = 1.0 - beta2**state['step'] | |||||
| step_size = step_size * math.sqrt( | |||||
| bias_correction2) / bias_correction1 | |||||
| p.data.addcdiv_(exp_avg, denom, value=-step_size) | |||||
| # Just adding the square of the weights to the loss function is *not* | |||||
| # the correct way of using L2 regularization/weight decay with Adam, | |||||
| # since that will interact with the m and v parameters in strange ways. | |||||
| # | |||||
| # Instead we want to decay the weights in a manner that doesn't interact | |||||
| # with the m/v parameters. This is equivalent to adding the square | |||||
| # of the weights to the loss with plain (non-momentum) SGD. | |||||
| # Add weight decay at the end (fixed version) | |||||
| p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay']) | |||||
| return loss | |||||
| @@ -800,6 +800,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.invoke_hook(TrainerStages.after_train_iter) | self.invoke_hook(TrainerStages.after_train_iter) | ||||
| del self.data_batch | del self.data_batch | ||||
| self._iter += 1 | self._iter += 1 | ||||
| self._mode = ModeKeys.TRAIN | |||||
| if i + 1 >= self.iters_per_epoch: | if i + 1 >= self.iters_per_epoch: | ||||
| break | break | ||||
| @@ -6,9 +6,15 @@ import unittest | |||||
| from modelscope.metainfo import Preprocessors, Trainers | from modelscope.metainfo import Preprocessors, Trainers | ||||
| from modelscope.models import Model | from modelscope.models import Model | ||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.pipelines import pipeline | from modelscope.pipelines import pipeline | ||||
| from modelscope.trainers import build_trainer | from modelscope.trainers import build_trainer | ||||
| from modelscope.trainers.hooks import Hook | |||||
| from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer | |||||
| from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ | |||||
| calculate_fisher | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.data_utils import to_device | |||||
| class TestFinetuneSequenceClassification(unittest.TestCase): | class TestFinetuneSequenceClassification(unittest.TestCase): | ||||
| @@ -69,6 +75,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| @unittest.skip | @unittest.skip | ||||
| def test_finetune_afqmc(self): | def test_finetune_afqmc(self): | ||||
| """This unittest is used to reproduce the clue:afqmc dataset + structbert model training results. | |||||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||||
| """ | |||||
| def cfg_modify_fn(cfg): | def cfg_modify_fn(cfg): | ||||
| cfg.task = Tasks.sentence_similarity | cfg.task = Tasks.sentence_similarity | ||||
| @@ -114,7 +124,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| dc.local_files_only = True | dc.local_files_only = True | ||||
| dataset = load_dataset('clue', 'afqmc', download_config=dc) | dataset = load_dataset('clue', 'afqmc', download_config=dc) | ||||
| self.finetune( | self.finetune( | ||||
| model_id='damo/nlp_structbert_backbone_tiny_std', | |||||
| model_id='damo/nlp_structbert_backbone_base_std', | |||||
| train_dataset=dataset['train'], | train_dataset=dataset['train'], | ||||
| eval_dataset=dataset['validation'], | eval_dataset=dataset['validation'], | ||||
| cfg_modify_fn=cfg_modify_fn) | cfg_modify_fn=cfg_modify_fn) | ||||
| @@ -124,6 +134,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| @unittest.skip | @unittest.skip | ||||
| def test_finetune_tnews(self): | def test_finetune_tnews(self): | ||||
| """This unittest is used to reproduce the clue:tnews dataset + structbert model training results. | |||||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||||
| """ | |||||
| def cfg_modify_fn(cfg): | def cfg_modify_fn(cfg): | ||||
| # TODO no proper task for tnews | # TODO no proper task for tnews | ||||
| @@ -175,13 +189,21 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| dataset = load_dataset('clue', 'tnews', download_config=dc) | dataset = load_dataset('clue', 'tnews', download_config=dc) | ||||
| self.finetune( | self.finetune( | ||||
| model_id='damo/nlp_structbert_backbone_tiny_std', | |||||
| model_id='damo/nlp_structbert_backbone_base_std', | |||||
| train_dataset=dataset['train'], | train_dataset=dataset['train'], | ||||
| eval_dataset=dataset['validation'], | eval_dataset=dataset['validation'], | ||||
| cfg_modify_fn=cfg_modify_fn) | cfg_modify_fn=cfg_modify_fn) | ||||
| @unittest.skip | @unittest.skip | ||||
| def test_veco_xnli(self): | def test_veco_xnli(self): | ||||
| """This unittest is used to reproduce the xnli dataset + veco model training results. | |||||
| Here we follow the training scenario listed in the Alicemind open source project: | |||||
| https://github.com/alibaba/AliceMind/tree/main/VECO | |||||
| by training the english language subset. | |||||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||||
| """ | |||||
| from datasets import load_dataset | from datasets import load_dataset | ||||
| langs = ['en'] | langs = ['en'] | ||||
| langs_eval = ['en'] | langs_eval = ['en'] | ||||
| @@ -267,6 +289,112 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||||
| name=Trainers.nlp_veco_trainer, | name=Trainers.nlp_veco_trainer, | ||||
| cfg_modify_fn=cfg_modify_fn) | cfg_modify_fn=cfg_modify_fn) | ||||
| @unittest.skip | |||||
| def test_finetune_cluewsc(self): | |||||
| """This unittest is used to reproduce the clue:wsc dataset + structbert model training results. | |||||
| A runnable sample of child-tuning is also showed here. | |||||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||||
| """ | |||||
| child_tuning_type = 'ChildTuning-F' | |||||
| mode = {} | |||||
| if child_tuning_type is not None: | |||||
| mode = {'mode': child_tuning_type, 'reserve_p': 0.2} | |||||
| def cfg_modify_fn(cfg): | |||||
| cfg.task = 'nli' | |||||
| cfg['preprocessor'] = {'type': 'nli-tokenizer'} | |||||
| cfg['dataset'] = { | |||||
| 'train': { | |||||
| 'labels': ['0', '1'], | |||||
| 'first_sequence': 'text', | |||||
| 'second_sequence': 'text2', | |||||
| 'label': 'label', | |||||
| } | |||||
| } | |||||
| cfg.train.dataloader.batch_size_per_gpu = 16 | |||||
| cfg.train.max_epochs = 30 | |||||
| cfg.train.optimizer = { | |||||
| 'type': | |||||
| 'AdamW' if child_tuning_type is None else 'ChildTuningAdamW', | |||||
| 'lr': 1e-5, | |||||
| 'options': {}, | |||||
| **mode, | |||||
| } | |||||
| cfg.train.lr_scheduler = { | |||||
| 'type': | |||||
| 'LinearLR', | |||||
| 'start_factor': | |||||
| 1.0, | |||||
| 'end_factor': | |||||
| 0.0, | |||||
| 'total_iters': | |||||
| int( | |||||
| len(dataset['train']) | |||||
| / cfg.train.dataloader.batch_size_per_gpu) | |||||
| * cfg.train.max_epochs, | |||||
| 'options': { | |||||
| 'by_epoch': False | |||||
| } | |||||
| } | |||||
| cfg.train.hooks = [{ | |||||
| 'type': 'CheckpointHook', | |||||
| 'interval': 1 | |||||
| }, { | |||||
| 'type': 'TextLoggerHook', | |||||
| 'interval': 1 | |||||
| }, { | |||||
| 'type': 'IterTimerHook' | |||||
| }, { | |||||
| 'type': 'EvaluationHook', | |||||
| 'by_epoch': False, | |||||
| 'interval': 30 | |||||
| }] | |||||
| return cfg | |||||
| def add_sentence2(features): | |||||
| return { | |||||
| 'text2': | |||||
| features['target']['span2_text'] + '指代' | |||||
| + features['target']['span1_text'] | |||||
| } | |||||
| dataset = MsDataset.load('clue', subset_name='cluewsc2020') | |||||
| dataset = { | |||||
| k: v.to_hf_dataset().map(add_sentence2) | |||||
| for k, v in dataset.items() | |||||
| } | |||||
| kwargs = dict( | |||||
| model='damo/nlp_structbert_backbone_base_std', | |||||
| train_dataset=dataset['train'], | |||||
| eval_dataset=dataset['validation'], | |||||
| work_dir=self.tmp_dir, | |||||
| cfg_modify_fn=cfg_modify_fn) | |||||
| os.environ['LOCAL_RANK'] = '0' | |||||
| trainer: NlpEpochBasedTrainer = build_trainer( | |||||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||||
| class CalculateFisherHook(Hook): | |||||
| @staticmethod | |||||
| def forward_step(model, inputs): | |||||
| inputs = to_device(inputs, trainer.device) | |||||
| trainer.train_step(model, inputs) | |||||
| return trainer.train_outputs['loss'] | |||||
| def before_run(self, trainer: NlpEpochBasedTrainer): | |||||
| v = calculate_fisher(trainer.model, trainer.train_dataloader, | |||||
| self.forward_step, 0.2) | |||||
| trainer.optimizer.set_gradient_mask(v) | |||||
| if child_tuning_type == 'ChildTuning-D': | |||||
| trainer.register_hook(CalculateFisherHook()) | |||||
| trainer.train() | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| unittest.main() | unittest.main() | ||||
| @@ -47,6 +47,11 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||||
| @unittest.skip | @unittest.skip | ||||
| def test_word_segmentation(self): | def test_word_segmentation(self): | ||||
| """This unittest is used to reproduce the icwb2:pku dataset + structbert model training results. | |||||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||||
| """ | |||||
| os.system( | os.system( | ||||
| f'curl http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip > {self.tmp_dir}/icwb2-data.zip' | f'curl http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip > {self.tmp_dir}/icwb2-data.zip' | ||||
| ) | ) | ||||
| @@ -114,7 +119,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||||
| return cfg | return cfg | ||||
| self.finetune( | self.finetune( | ||||
| 'damo/nlp_structbert_backbone_tiny_std', | |||||
| 'damo/nlp_structbert_backbone_base_std', | |||||
| train_dataset, | train_dataset, | ||||
| dev_dataset, | dev_dataset, | ||||
| cfg_modify_fn=cfg_modify_fn) | cfg_modify_fn=cfg_modify_fn) | ||||