1. add child-tuning optimizer and ut
2. fix a training bug which can cause interruption after cross-evaluation
3. move model.params from cfg to default args in build_optimizer to prevent the saving of params in save_pretrained
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9891963
master
| @@ -1,4 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .builder import OPTIMIZERS, build_optimizer | |||
| from .child_tuning_adamw_optimizer import ChildTuningAdamW | |||
| __all__ = ['OPTIMIZERS', 'build_optimizer'] | |||
| __all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] | |||
| @@ -20,7 +20,10 @@ def build_optimizer(model: torch.nn.Module, | |||
| """ | |||
| if hasattr(model, 'module'): | |||
| model = model.module | |||
| cfg.params = model.parameters() | |||
| if default_args is None: | |||
| default_args = {} | |||
| default_args['params'] = model.parameters() | |||
| return build_from_cfg( | |||
| cfg, OPTIMIZERS, group_key=default_group, default_args=default_args) | |||
| @@ -0,0 +1,188 @@ | |||
| # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. | |||
| # All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import math | |||
| import types | |||
| from typing import Callable, Iterable, Tuple | |||
| import numpy as np | |||
| import torch | |||
| from torch.distributions.bernoulli import Bernoulli | |||
| from torch.optim import Optimizer | |||
| from modelscope.utils.logger import get_logger | |||
| from .builder import OPTIMIZERS, default_group | |||
| logger = get_logger(__name__) | |||
| __all__ = ['calculate_fisher', 'ChildTuningAdamW'] | |||
| def calculate_fisher(model: torch.nn.Module, | |||
| data_loader, | |||
| forward_step, | |||
| reserve_p, | |||
| grad_clip=None): | |||
| gradient_mask = dict() | |||
| model.train() | |||
| for name, params in model.named_parameters(): | |||
| if 'layer' in name: | |||
| gradient_mask[params] = params.new_zeros(params.size()) | |||
| iters = len(data_loader) | |||
| for inputs in data_loader: | |||
| loss = forward_step(model, inputs) | |||
| loss.backward() | |||
| for name, params in model.named_parameters(): | |||
| if 'layer' in name: | |||
| if grad_clip is not None: | |||
| torch.nn.utils.clip_grad_norm_(params, **grad_clip) | |||
| gradient_mask[params] += (params.grad**2) / iters | |||
| model.zero_grad() | |||
| logger.info('Calculate Fisher Information...') | |||
| # Numpy | |||
| r = None | |||
| for k, v in gradient_mask.items(): | |||
| v = v.view(-1).cpu().numpy() | |||
| if r is None: | |||
| r = v | |||
| else: | |||
| r = np.append(r, v) | |||
| polar = np.percentile(r, (1 - reserve_p) * 100) | |||
| for k in gradient_mask: | |||
| gradient_mask[k] = gradient_mask[k] >= polar | |||
| print('Polar => {}'.format(polar)) | |||
| # TODO: pytorch: torch.kthvalue | |||
| return gradient_mask | |||
| @OPTIMIZERS.register_module( | |||
| group_key=default_group, module_name='ChildTuningAdamW') | |||
| class ChildTuningAdamW(Optimizer): | |||
| def __init__(self, | |||
| params: Iterable[torch.nn.parameter.Parameter], | |||
| lr: float = 1e-3, | |||
| betas: Tuple[float, float] = (0.9, 0.999), | |||
| eps: float = 1e-6, | |||
| weight_decay: float = 0.0, | |||
| correct_bias: bool = True, | |||
| reserve_p=1.0, | |||
| mode=None): | |||
| if lr < 0.0: | |||
| raise ValueError( | |||
| 'Invalid learning rate: {} - should be >= 0.0'.format(lr)) | |||
| if not 0.0 <= betas[0] < 1.0: | |||
| raise ValueError( | |||
| 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( | |||
| betas[0])) | |||
| if not 0.0 <= betas[1] < 1.0: | |||
| raise ValueError( | |||
| 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( | |||
| betas[1])) | |||
| if not 0.0 <= eps: | |||
| raise ValueError( | |||
| 'Invalid epsilon value: {} - should be >= 0.0'.format(eps)) | |||
| defaults = dict( | |||
| lr=lr, | |||
| betas=betas, | |||
| eps=eps, | |||
| weight_decay=weight_decay, | |||
| correct_bias=correct_bias) | |||
| super().__init__(params, defaults) | |||
| self.gradient_mask = None | |||
| self.reserve_p = reserve_p | |||
| self.mode = mode | |||
| def set_gradient_mask(self, gradient_mask): | |||
| self.gradient_mask = gradient_mask | |||
| def step(self, closure: Callable = None): | |||
| """ | |||
| Performs a single optimization step. | |||
| Arguments: | |||
| closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. | |||
| """ | |||
| loss = None | |||
| if closure is not None: | |||
| loss = closure() | |||
| for group in self.param_groups: | |||
| for p in group['params']: | |||
| if p.grad is None: | |||
| continue | |||
| grad = p.grad.data | |||
| if grad.is_sparse: | |||
| raise RuntimeError( | |||
| 'Adam does not support sparse gradients, please consider SparseAdam instead' | |||
| ) | |||
| # ChildTuning code | |||
| if self.mode is not None: | |||
| if self.mode == 'ChildTuning-D': | |||
| if p in self.gradient_mask: | |||
| grad *= self.gradient_mask[p] | |||
| else: | |||
| # ChildTuning-F | |||
| grad_mask = Bernoulli( | |||
| grad.new_full( | |||
| size=grad.size(), fill_value=self.reserve_p)) | |||
| grad *= grad_mask.sample() / self.reserve_p | |||
| state = self.state[p] | |||
| # State initialization | |||
| if len(state) == 0: | |||
| state['step'] = 0 | |||
| # Exponential moving average of gradient values | |||
| state['exp_avg'] = torch.zeros_like(p.data) | |||
| # Exponential moving average of squared gradient values | |||
| state['exp_avg_sq'] = torch.zeros_like(p.data) | |||
| exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |||
| beta1, beta2 = group['betas'] | |||
| state['step'] += 1 | |||
| # Decay the first and second moment running average coefficient | |||
| # In-place operations to update the averages at the same time | |||
| exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) | |||
| exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) | |||
| denom = exp_avg_sq.sqrt().add_(group['eps']) | |||
| step_size = group['lr'] | |||
| if group['correct_bias']: # No bias correction for Bert | |||
| bias_correction1 = 1.0 - beta1**state['step'] | |||
| bias_correction2 = 1.0 - beta2**state['step'] | |||
| step_size = step_size * math.sqrt( | |||
| bias_correction2) / bias_correction1 | |||
| p.data.addcdiv_(exp_avg, denom, value=-step_size) | |||
| # Just adding the square of the weights to the loss function is *not* | |||
| # the correct way of using L2 regularization/weight decay with Adam, | |||
| # since that will interact with the m and v parameters in strange ways. | |||
| # | |||
| # Instead we want to decay the weights in a manner that doesn't interact | |||
| # with the m/v parameters. This is equivalent to adding the square | |||
| # of the weights to the loss with plain (non-momentum) SGD. | |||
| # Add weight decay at the end (fixed version) | |||
| p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay']) | |||
| return loss | |||
| @@ -800,6 +800,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self.invoke_hook(TrainerStages.after_train_iter) | |||
| del self.data_batch | |||
| self._iter += 1 | |||
| self._mode = ModeKeys.TRAIN | |||
| if i + 1 >= self.iters_per_epoch: | |||
| break | |||
| @@ -6,9 +6,15 @@ import unittest | |||
| from modelscope.metainfo import Preprocessors, Trainers | |||
| from modelscope.models import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.trainers.hooks import Hook | |||
| from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer | |||
| from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ | |||
| calculate_fisher | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.data_utils import to_device | |||
| class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| @@ -69,6 +75,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| @unittest.skip | |||
| def test_finetune_afqmc(self): | |||
| """This unittest is used to reproduce the clue:afqmc dataset + structbert model training results. | |||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||
| """ | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = Tasks.sentence_similarity | |||
| @@ -114,7 +124,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| dc.local_files_only = True | |||
| dataset = load_dataset('clue', 'afqmc', download_config=dc) | |||
| self.finetune( | |||
| model_id='damo/nlp_structbert_backbone_tiny_std', | |||
| model_id='damo/nlp_structbert_backbone_base_std', | |||
| train_dataset=dataset['train'], | |||
| eval_dataset=dataset['validation'], | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| @@ -124,6 +134,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| @unittest.skip | |||
| def test_finetune_tnews(self): | |||
| """This unittest is used to reproduce the clue:tnews dataset + structbert model training results. | |||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||
| """ | |||
| def cfg_modify_fn(cfg): | |||
| # TODO no proper task for tnews | |||
| @@ -175,13 +189,21 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| dataset = load_dataset('clue', 'tnews', download_config=dc) | |||
| self.finetune( | |||
| model_id='damo/nlp_structbert_backbone_tiny_std', | |||
| model_id='damo/nlp_structbert_backbone_base_std', | |||
| train_dataset=dataset['train'], | |||
| eval_dataset=dataset['validation'], | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| @unittest.skip | |||
| def test_veco_xnli(self): | |||
| """This unittest is used to reproduce the xnli dataset + veco model training results. | |||
| Here we follow the training scenario listed in the Alicemind open source project: | |||
| https://github.com/alibaba/AliceMind/tree/main/VECO | |||
| by training the english language subset. | |||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||
| """ | |||
| from datasets import load_dataset | |||
| langs = ['en'] | |||
| langs_eval = ['en'] | |||
| @@ -267,6 +289,112 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
| name=Trainers.nlp_veco_trainer, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| @unittest.skip | |||
| def test_finetune_cluewsc(self): | |||
| """This unittest is used to reproduce the clue:wsc dataset + structbert model training results. | |||
| A runnable sample of child-tuning is also showed here. | |||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||
| """ | |||
| child_tuning_type = 'ChildTuning-F' | |||
| mode = {} | |||
| if child_tuning_type is not None: | |||
| mode = {'mode': child_tuning_type, 'reserve_p': 0.2} | |||
| def cfg_modify_fn(cfg): | |||
| cfg.task = 'nli' | |||
| cfg['preprocessor'] = {'type': 'nli-tokenizer'} | |||
| cfg['dataset'] = { | |||
| 'train': { | |||
| 'labels': ['0', '1'], | |||
| 'first_sequence': 'text', | |||
| 'second_sequence': 'text2', | |||
| 'label': 'label', | |||
| } | |||
| } | |||
| cfg.train.dataloader.batch_size_per_gpu = 16 | |||
| cfg.train.max_epochs = 30 | |||
| cfg.train.optimizer = { | |||
| 'type': | |||
| 'AdamW' if child_tuning_type is None else 'ChildTuningAdamW', | |||
| 'lr': 1e-5, | |||
| 'options': {}, | |||
| **mode, | |||
| } | |||
| cfg.train.lr_scheduler = { | |||
| 'type': | |||
| 'LinearLR', | |||
| 'start_factor': | |||
| 1.0, | |||
| 'end_factor': | |||
| 0.0, | |||
| 'total_iters': | |||
| int( | |||
| len(dataset['train']) | |||
| / cfg.train.dataloader.batch_size_per_gpu) | |||
| * cfg.train.max_epochs, | |||
| 'options': { | |||
| 'by_epoch': False | |||
| } | |||
| } | |||
| cfg.train.hooks = [{ | |||
| 'type': 'CheckpointHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'TextLoggerHook', | |||
| 'interval': 1 | |||
| }, { | |||
| 'type': 'IterTimerHook' | |||
| }, { | |||
| 'type': 'EvaluationHook', | |||
| 'by_epoch': False, | |||
| 'interval': 30 | |||
| }] | |||
| return cfg | |||
| def add_sentence2(features): | |||
| return { | |||
| 'text2': | |||
| features['target']['span2_text'] + '指代' | |||
| + features['target']['span1_text'] | |||
| } | |||
| dataset = MsDataset.load('clue', subset_name='cluewsc2020') | |||
| dataset = { | |||
| k: v.to_hf_dataset().map(add_sentence2) | |||
| for k, v in dataset.items() | |||
| } | |||
| kwargs = dict( | |||
| model='damo/nlp_structbert_backbone_base_std', | |||
| train_dataset=dataset['train'], | |||
| eval_dataset=dataset['validation'], | |||
| work_dir=self.tmp_dir, | |||
| cfg_modify_fn=cfg_modify_fn) | |||
| os.environ['LOCAL_RANK'] = '0' | |||
| trainer: NlpEpochBasedTrainer = build_trainer( | |||
| name=Trainers.nlp_base_trainer, default_args=kwargs) | |||
| class CalculateFisherHook(Hook): | |||
| @staticmethod | |||
| def forward_step(model, inputs): | |||
| inputs = to_device(inputs, trainer.device) | |||
| trainer.train_step(model, inputs) | |||
| return trainer.train_outputs['loss'] | |||
| def before_run(self, trainer: NlpEpochBasedTrainer): | |||
| v = calculate_fisher(trainer.model, trainer.train_dataloader, | |||
| self.forward_step, 0.2) | |||
| trainer.optimizer.set_gradient_mask(v) | |||
| if child_tuning_type == 'ChildTuning-D': | |||
| trainer.register_hook(CalculateFisherHook()) | |||
| trainer.train() | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||
| @@ -47,6 +47,11 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||
| @unittest.skip | |||
| def test_word_segmentation(self): | |||
| """This unittest is used to reproduce the icwb2:pku dataset + structbert model training results. | |||
| User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. | |||
| """ | |||
| os.system( | |||
| f'curl http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip > {self.tmp_dir}/icwb2-data.zip' | |||
| ) | |||
| @@ -114,7 +119,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): | |||
| return cfg | |||
| self.finetune( | |||
| 'damo/nlp_structbert_backbone_tiny_std', | |||
| 'damo/nlp_structbert_backbone_base_std', | |||
| train_dataset, | |||
| dev_dataset, | |||
| cfg_modify_fn=cfg_modify_fn) | |||