From 88d0804dcd6cdc430ea96ede82796a48a92596c2 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Tue, 30 Aug 2022 10:52:07 +0800 Subject: [PATCH] [to #42322933] Add S4: child-tuning 1. add child-tuning optimizer and ut 2. fix a training bug which can cause interruption after cross-evaluation 3. move model.params from cfg to default args in build_optimizer to prevent the saving of params in save_pretrained Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9891963 --- modelscope/trainers/optimizer/__init__.py | 3 +- modelscope/trainers/optimizer/builder.py | 5 +- .../optimizer/child_tuning_adamw_optimizer.py | 188 ++++++++++++++++++ modelscope/trainers/trainer.py | 1 + .../test_finetune_sequence_classification.py | 132 +++++++++++- .../test_finetune_token_classificatin.py | 7 +- 6 files changed, 331 insertions(+), 5 deletions(-) create mode 100644 modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py diff --git a/modelscope/trainers/optimizer/__init__.py b/modelscope/trainers/optimizer/__init__.py index 884f3043..9962c2c2 100644 --- a/modelscope/trainers/optimizer/__init__.py +++ b/modelscope/trainers/optimizer/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .builder import OPTIMIZERS, build_optimizer +from .child_tuning_adamw_optimizer import ChildTuningAdamW -__all__ = ['OPTIMIZERS', 'build_optimizer'] +__all__ = ['OPTIMIZERS', 'build_optimizer', 'ChildTuningAdamW'] diff --git a/modelscope/trainers/optimizer/builder.py b/modelscope/trainers/optimizer/builder.py index 4d772dd9..f43768d6 100644 --- a/modelscope/trainers/optimizer/builder.py +++ b/modelscope/trainers/optimizer/builder.py @@ -20,7 +20,10 @@ def build_optimizer(model: torch.nn.Module, """ if hasattr(model, 'module'): model = model.module - cfg.params = model.parameters() + + if default_args is None: + default_args = {} + default_args['params'] = model.parameters() return build_from_cfg( cfg, OPTIMIZERS, group_key=default_group, default_args=default_args) diff --git a/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py b/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py new file mode 100644 index 00000000..d004071f --- /dev/null +++ b/modelscope/trainers/optimizer/child_tuning_adamw_optimizer.py @@ -0,0 +1,188 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import types +from typing import Callable, Iterable, Tuple + +import numpy as np +import torch +from torch.distributions.bernoulli import Bernoulli +from torch.optim import Optimizer + +from modelscope.utils.logger import get_logger +from .builder import OPTIMIZERS, default_group + +logger = get_logger(__name__) + +__all__ = ['calculate_fisher', 'ChildTuningAdamW'] + + +def calculate_fisher(model: torch.nn.Module, + data_loader, + forward_step, + reserve_p, + grad_clip=None): + + gradient_mask = dict() + model.train() + for name, params in model.named_parameters(): + if 'layer' in name: + gradient_mask[params] = params.new_zeros(params.size()) + + iters = len(data_loader) + for inputs in data_loader: + loss = forward_step(model, inputs) + loss.backward() + for name, params in model.named_parameters(): + if 'layer' in name: + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_(params, **grad_clip) + gradient_mask[params] += (params.grad**2) / iters + model.zero_grad() + + logger.info('Calculate Fisher Information...') + + # Numpy + r = None + for k, v in gradient_mask.items(): + v = v.view(-1).cpu().numpy() + if r is None: + r = v + else: + r = np.append(r, v) + polar = np.percentile(r, (1 - reserve_p) * 100) + for k in gradient_mask: + gradient_mask[k] = gradient_mask[k] >= polar + print('Polar => {}'.format(polar)) + + # TODO: pytorch: torch.kthvalue + + return gradient_mask + + +@OPTIMIZERS.register_module( + group_key=default_group, module_name='ChildTuningAdamW') +class ChildTuningAdamW(Optimizer): + + def __init__(self, + params: Iterable[torch.nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + reserve_p=1.0, + mode=None): + if lr < 0.0: + raise ValueError( + 'Invalid learning rate: {} - should be >= 0.0'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError( + 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( + betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError( + 'Invalid beta parameter: {} - should be in [0.0, 1.0['.format( + betas[1])) + if not 0.0 <= eps: + raise ValueError( + 'Invalid epsilon value: {} - should be >= 0.0'.format(eps)) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias) + super().__init__(params, defaults) + + self.gradient_mask = None + self.reserve_p = reserve_p + self.mode = mode + + def set_gradient_mask(self, gradient_mask): + self.gradient_mask = gradient_mask + + def step(self, closure: Callable = None): + """ + Performs a single optimization step. + Arguments: + closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + 'Adam does not support sparse gradients, please consider SparseAdam instead' + ) + + # ChildTuning code + if self.mode is not None: + if self.mode == 'ChildTuning-D': + if p in self.gradient_mask: + grad *= self.gradient_mask[p] + else: + # ChildTuning-F + grad_mask = Bernoulli( + grad.new_full( + size=grad.size(), fill_value=self.reserve_p)) + grad *= grad_mask.sample() / self.reserve_p + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group['eps']) + + step_size = group['lr'] + if group['correct_bias']: # No bias correction for Bert + bias_correction1 = 1.0 - beta1**state['step'] + bias_correction2 = 1.0 - beta2**state['step'] + step_size = step_size * math.sqrt( + bias_correction2) / bias_correction1 + + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay']) + + return loss diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index dc8c5c09..290478cb 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -800,6 +800,7 @@ class EpochBasedTrainer(BaseTrainer): self.invoke_hook(TrainerStages.after_train_iter) del self.data_batch self._iter += 1 + self._mode = ModeKeys.TRAIN if i + 1 >= self.iters_per_epoch: break diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index 847e47ef..24f1a2fd 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -6,9 +6,15 @@ import unittest from modelscope.metainfo import Preprocessors, Trainers from modelscope.models import Model +from modelscope.msdatasets import MsDataset from modelscope.pipelines import pipeline from modelscope.trainers import build_trainer +from modelscope.trainers.hooks import Hook +from modelscope.trainers.nlp_trainer import NlpEpochBasedTrainer +from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \ + calculate_fisher from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.data_utils import to_device class TestFinetuneSequenceClassification(unittest.TestCase): @@ -69,6 +75,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase): @unittest.skip def test_finetune_afqmc(self): + """This unittest is used to reproduce the clue:afqmc dataset + structbert model training results. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ def cfg_modify_fn(cfg): cfg.task = Tasks.sentence_similarity @@ -114,7 +124,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): dc.local_files_only = True dataset = load_dataset('clue', 'afqmc', download_config=dc) self.finetune( - model_id='damo/nlp_structbert_backbone_tiny_std', + model_id='damo/nlp_structbert_backbone_base_std', train_dataset=dataset['train'], eval_dataset=dataset['validation'], cfg_modify_fn=cfg_modify_fn) @@ -124,6 +134,10 @@ class TestFinetuneSequenceClassification(unittest.TestCase): @unittest.skip def test_finetune_tnews(self): + """This unittest is used to reproduce the clue:tnews dataset + structbert model training results. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ def cfg_modify_fn(cfg): # TODO no proper task for tnews @@ -175,13 +189,21 @@ class TestFinetuneSequenceClassification(unittest.TestCase): dataset = load_dataset('clue', 'tnews', download_config=dc) self.finetune( - model_id='damo/nlp_structbert_backbone_tiny_std', + model_id='damo/nlp_structbert_backbone_base_std', train_dataset=dataset['train'], eval_dataset=dataset['validation'], cfg_modify_fn=cfg_modify_fn) @unittest.skip def test_veco_xnli(self): + """This unittest is used to reproduce the xnli dataset + veco model training results. + + Here we follow the training scenario listed in the Alicemind open source project: + https://github.com/alibaba/AliceMind/tree/main/VECO + by training the english language subset. + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + from datasets import load_dataset langs = ['en'] langs_eval = ['en'] @@ -267,6 +289,112 @@ class TestFinetuneSequenceClassification(unittest.TestCase): name=Trainers.nlp_veco_trainer, cfg_modify_fn=cfg_modify_fn) + @unittest.skip + def test_finetune_cluewsc(self): + """This unittest is used to reproduce the clue:wsc dataset + structbert model training results. + + A runnable sample of child-tuning is also showed here. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + + child_tuning_type = 'ChildTuning-F' + mode = {} + if child_tuning_type is not None: + mode = {'mode': child_tuning_type, 'reserve_p': 0.2} + + def cfg_modify_fn(cfg): + cfg.task = 'nli' + cfg['preprocessor'] = {'type': 'nli-tokenizer'} + cfg['dataset'] = { + 'train': { + 'labels': ['0', '1'], + 'first_sequence': 'text', + 'second_sequence': 'text2', + 'label': 'label', + } + } + cfg.train.dataloader.batch_size_per_gpu = 16 + cfg.train.max_epochs = 30 + cfg.train.optimizer = { + 'type': + 'AdamW' if child_tuning_type is None else 'ChildTuningAdamW', + 'lr': 1e-5, + 'options': {}, + **mode, + } + cfg.train.lr_scheduler = { + 'type': + 'LinearLR', + 'start_factor': + 1.0, + 'end_factor': + 0.0, + 'total_iters': + int( + len(dataset['train']) + / cfg.train.dataloader.batch_size_per_gpu) + * cfg.train.max_epochs, + 'options': { + 'by_epoch': False + } + } + cfg.train.hooks = [{ + 'type': 'CheckpointHook', + 'interval': 1 + }, { + 'type': 'TextLoggerHook', + 'interval': 1 + }, { + 'type': 'IterTimerHook' + }, { + 'type': 'EvaluationHook', + 'by_epoch': False, + 'interval': 30 + }] + return cfg + + def add_sentence2(features): + return { + 'text2': + features['target']['span2_text'] + '指代' + + features['target']['span1_text'] + } + + dataset = MsDataset.load('clue', subset_name='cluewsc2020') + dataset = { + k: v.to_hf_dataset().map(add_sentence2) + for k, v in dataset.items() + } + + kwargs = dict( + model='damo/nlp_structbert_backbone_base_std', + train_dataset=dataset['train'], + eval_dataset=dataset['validation'], + work_dir=self.tmp_dir, + cfg_modify_fn=cfg_modify_fn) + + os.environ['LOCAL_RANK'] = '0' + trainer: NlpEpochBasedTrainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + + class CalculateFisherHook(Hook): + + @staticmethod + def forward_step(model, inputs): + inputs = to_device(inputs, trainer.device) + trainer.train_step(model, inputs) + return trainer.train_outputs['loss'] + + def before_run(self, trainer: NlpEpochBasedTrainer): + v = calculate_fisher(trainer.model, trainer.train_dataloader, + self.forward_step, 0.2) + trainer.optimizer.set_gradient_mask(v) + + if child_tuning_type == 'ChildTuning-D': + trainer.register_hook(CalculateFisherHook()) + trainer.train() + if __name__ == '__main__': unittest.main() diff --git a/tests/trainers/test_finetune_token_classificatin.py b/tests/trainers/test_finetune_token_classificatin.py index 520d1a3c..c34410be 100644 --- a/tests/trainers/test_finetune_token_classificatin.py +++ b/tests/trainers/test_finetune_token_classificatin.py @@ -47,6 +47,11 @@ class TestFinetuneTokenClassification(unittest.TestCase): @unittest.skip def test_word_segmentation(self): + """This unittest is used to reproduce the icwb2:pku dataset + structbert model training results. + + User can train a custom dataset by modifying this piece of code and comment the @unittest.skip. + """ + os.system( f'curl http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip > {self.tmp_dir}/icwb2-data.zip' ) @@ -114,7 +119,7 @@ class TestFinetuneTokenClassification(unittest.TestCase): return cfg self.finetune( - 'damo/nlp_structbert_backbone_tiny_std', + 'damo/nlp_structbert_backbone_base_std', train_dataset, dev_dataset, cfg_modify_fn=cfg_modify_fn)