From 09d2296f36f1a301dc5e144e00a692dfce2675ee Mon Sep 17 00:00:00 2001 From: "laiyin.lyc" Date: Tue, 11 Oct 2022 16:05:20 +0800 Subject: [PATCH] [to #44847108] add sparsity hook (pst algorithm) Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10198228 * [to #44847108] add sparsity hook (pst algorithm) --- modelscope/metainfo.py | 3 + modelscope/trainers/hooks/__init__.py | 4 +- .../trainers/hooks/compression/__init__.py | 24 ++ .../hooks/compression/sparsity_hook.py | 131 +++++++++++ .../trainers/hooks/compression/utils.py | 208 ++++++++++++++++++ tests/trainers/hooks/compression/__init__.py | 0 .../hooks/compression/test_sparsity_hook.py | 113 ++++++++++ 7 files changed, 482 insertions(+), 1 deletion(-) create mode 100644 modelscope/trainers/hooks/compression/__init__.py create mode 100644 modelscope/trainers/hooks/compression/sparsity_hook.py create mode 100644 modelscope/trainers/hooks/compression/utils.py create mode 100644 tests/trainers/hooks/compression/__init__.py create mode 100644 tests/trainers/hooks/compression/test_sparsity_hook.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 1b8c4720..77627abc 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -404,6 +404,9 @@ class Hooks(object): IterTimerHook = 'IterTimerHook' EvaluationHook = 'EvaluationHook' + # Compression + SparsityHook = 'SparsityHook' + class LR_Schedulers(object): """learning rate scheduler is defined here diff --git a/modelscope/trainers/hooks/__init__.py b/modelscope/trainers/hooks/__init__.py index f133041b..a2e0cf4b 100644 --- a/modelscope/trainers/hooks/__init__.py +++ b/modelscope/trainers/hooks/__init__.py @@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .builder import HOOKS, build_hook from .checkpoint_hook import BestCkptSaverHook, CheckpointHook + from .compression import SparsityHook from .evaluation_hook import EvaluationHook from .hook import Hook from .iter_timer_hook import IterTimerHook - from .logger import TextLoggerHook, TensorboardHook + from .logger import TensorboardHook, TextLoggerHook from .lr_scheduler_hook import LrSchedulerHook from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook, OptimizerHook, TorchAMPOptimizerHook) @@ -19,6 +20,7 @@ else: _import_structure = { 'builder': ['HOOKS', 'build_hook'], 'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'], + 'compression': ['SparsityHook'], 'evaluation_hook': ['EvaluationHook'], 'hook': ['Hook'], 'iter_timer_hook': ['IterTimerHook'], diff --git a/modelscope/trainers/hooks/compression/__init__.py b/modelscope/trainers/hooks/compression/__init__.py new file mode 100644 index 00000000..f755b2ca --- /dev/null +++ b/modelscope/trainers/hooks/compression/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .sparsity_hook import SparsityHook + from .utils import SparseLinear, convert_sparse_network + +else: + _import_structure = { + 'sparsity_hook': ['SparsityHook'], + 'utils': ['convert_sparse_network', 'SparseLinear'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/trainers/hooks/compression/sparsity_hook.py b/modelscope/trainers/hooks/compression/sparsity_hook.py new file mode 100644 index 00000000..993488d8 --- /dev/null +++ b/modelscope/trainers/hooks/compression/sparsity_hook.py @@ -0,0 +1,131 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +from modelscope import __version__ +from modelscope.metainfo import Hooks +from modelscope.trainers.hooks.builder import HOOKS +from modelscope.trainers.hooks.hook import Hook +from modelscope.trainers.hooks.priority import Priority +from modelscope.utils.checkpoint import save_checkpoint +from modelscope.utils.torch_utils import is_master + + +@HOOKS.register_module(module_name=Hooks.SparsityHook) +class SparsityHook(Hook): + + PRIORITY = Priority.HIGHEST + + def __init__(self, pruning_method, config={}, save_dir=None): + self.pruning_method = pruning_method + self.save_dir = save_dir + + self.compress_module = config.get('compress_module', []) + self.weight_rank = config.get('weight_rank', 8) + self.weight_beta = config.get('weight_beta', 1) + self.mask_rank = config.get('mask_rank', 8) + self.mask_alpha1 = config.get('mask_alpha1', 1) + self.mask_alpha2 = config.get('mask_alpha2', 1) + + self.step = 0 + self.total_step = 0 + self.frequency = config.get('frequency', 1) + self.initial_warmup = config.get('initial_warmup', 0.1) + self.final_warmup = config.get('final_warmup', 0.3) + self.initial_sparsity = config.get('initial_sparsity', 0.0) + self.final_sparsity = config.get('final_sparsity', 0.0) + + def before_run(self, trainer): + import torch + + from .utils import SparseLinear, convert_sparse_network + + if self.save_dir is None: + self.save_dir = trainer.work_dir + + if len(self.compress_module) == 0: + convert_sparse_network( + trainer.model, + pruning_method=self.pruning_method, + weight_rank=self.weight_rank, + weight_beta=self.weight_beta, + mask_rank=self.mask_rank, + mask_alpha1=self.mask_alpha1, + mask_alpha2=self.mask_alpha2, + logger=trainer.logger, + ) + else: + for cm in self.compress_module: + for name, module in trainer.model.named_modules(): + if name != cm: + continue + convert_sparse_network( + module, + pruning_method=self.pruning_method, + weight_rank=self.weight_rank, + weight_beta=self.weight_beta, + mask_rank=self.mask_rank, + mask_alpha1=self.mask_alpha1, + mask_alpha2=self.mask_alpha2, + logger=trainer.logger, + ) + + for i in range(len(trainer.optimizer.param_groups)): + new_train_params = [] + for param in trainer.optimizer.param_groups[i]['params']: + is_find = False + for name, module in trainer.model.named_modules(): + if isinstance(module, SparseLinear): + if torch.equal(param.half(), + module.weight.data.half()): + is_find = True + break + + if not is_find: + new_train_params.append(param) + + trainer.optimizer.param_groups[i]['params'] = new_train_params + + new_params = [] + for name, module in trainer.model.named_modules(): + if isinstance(module, SparseLinear): + new_params.extend( + [p for p in module.parameters() if p.requires_grad]) + + trainer.optimizer.add_param_group({'params': new_params}) + + self.total_step = trainer.iters_per_epoch * trainer._max_epochs + + def before_train_iter(self, trainer): + from .utils import schedule_sparsity_ratio, update_network_sparsity + + cur_sparsity = schedule_sparsity_ratio( + self.step, + self.total_step, + self.frequency, + self.initial_warmup, + self.final_warmup, + self.initial_sparsity, + self.final_sparsity, + ) + + update_network_sparsity(trainer.model, cur_sparsity) + + if is_master(): + trainer.logger.info( + f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}' + ) + + self.step += 1 + + def after_run(self, trainer): + from .utils import generate_sparse_model + + generate_sparse_model(trainer.model, logger=trainer.logger) + + self._save_checkpoint(trainer) + + def _save_checkpoint(self, trainer): + if is_master(): + trainer.logger.info('Saving checkpoint at final compress') + cur_save_name = os.path.join(self.save_dir, 'compress_model.pth') + save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) diff --git a/modelscope/trainers/hooks/compression/utils.py b/modelscope/trainers/hooks/compression/utils.py new file mode 100644 index 00000000..59418201 --- /dev/null +++ b/modelscope/trainers/hooks/compression/utils.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn + +from modelscope.utils.torch_utils import is_master + + +class SparseBinarizer(torch.autograd.Function): + + @staticmethod + def forward(ctx, mask_scores, sparsity): + num_prune = int(mask_scores.numel() * sparsity) + prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune] + mask = mask_scores.clone().fill_(1) + mask.reshape(-1)[prune_indices] = 0.0 + return mask + + @staticmethod + def backward(ctx, gradOutput): + return gradOutput, None + + +class SparseLinear(nn.Module): + """ + Fully Connected layer with on the fly adaptive mask. + """ + + def __init__( + self, + module, + pruning_method='pst', + weight_rank=8, + weight_beta=1.0, + mask_rank=8, + mask_alpha1=1.0, + mask_alpha2=1.0, + ): + super(SparseLinear, self).__init__() + self.module = module + out_features = self.module.weight.shape[0] + in_features = self.module.weight.shape[1] + + self.weight = self.module.weight + self.module.weight = None + self.module._parameters.pop('weight') + + self.pruning_method = pruning_method + + self.cur_sparsity = 0.0 + + if self.pruning_method == 'pst': + self.weight_rank = weight_rank + self.weight_beta = weight_beta + self.mask_rank = mask_rank + self.mask_alpha1 = mask_alpha1 + self.mask_alpha2 = mask_alpha2 + + # create trainable params + self.weight_U = nn.Parameter( + torch.randn(out_features, self.weight_rank).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.weight_V = nn.Parameter( + torch.zeros(self.weight_rank, in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + + self.mask_scores_A = nn.Parameter( + torch.randn(out_features, self.mask_rank).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_B = nn.Parameter( + torch.zeros(self.mask_rank, in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_R = nn.Parameter( + torch.zeros(out_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + self.mask_scores_C = nn.Parameter( + torch.zeros(in_features).to( + device=self.weight.device, dtype=self.weight.dtype)) + + self.weight.requires_grad = False + if self.module.bias is not None: + self.module.bias.requires_grad = False + + def forward(self, *inputs): + if self.pruning_method == 'pst': + weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V + mask_scores = ( + weight.abs() + + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B + + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) + + self.mask_scores_C.unsqueeze(0))) + + mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) + masked_weight = mask * weight + + self.module.weight = masked_weight + return self.module(*inputs) + else: + return self.module(*inputs) + + def convert(self): + if self.pruning_method == 'pst': + weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V + mask_scores = ( + weight.abs() + + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B + + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) + + self.mask_scores_C.unsqueeze(0))) + + mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) + + masked_weight = mask * weight + self.module.weight = nn.Parameter(masked_weight.data) + + +def _setattr(model, name, module): + name_list = name.split('.') + for name in name_list[:-1]: + model = getattr(model, name) + setattr(model, name_list[-1], module) + + +def convert_sparse_network( + model, + pruning_method, + weight_rank, + weight_beta, + mask_rank, + mask_alpha1, + mask_alpha2, + logger=None, +): + compress_module = [nn.Linear] + try: + from megatron import mpu + compress_module.extend( + [mpu.RowParallelLinear, mpu.ColumnParallelLinear]) + except ImportError: + pass + + for name, module in model.named_modules(): + if type(module) in compress_module: + new_module = SparseLinear( + module, + pruning_method, + weight_rank, + weight_beta, + mask_rank, + mask_alpha1, + mask_alpha2, + ) + + # replace original module by new sparse module + _setattr(model, name, new_module) + + if is_master(): + if logger: + logger.info(f'convert {name} to sparse module.') + else: + print(f'convert {name} to sparse module.') + + +def update_network_sparsity(model, sparsity): + for name, module in model.named_modules(): + if isinstance(module, SparseLinear): + module.cur_sparsity = sparsity + + +def schedule_sparsity_ratio( + step, + total_step, + frequency, + initial_warmup, + final_warmup, + initial_sparsity, + final_sparsity, +): + if step <= initial_warmup * total_step: + sparsity = initial_sparsity + elif step > (total_step - final_warmup * total_step): + sparsity = final_sparsity + else: + spars_warmup_steps = initial_warmup * total_step + spars_schedu_steps = (final_warmup + initial_warmup) * total_step + step = (step - spars_warmup_steps) // frequency * frequency + mul_coeff = 1 - step / (total_step - spars_schedu_steps) + sparsity = final_sparsity + (initial_sparsity - final_sparsity) * ( + mul_coeff**3) + return sparsity + + +def generate_sparse_model(model, logger=None): + # generate sparse weight for saving + for name, module in model.named_modules(): + if isinstance(module, SparseLinear): + module.convert() + + _setattr(model, name, module.module) + + if is_master(): + if logger: + logger.info(f'convert {name} weight to sparse weight, \ + sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' + ) + else: + print(f'convert {name} weight to sparse, \ + sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' + ) diff --git a/tests/trainers/hooks/compression/__init__.py b/tests/trainers/hooks/compression/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/trainers/hooks/compression/test_sparsity_hook.py b/tests/trainers/hooks/compression/test_sparsity_hook.py new file mode 100644 index 00000000..4af4dcdb --- /dev/null +++ b/tests/trainers/hooks/compression/test_sparsity_hook.py @@ -0,0 +1,113 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import shutil +import tempfile +import unittest + +import json +import numpy as np +import torch +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import MultiStepLR + +from modelscope.metainfo import Trainers +from modelscope.models.base import Model +from modelscope.trainers import build_trainer +from modelscope.utils.constant import ModelFile, TrainerStages +from modelscope.utils.test_utils import create_dummy_test_dataset + +dummy_dataset = create_dummy_test_dataset( + np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) + + +class DummyModel(nn.Module, Model): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 10) + self.bn = nn.BatchNorm1d(10) + + def forward(self, feat, labels): + x = self.linear(feat) + + x = self.bn(x) + loss = torch.sum(x) + return dict(logits=x, loss=loss) + + +class SparsityHookTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def test_sparsity_hook(self): + json_cfg = { + 'task': 'image_classification', + 'train': { + 'work_dir': + self.tmp_dir, + 'dataloader': { + 'batch_size_per_gpu': 2, + 'workers_per_gpu': 1 + }, + 'hooks': [{ + 'type': 'SparsityHook', + 'pruning_method': 'pst', + 'config': { + 'weight_rank': 1, + 'mask_rank': 1, + 'final_sparsity': 0.9, + 'frequency': 1, + }, + }], + }, + } + + config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) + with open(config_path, 'w') as f: + json.dump(json_cfg, f) + + model = DummyModel() + optimizer = SGD(model.parameters(), lr=0.01) + lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) + trainer_name = Trainers.default + kwargs = dict( + cfg_file=config_path, + model=model, + train_dataset=dummy_dataset, + optimizers=(optimizer, lr_scheduler), + max_epochs=5, + device='cpu', + ) + + trainer = build_trainer(trainer_name, kwargs) + train_dataloader = trainer._build_dataloader_with_dataset( + trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) + trainer.register_optimizers_hook() + trainer.register_hook_from_cfg(trainer.cfg.train.hooks) + trainer.train_dataloader = train_dataloader + trainer.data_loader = train_dataloader + trainer.invoke_hook(TrainerStages.before_run) + for i in range(trainer._epoch, trainer._max_epochs): + trainer.invoke_hook(TrainerStages.before_train_epoch) + for _, data_batch in enumerate(train_dataloader): + trainer.invoke_hook(TrainerStages.before_train_iter) + trainer.train_step(trainer.model, data_batch) + trainer.invoke_hook(TrainerStages.after_train_iter) + trainer.invoke_hook(TrainerStages.after_train_epoch) + trainer.invoke_hook(TrainerStages.after_run) + + self.assertEqual( + torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9) + + +if __name__ == '__main__': + unittest.main()