Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10198228 * [to #44847108] add sparsity hook (pst algorithm)master
| @@ -404,6 +404,9 @@ class Hooks(object): | |||
| IterTimerHook = 'IterTimerHook' | |||
| EvaluationHook = 'EvaluationHook' | |||
| # Compression | |||
| SparsityHook = 'SparsityHook' | |||
| class LR_Schedulers(object): | |||
| """learning rate scheduler is defined here | |||
| @@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .builder import HOOKS, build_hook | |||
| from .checkpoint_hook import BestCkptSaverHook, CheckpointHook | |||
| from .compression import SparsityHook | |||
| from .evaluation_hook import EvaluationHook | |||
| from .hook import Hook | |||
| from .iter_timer_hook import IterTimerHook | |||
| from .logger import TextLoggerHook, TensorboardHook | |||
| from .logger import TensorboardHook, TextLoggerHook | |||
| from .lr_scheduler_hook import LrSchedulerHook | |||
| from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook, | |||
| OptimizerHook, TorchAMPOptimizerHook) | |||
| @@ -19,6 +20,7 @@ else: | |||
| _import_structure = { | |||
| 'builder': ['HOOKS', 'build_hook'], | |||
| 'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'], | |||
| 'compression': ['SparsityHook'], | |||
| 'evaluation_hook': ['EvaluationHook'], | |||
| 'hook': ['Hook'], | |||
| 'iter_timer_hook': ['IterTimerHook'], | |||
| @@ -0,0 +1,24 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .sparsity_hook import SparsityHook | |||
| from .utils import SparseLinear, convert_sparse_network | |||
| else: | |||
| _import_structure = { | |||
| 'sparsity_hook': ['SparsityHook'], | |||
| 'utils': ['convert_sparse_network', 'SparseLinear'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,131 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| from modelscope import __version__ | |||
| from modelscope.metainfo import Hooks | |||
| from modelscope.trainers.hooks.builder import HOOKS | |||
| from modelscope.trainers.hooks.hook import Hook | |||
| from modelscope.trainers.hooks.priority import Priority | |||
| from modelscope.utils.checkpoint import save_checkpoint | |||
| from modelscope.utils.torch_utils import is_master | |||
| @HOOKS.register_module(module_name=Hooks.SparsityHook) | |||
| class SparsityHook(Hook): | |||
| PRIORITY = Priority.HIGHEST | |||
| def __init__(self, pruning_method, config={}, save_dir=None): | |||
| self.pruning_method = pruning_method | |||
| self.save_dir = save_dir | |||
| self.compress_module = config.get('compress_module', []) | |||
| self.weight_rank = config.get('weight_rank', 8) | |||
| self.weight_beta = config.get('weight_beta', 1) | |||
| self.mask_rank = config.get('mask_rank', 8) | |||
| self.mask_alpha1 = config.get('mask_alpha1', 1) | |||
| self.mask_alpha2 = config.get('mask_alpha2', 1) | |||
| self.step = 0 | |||
| self.total_step = 0 | |||
| self.frequency = config.get('frequency', 1) | |||
| self.initial_warmup = config.get('initial_warmup', 0.1) | |||
| self.final_warmup = config.get('final_warmup', 0.3) | |||
| self.initial_sparsity = config.get('initial_sparsity', 0.0) | |||
| self.final_sparsity = config.get('final_sparsity', 0.0) | |||
| def before_run(self, trainer): | |||
| import torch | |||
| from .utils import SparseLinear, convert_sparse_network | |||
| if self.save_dir is None: | |||
| self.save_dir = trainer.work_dir | |||
| if len(self.compress_module) == 0: | |||
| convert_sparse_network( | |||
| trainer.model, | |||
| pruning_method=self.pruning_method, | |||
| weight_rank=self.weight_rank, | |||
| weight_beta=self.weight_beta, | |||
| mask_rank=self.mask_rank, | |||
| mask_alpha1=self.mask_alpha1, | |||
| mask_alpha2=self.mask_alpha2, | |||
| logger=trainer.logger, | |||
| ) | |||
| else: | |||
| for cm in self.compress_module: | |||
| for name, module in trainer.model.named_modules(): | |||
| if name != cm: | |||
| continue | |||
| convert_sparse_network( | |||
| module, | |||
| pruning_method=self.pruning_method, | |||
| weight_rank=self.weight_rank, | |||
| weight_beta=self.weight_beta, | |||
| mask_rank=self.mask_rank, | |||
| mask_alpha1=self.mask_alpha1, | |||
| mask_alpha2=self.mask_alpha2, | |||
| logger=trainer.logger, | |||
| ) | |||
| for i in range(len(trainer.optimizer.param_groups)): | |||
| new_train_params = [] | |||
| for param in trainer.optimizer.param_groups[i]['params']: | |||
| is_find = False | |||
| for name, module in trainer.model.named_modules(): | |||
| if isinstance(module, SparseLinear): | |||
| if torch.equal(param.half(), | |||
| module.weight.data.half()): | |||
| is_find = True | |||
| break | |||
| if not is_find: | |||
| new_train_params.append(param) | |||
| trainer.optimizer.param_groups[i]['params'] = new_train_params | |||
| new_params = [] | |||
| for name, module in trainer.model.named_modules(): | |||
| if isinstance(module, SparseLinear): | |||
| new_params.extend( | |||
| [p for p in module.parameters() if p.requires_grad]) | |||
| trainer.optimizer.add_param_group({'params': new_params}) | |||
| self.total_step = trainer.iters_per_epoch * trainer._max_epochs | |||
| def before_train_iter(self, trainer): | |||
| from .utils import schedule_sparsity_ratio, update_network_sparsity | |||
| cur_sparsity = schedule_sparsity_ratio( | |||
| self.step, | |||
| self.total_step, | |||
| self.frequency, | |||
| self.initial_warmup, | |||
| self.final_warmup, | |||
| self.initial_sparsity, | |||
| self.final_sparsity, | |||
| ) | |||
| update_network_sparsity(trainer.model, cur_sparsity) | |||
| if is_master(): | |||
| trainer.logger.info( | |||
| f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}' | |||
| ) | |||
| self.step += 1 | |||
| def after_run(self, trainer): | |||
| from .utils import generate_sparse_model | |||
| generate_sparse_model(trainer.model, logger=trainer.logger) | |||
| self._save_checkpoint(trainer) | |||
| def _save_checkpoint(self, trainer): | |||
| if is_master(): | |||
| trainer.logger.info('Saving checkpoint at final compress') | |||
| cur_save_name = os.path.join(self.save_dir, 'compress_model.pth') | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| @@ -0,0 +1,208 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import torch | |||
| import torch.nn as nn | |||
| from modelscope.utils.torch_utils import is_master | |||
| class SparseBinarizer(torch.autograd.Function): | |||
| @staticmethod | |||
| def forward(ctx, mask_scores, sparsity): | |||
| num_prune = int(mask_scores.numel() * sparsity) | |||
| prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune] | |||
| mask = mask_scores.clone().fill_(1) | |||
| mask.reshape(-1)[prune_indices] = 0.0 | |||
| return mask | |||
| @staticmethod | |||
| def backward(ctx, gradOutput): | |||
| return gradOutput, None | |||
| class SparseLinear(nn.Module): | |||
| """ | |||
| Fully Connected layer with on the fly adaptive mask. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| module, | |||
| pruning_method='pst', | |||
| weight_rank=8, | |||
| weight_beta=1.0, | |||
| mask_rank=8, | |||
| mask_alpha1=1.0, | |||
| mask_alpha2=1.0, | |||
| ): | |||
| super(SparseLinear, self).__init__() | |||
| self.module = module | |||
| out_features = self.module.weight.shape[0] | |||
| in_features = self.module.weight.shape[1] | |||
| self.weight = self.module.weight | |||
| self.module.weight = None | |||
| self.module._parameters.pop('weight') | |||
| self.pruning_method = pruning_method | |||
| self.cur_sparsity = 0.0 | |||
| if self.pruning_method == 'pst': | |||
| self.weight_rank = weight_rank | |||
| self.weight_beta = weight_beta | |||
| self.mask_rank = mask_rank | |||
| self.mask_alpha1 = mask_alpha1 | |||
| self.mask_alpha2 = mask_alpha2 | |||
| # create trainable params | |||
| self.weight_U = nn.Parameter( | |||
| torch.randn(out_features, self.weight_rank).to( | |||
| device=self.weight.device, dtype=self.weight.dtype)) | |||
| self.weight_V = nn.Parameter( | |||
| torch.zeros(self.weight_rank, in_features).to( | |||
| device=self.weight.device, dtype=self.weight.dtype)) | |||
| self.mask_scores_A = nn.Parameter( | |||
| torch.randn(out_features, self.mask_rank).to( | |||
| device=self.weight.device, dtype=self.weight.dtype)) | |||
| self.mask_scores_B = nn.Parameter( | |||
| torch.zeros(self.mask_rank, in_features).to( | |||
| device=self.weight.device, dtype=self.weight.dtype)) | |||
| self.mask_scores_R = nn.Parameter( | |||
| torch.zeros(out_features).to( | |||
| device=self.weight.device, dtype=self.weight.dtype)) | |||
| self.mask_scores_C = nn.Parameter( | |||
| torch.zeros(in_features).to( | |||
| device=self.weight.device, dtype=self.weight.dtype)) | |||
| self.weight.requires_grad = False | |||
| if self.module.bias is not None: | |||
| self.module.bias.requires_grad = False | |||
| def forward(self, *inputs): | |||
| if self.pruning_method == 'pst': | |||
| weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V | |||
| mask_scores = ( | |||
| weight.abs() | |||
| + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B | |||
| + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) | |||
| + self.mask_scores_C.unsqueeze(0))) | |||
| mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) | |||
| masked_weight = mask * weight | |||
| self.module.weight = masked_weight | |||
| return self.module(*inputs) | |||
| else: | |||
| return self.module(*inputs) | |||
| def convert(self): | |||
| if self.pruning_method == 'pst': | |||
| weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V | |||
| mask_scores = ( | |||
| weight.abs() | |||
| + self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B | |||
| + self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1) | |||
| + self.mask_scores_C.unsqueeze(0))) | |||
| mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity) | |||
| masked_weight = mask * weight | |||
| self.module.weight = nn.Parameter(masked_weight.data) | |||
| def _setattr(model, name, module): | |||
| name_list = name.split('.') | |||
| for name in name_list[:-1]: | |||
| model = getattr(model, name) | |||
| setattr(model, name_list[-1], module) | |||
| def convert_sparse_network( | |||
| model, | |||
| pruning_method, | |||
| weight_rank, | |||
| weight_beta, | |||
| mask_rank, | |||
| mask_alpha1, | |||
| mask_alpha2, | |||
| logger=None, | |||
| ): | |||
| compress_module = [nn.Linear] | |||
| try: | |||
| from megatron import mpu | |||
| compress_module.extend( | |||
| [mpu.RowParallelLinear, mpu.ColumnParallelLinear]) | |||
| except ImportError: | |||
| pass | |||
| for name, module in model.named_modules(): | |||
| if type(module) in compress_module: | |||
| new_module = SparseLinear( | |||
| module, | |||
| pruning_method, | |||
| weight_rank, | |||
| weight_beta, | |||
| mask_rank, | |||
| mask_alpha1, | |||
| mask_alpha2, | |||
| ) | |||
| # replace original module by new sparse module | |||
| _setattr(model, name, new_module) | |||
| if is_master(): | |||
| if logger: | |||
| logger.info(f'convert {name} to sparse module.') | |||
| else: | |||
| print(f'convert {name} to sparse module.') | |||
| def update_network_sparsity(model, sparsity): | |||
| for name, module in model.named_modules(): | |||
| if isinstance(module, SparseLinear): | |||
| module.cur_sparsity = sparsity | |||
| def schedule_sparsity_ratio( | |||
| step, | |||
| total_step, | |||
| frequency, | |||
| initial_warmup, | |||
| final_warmup, | |||
| initial_sparsity, | |||
| final_sparsity, | |||
| ): | |||
| if step <= initial_warmup * total_step: | |||
| sparsity = initial_sparsity | |||
| elif step > (total_step - final_warmup * total_step): | |||
| sparsity = final_sparsity | |||
| else: | |||
| spars_warmup_steps = initial_warmup * total_step | |||
| spars_schedu_steps = (final_warmup + initial_warmup) * total_step | |||
| step = (step - spars_warmup_steps) // frequency * frequency | |||
| mul_coeff = 1 - step / (total_step - spars_schedu_steps) | |||
| sparsity = final_sparsity + (initial_sparsity - final_sparsity) * ( | |||
| mul_coeff**3) | |||
| return sparsity | |||
| def generate_sparse_model(model, logger=None): | |||
| # generate sparse weight for saving | |||
| for name, module in model.named_modules(): | |||
| if isinstance(module, SparseLinear): | |||
| module.convert() | |||
| _setattr(model, name, module.module) | |||
| if is_master(): | |||
| if logger: | |||
| logger.info(f'convert {name} weight to sparse weight, \ | |||
| sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' | |||
| ) | |||
| else: | |||
| print(f'convert {name} weight to sparse, \ | |||
| sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.' | |||
| ) | |||
| @@ -0,0 +1,113 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.optim import SGD | |||
| from torch.optim.lr_scheduler import MultiStepLR | |||
| from modelscope.metainfo import Trainers | |||
| from modelscope.models.base import Model | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import ModelFile, TrainerStages | |||
| from modelscope.utils.test_utils import create_dummy_test_dataset | |||
| dummy_dataset = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10) | |||
| class DummyModel(nn.Module, Model): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.linear = nn.Linear(5, 10) | |||
| self.bn = nn.BatchNorm1d(10) | |||
| def forward(self, feat, labels): | |||
| x = self.linear(feat) | |||
| x = self.bn(x) | |||
| loss = torch.sum(x) | |||
| return dict(logits=x, loss=loss) | |||
| class SparsityHookTest(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| super().tearDown() | |||
| shutil.rmtree(self.tmp_dir) | |||
| def test_sparsity_hook(self): | |||
| json_cfg = { | |||
| 'task': 'image_classification', | |||
| 'train': { | |||
| 'work_dir': | |||
| self.tmp_dir, | |||
| 'dataloader': { | |||
| 'batch_size_per_gpu': 2, | |||
| 'workers_per_gpu': 1 | |||
| }, | |||
| 'hooks': [{ | |||
| 'type': 'SparsityHook', | |||
| 'pruning_method': 'pst', | |||
| 'config': { | |||
| 'weight_rank': 1, | |||
| 'mask_rank': 1, | |||
| 'final_sparsity': 0.9, | |||
| 'frequency': 1, | |||
| }, | |||
| }], | |||
| }, | |||
| } | |||
| config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION) | |||
| with open(config_path, 'w') as f: | |||
| json.dump(json_cfg, f) | |||
| model = DummyModel() | |||
| optimizer = SGD(model.parameters(), lr=0.01) | |||
| lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4]) | |||
| trainer_name = Trainers.default | |||
| kwargs = dict( | |||
| cfg_file=config_path, | |||
| model=model, | |||
| train_dataset=dummy_dataset, | |||
| optimizers=(optimizer, lr_scheduler), | |||
| max_epochs=5, | |||
| device='cpu', | |||
| ) | |||
| trainer = build_trainer(trainer_name, kwargs) | |||
| train_dataloader = trainer._build_dataloader_with_dataset( | |||
| trainer.train_dataset, **trainer.cfg.train.get('dataloader', {})) | |||
| trainer.register_optimizers_hook() | |||
| trainer.register_hook_from_cfg(trainer.cfg.train.hooks) | |||
| trainer.train_dataloader = train_dataloader | |||
| trainer.data_loader = train_dataloader | |||
| trainer.invoke_hook(TrainerStages.before_run) | |||
| for i in range(trainer._epoch, trainer._max_epochs): | |||
| trainer.invoke_hook(TrainerStages.before_train_epoch) | |||
| for _, data_batch in enumerate(train_dataloader): | |||
| trainer.invoke_hook(TrainerStages.before_train_iter) | |||
| trainer.train_step(trainer.model, data_batch) | |||
| trainer.invoke_hook(TrainerStages.after_train_iter) | |||
| trainer.invoke_hook(TrainerStages.after_train_epoch) | |||
| trainer.invoke_hook(TrainerStages.after_run) | |||
| self.assertEqual( | |||
| torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||