Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10198228 * [to #44847108] add sparsity hook (pst algorithm)master
| @@ -404,6 +404,9 @@ class Hooks(object): | |||||
| IterTimerHook = 'IterTimerHook' | IterTimerHook = 'IterTimerHook' | ||||
| EvaluationHook = 'EvaluationHook' | EvaluationHook = 'EvaluationHook' | ||||
| # Compression | |||||
| SparsityHook = 'SparsityHook' | |||||
| class LR_Schedulers(object): | class LR_Schedulers(object): | ||||
| """learning rate scheduler is defined here | """learning rate scheduler is defined here | ||||
| @@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .builder import HOOKS, build_hook | from .builder import HOOKS, build_hook | ||||
| from .checkpoint_hook import BestCkptSaverHook, CheckpointHook | from .checkpoint_hook import BestCkptSaverHook, CheckpointHook | ||||
| from .compression import SparsityHook | |||||
| from .evaluation_hook import EvaluationHook | from .evaluation_hook import EvaluationHook | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .iter_timer_hook import IterTimerHook | from .iter_timer_hook import IterTimerHook | ||||
| from .logger import TextLoggerHook, TensorboardHook | |||||
| from .logger import TensorboardHook, TextLoggerHook | |||||
| from .lr_scheduler_hook import LrSchedulerHook | from .lr_scheduler_hook import LrSchedulerHook | ||||
| from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook, | from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook, | ||||
| OptimizerHook, TorchAMPOptimizerHook) | OptimizerHook, TorchAMPOptimizerHook) | ||||
| @@ -19,6 +20,7 @@ else: | |||||
| _import_structure = { | _import_structure = { | ||||
| 'builder': ['HOOKS', 'build_hook'], | 'builder': ['HOOKS', 'build_hook'], | ||||
| 'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'], | 'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'], | ||||
| 'compression': ['SparsityHook'], | |||||
| 'evaluation_hook': ['EvaluationHook'], | 'evaluation_hook': ['EvaluationHook'], | ||||
| 'hook': ['Hook'], | 'hook': ['Hook'], | ||||
| 'iter_timer_hook': ['IterTimerHook'], | 'iter_timer_hook': ['IterTimerHook'], | ||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .sparsity_hook import SparsityHook | |||||
| from .utils import SparseLinear, convert_sparse_network | |||||
| else: | |||||
| _import_structure = { | |||||
| 'sparsity_hook': ['SparsityHook'], | |||||
| 'utils': ['convert_sparse_network', 'SparseLinear'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,131 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| from modelscope import __version__ | |||||
| from modelscope.metainfo import Hooks | |||||
| from modelscope.trainers.hooks.builder import HOOKS | |||||
| from modelscope.trainers.hooks.hook import Hook | |||||
| from modelscope.trainers.hooks.priority import Priority | |||||
| from modelscope.utils.checkpoint import save_checkpoint | |||||
| from modelscope.utils.torch_utils import is_master | |||||
| @HOOKS.register_module(module_name=Hooks.SparsityHook) | |||||
| class SparsityHook(Hook): | |||||
| PRIORITY = Priority.HIGHEST | |||||
| def __init__(self, pruning_method, config={}, save_dir=None): | |||||
| self.pruning_method = pruning_method | |||||
| self.save_dir = save_dir | |||||
| self.compress_module = config.get('compress_module', []) | |||||
| self.weight_rank = config.get('weight_rank', 8) | |||||
| self.weight_beta = config.get('weight_beta', 1) | |||||
| self.mask_rank = config.get('mask_rank', 8) | |||||
| self.mask_alpha1 = config.get('mask_alpha1', 1) | |||||
| self.mask_alpha2 = config.get('mask_alpha2', 1) | |||||
| self.step = 0 | |||||
| self.total_step = 0 | |||||
| self.frequency = config.get('frequency', 1) | |||||
| self.initial_warmup = config.get('initial_warmup', 0.1) | |||||
| self.final_warmup = config.get('final_warmup', 0.3) | |||||
| self.initial_sparsity = config.get('initial_sparsity', 0.0) | |||||
| self.final_sparsity = config.get('final_sparsity', 0.0) | |||||
| def before_run(self, trainer): | |||||
| import torch | |||||
| from .utils import SparseLinear, convert_sparse_network | |||||
| if self.save_dir is None: | |||||
| self.save_dir = trainer.work_dir | |||||
| if len(self.compress_module) == 0: | |||||
| convert_sparse_network( | |||||
| trainer.model, | |||||
| pruning_method=self.pruning_method, | |||||
| weight_rank=self.weight_rank, | |||||
| weight_beta=self.weight_beta, | |||||
| mask_rank=self.mask_rank, | |||||
| mask_alpha1=self.mask_alpha1, | |||||
| mask_alpha2=self.mask_alpha2, | |||||
| logger=trainer.logger, | |||||
| ) | |||||
| else: | |||||
| for cm in self.compress_module: | |||||
| for name, module in trainer.model.named_modules(): | |||||
| if name != cm: | |||||
| continue | |||||
| convert_sparse_network( | |||||
| module, | |||||
| pruning_method=self.pruning_method, | |||||
| weight_rank=self.weight_rank, | |||||
| weight_beta=self.weight_beta, | |||||
| mask_rank=self.mask_rank, | |||||
| mask_alpha1=self.mask_alpha1, | |||||
| mask_alpha2=self.mask_alpha2, | |||||
| logger=trainer.logger, | |||||
| ) | |||||
| for i in range(len(trainer.optimizer.param_groups)): | |||||
| new_train_params = [] | |||||
| for param in trainer.optimizer.param_groups[i]['params']: | |||||
| is_find = False | |||||
| for name, module in trainer.model.named_modules(): | |||||
| if isinstance(module, SparseLinear): | |||||
| if torch.equal(param.half(), | |||||
| module.weight.data.half()): | |||||
| is_find = True | |||||
| break | |||||
| if not is_find: | |||||
| new_train_params.append(param) | |||||
| trainer.optimizer.param_groups[i]['params'] = new_train_params | |||||
| new_params = [] | |||||
| for name, module in trainer.model.named_modules(): | |||||
| if isinstance(module, SparseLinear): | |||||
| new_params.extend( | |||||
| [p for p in module.parameters() if p.requires_grad]) | |||||
| trainer.optimizer.add_param_group({'params': new_params}) | |||||
| self.total_step = trainer.iters_per_epoch * trainer._max_epochs | |||||
| def before_train_iter(self, trainer): | |||||
| from .utils import schedule_sparsity_ratio, update_network_sparsity | |||||
| cur_sparsity = schedule_sparsity_ratio( | |||||
| self.step, | |||||
| self.total_step, | |||||
| self.frequency, | |||||
| self.initial_warmup, | |||||
| self.final_warmup, | |||||
| self.initial_sparsity, | |||||
| self.final_sparsity, | |||||
| ) | |||||
| update_network_sparsity(trainer.model, cur_sparsity) | |||||
| if is_master(): | |||||
| trainer.logger.info( | |||||
| f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}' | |||||
| ) | |||||
| self.step += 1 | |||||
| def after_run(self, trainer): | |||||
| from .utils import generate_sparse_model | |||||
| generate_sparse_model(trainer.model, logger=trainer.logger) | |||||
| self._save_checkpoint(trainer) | |||||
| def _save_checkpoint(self, trainer): | |||||
| if is_master(): | |||||
| trainer.logger.info('Saving checkpoint at final compress') | |||||
| cur_save_name = os.path.join(self.save_dir, 'compress_model.pth') | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| @@ -0,0 +1,208 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from modelscope.utils.torch_utils import is_master | |||||
| class SparseBinarizer(torch.autograd.Function): | |||||
| @staticmethod | |||||
| def forward(ctx, mask_scores, sparsity): | |||||
| num_prune = int(mask_scores.numel() * sparsity) | |||||
| prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune] | |||||
| mask = mask_scores.clone().fill_(1) | |||||
| mask.reshape(-1)[prune_indices] = 0.0 | |||||
| return mask | |||||
| @staticmethod | |||||
| def backward(ctx, gradOutput): | |||||
| return gradOutput, None | |||||
| class SparseLinear(nn.Module): | |||||
| """ | |||||
| Fully Connected layer with on the fly adaptive mask. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| module, | |||||
| pruning_method='pst', | |||||
| weight_rank=8, | |||||
| weight_beta=1.0, | |||||
| mask_rank=8, | |||||
| mask_alpha1=1.0, | |||||
| mask_alpha2=1.0, | |||||
| ): | |||||
| super(SparseLinear, self).__init__() | |||||
| self.module = module | |||||
| out_features = self.module.weight.shape[0] | |||||
| in_features = self.module.weight.shape[1] | |||||
| self.weight = self.module.weight | |||||
| self.module.weight = None | |||||
| self.module._parameters.pop('weight') | |||||
| self.pruning_method = pruning_method | |||||
| self.cur_sparsity = 0.0 | |||||
| if self.pruning_method == 'pst': | |||||
| self.weight_rank = weight_rank | |||||
| self.weight_beta = weight_beta | |||||
| self.mask_rank = mask_rank | |||||
| self.mask_alpha1 = mask_alpha1 | |||||
| self.mask_alpha2 = mask_alpha2 | |||||
| # create trainable params | |||||
| self.weight_U = nn.Parameter( | |||||
| torch.randn(out_features, self.weight_rank).to( | |||||
| device=self.weight.device, dtype=self.weight.dtype)) | |||||
| self.weight_V = nn.Parameter( | |||||
| torch.zeros(self.weight_rank, in_features).to( | |||||
| device=self.weight.device, dtype=self.weight.dtype)) | |||||
| self.mask_scores_A = nn.Parameter( | |||||
| torch.randn(out_features, self.mask_rank).to( | |||||
| device=self.weight.device, dtype=self.weight.dtype)) | |||||
| self.mask_scores_B = nn.Parameter( | |||||
| torch.zeros(self.mask_rank, in_features).to( | |||||
| device=self.weight.device, dtype=self.weight.dtype)) | |||||
| self.mask_scores_R = nn.Parameter( | |||||
| torch.zeros(out_features).to( | |||||
| device=self.weight.device, dtype=self.weight.dtype)) | |||||
| self.mask_scores_C = nn.Parameter( | |||||
| torch.zeros(in_features).to( | |||||
| device=self.weight.device, dtype=self.weight.dtype)) | |||||
| self.weight.requires_grad = False | |||||
| if self.module.bias is not None: | |||||
| self.module.bias.requires_grad = False | |||||
| def forward(self, *inputs): | |||||
| if self.pruning_method == 'pst': | |||||
| weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V | |||||
| mask_scores = ( | |||||
| weight.abs() | |||||
| + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B | |||||
| + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) | |||||
| + self.mask_scores_C.unsqueeze(0))) | |||||
| mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) | |||||
| masked_weight = mask * weight | |||||
| self.module.weight = masked_weight | |||||
| return self.module(*inputs) | |||||
| else: | |||||
| return self.module(*inputs) | |||||
| def convert(self): | |||||
| if self.pruning_method == 'pst': | |||||
| weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V | |||||
| mask_scores = ( | |||||
| weight.abs() | |||||
| + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B | |||||
| + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) | |||||
| + self.mask_scores_C.unsqueeze(0))) | |||||
| mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) | |||||
| masked_weight = mask * weight | |||||
| self.module.weight = nn.Parameter(masked_weight.data) | |||||
| def _setattr(model, name, module): | |||||
| name_list = name.split('.') | |||||
| for name in name_list[:-1]: | |||||
| model = getattr(model, name) | |||||
| setattr(model, name_list[-1], module) | |||||
| def convert_sparse_network( | |||||
| model, | |||||
| pruning_method, | |||||
| weight_rank, | |||||
| weight_beta, | |||||
| mask_rank, | |||||
| mask_alpha1, | |||||
| mask_alpha2, | |||||
| logger=None, | |||||
| ): | |||||
| compress_module = [nn.Linear] | |||||
| try: | |||||
| from megatron import mpu | |||||
| compress_module.extend( | |||||
| [mpu.RowParallelLinear, mpu.ColumnParallelLinear]) | |||||
| except ImportError: | |||||
| pass | |||||
| for name, module in model.named_modules(): | |||||
| if type(module) in compress_module: | |||||
| new_module = SparseLinear( | |||||
| module, | |||||
| pruning_method, | |||||
| weight_rank, | |||||
| weight_beta, | |||||
| mask_rank, | |||||
| mask_alpha1, | |||||
| mask_alpha2, | |||||
| ) | |||||
| # replace original module by new sparse module | |||||
| _setattr(model, name, new_module) | |||||
| if is_master(): | |||||
| if logger: | |||||
| logger.info(f'convert {name} to sparse module.') | |||||
| else: | |||||
| print(f'convert {name} to sparse module.') | |||||
| def update_network_sparsity(model, sparsity): | |||||
| for name, module in model.named_modules(): | |||||
| if isinstance(module, SparseLinear): | |||||
| module.cur_sparsity = sparsity | |||||
| def schedule_sparsity_ratio( | |||||
| step, | |||||
| total_step, | |||||
| frequency, | |||||
| initial_warmup, | |||||
| final_warmup, | |||||
| initial_sparsity, | |||||
| final_sparsity, | |||||
| ): | |||||
| if step <= initial_warmup * total_step: | |||||
| sparsity = initial_sparsity | |||||
| elif step > (total_step - final_warmup * total_step): | |||||
| sparsity = final_sparsity | |||||
| else: | |||||
| spars_warmup_steps = initial_warmup * total_step | |||||
| spars_schedu_steps = (final_warmup + initial_warmup) * total_step | |||||
| step = (step - spars_warmup_steps) // frequency * frequency | |||||
| mul_coeff = 1 - step / (total_step - spars_schedu_steps) | |||||
| sparsity = final_sparsity + (initial_sparsity - final_sparsity) * ( | |||||
| mul_coeff**3) | |||||
| return sparsity | |||||
| def generate_sparse_model(model, logger=None): | |||||
| # generate sparse weight for saving | |||||
| for name, module in model.named_modules(): | |||||
| if isinstance(module, SparseLinear): | |||||
| module.convert() | |||||
| _setattr(model, name, module.module) | |||||
| if is_master(): | |||||
| if logger: | |||||
| logger.info(f'convert {name} weight to sparse weight, \ | |||||
| sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' | |||||
| ) | |||||
| else: | |||||
| print(f'convert {name} weight to sparse, \ | |||||
| sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' | |||||
| ) | |||||
| @@ -0,0 +1,113 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import MultiStepLR | |||||
| from modelscope.metainfo import Trainers | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import ModelFile, TrainerStages | |||||
| from modelscope.utils.test_utils import create_dummy_test_dataset | |||||
| dummy_dataset = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | |||||
| class DummyModel(nn.Module, Model): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(5, 10) | |||||
| self.bn = nn.BatchNorm1d(10) | |||||
| def forward(self, feat, labels): | |||||
| x = self.linear(feat) | |||||
| x = self.bn(x) | |||||
| loss = torch.sum(x) | |||||
| return dict(logits=x, loss=loss) | |||||
| class SparsityHookTest(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| def test_sparsity_hook(self): | |||||
| json_cfg = { | |||||
| 'task': 'image_classification', | |||||
| 'train': { | |||||
| 'work_dir': | |||||
| self.tmp_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'SparsityHook', | |||||
| 'pruning_method': 'pst', | |||||
| 'config': { | |||||
| 'weight_rank': 1, | |||||
| 'mask_rank': 1, | |||||
| 'final_sparsity': 0.9, | |||||
| 'frequency': 1, | |||||
| }, | |||||
| }], | |||||
| }, | |||||
| } | |||||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel() | |||||
| optimizer = SGD(model.parameters(), lr=0.01) | |||||
| lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) | |||||
| trainer_name = Trainers.default | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| train_dataset=dummy_dataset, | |||||
| optimizers=(optimizer, lr_scheduler), | |||||
| max_epochs=5, | |||||
| device='cpu', | |||||
| ) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| train_dataloader = trainer._build_dataloader_with_dataset( | |||||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | |||||
| trainer.register_optimizers_hook() | |||||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | |||||
| trainer.train_dataloader = train_dataloader | |||||
| trainer.data_loader = train_dataloader | |||||
| trainer.invoke_hook(TrainerStages.before_run) | |||||
| for i in range(trainer._epoch, trainer._max_epochs): | |||||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||||
| for _, data_batch in enumerate(train_dataloader): | |||||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||||
| trainer.train_step(trainer.model, data_batch) | |||||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||||
| trainer.invoke_hook(TrainerStages.after_run) | |||||
| self.assertEqual( | |||||
| torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||