Browse Source

[to #44847108] add sparsity hook (pst algorithm)

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10198228

    * [to #44847108] add sparsity hook (pst algorithm)
master
laiyin.lyc 3 years ago
parent
commit
09d2296f36
7 changed files with 482 additions and 1 deletions
  1. +3
    -0
      modelscope/metainfo.py
  2. +3
    -1
      modelscope/trainers/hooks/__init__.py
  3. +24
    -0
      modelscope/trainers/hooks/compression/__init__.py
  4. +131
    -0
      modelscope/trainers/hooks/compression/sparsity_hook.py
  5. +208
    -0
      modelscope/trainers/hooks/compression/utils.py
  6. +0
    -0
      tests/trainers/hooks/compression/__init__.py
  7. +113
    -0
      tests/trainers/hooks/compression/test_sparsity_hook.py

+ 3
- 0
modelscope/metainfo.py View File

@@ -404,6 +404,9 @@ class Hooks(object):
IterTimerHook = 'IterTimerHook'
EvaluationHook = 'EvaluationHook'

# Compression
SparsityHook = 'SparsityHook'


class LR_Schedulers(object):
"""learning rate scheduler is defined here


+ 3
- 1
modelscope/trainers/hooks/__init__.py View File

@@ -6,10 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .builder import HOOKS, build_hook
from .checkpoint_hook import BestCkptSaverHook, CheckpointHook
from .compression import SparsityHook
from .evaluation_hook import EvaluationHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .logger import TextLoggerHook, TensorboardHook
from .logger import TensorboardHook, TextLoggerHook
from .lr_scheduler_hook import LrSchedulerHook
from .optimizer import (ApexAMPOptimizerHook, NoneOptimizerHook,
OptimizerHook, TorchAMPOptimizerHook)
@@ -19,6 +20,7 @@ else:
_import_structure = {
'builder': ['HOOKS', 'build_hook'],
'checkpoint_hook': ['BestCkptSaverHook', 'CheckpointHook'],
'compression': ['SparsityHook'],
'evaluation_hook': ['EvaluationHook'],
'hook': ['Hook'],
'iter_timer_hook': ['IterTimerHook'],


+ 24
- 0
modelscope/trainers/hooks/compression/__init__.py View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .sparsity_hook import SparsityHook
from .utils import SparseLinear, convert_sparse_network

else:
_import_structure = {
'sparsity_hook': ['SparsityHook'],
'utils': ['convert_sparse_network', 'SparseLinear'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 131
- 0
modelscope/trainers/hooks/compression/sparsity_hook.py View File

@@ -0,0 +1,131 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os

from modelscope import __version__
from modelscope.metainfo import Hooks
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.hook import Hook
from modelscope.trainers.hooks.priority import Priority
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.torch_utils import is_master


@HOOKS.register_module(module_name=Hooks.SparsityHook)
class SparsityHook(Hook):

PRIORITY = Priority.HIGHEST

def __init__(self, pruning_method, config={}, save_dir=None):
self.pruning_method = pruning_method
self.save_dir = save_dir

self.compress_module = config.get('compress_module', [])
self.weight_rank = config.get('weight_rank', 8)
self.weight_beta = config.get('weight_beta', 1)
self.mask_rank = config.get('mask_rank', 8)
self.mask_alpha1 = config.get('mask_alpha1', 1)
self.mask_alpha2 = config.get('mask_alpha2', 1)

self.step = 0
self.total_step = 0
self.frequency = config.get('frequency', 1)
self.initial_warmup = config.get('initial_warmup', 0.1)
self.final_warmup = config.get('final_warmup', 0.3)
self.initial_sparsity = config.get('initial_sparsity', 0.0)
self.final_sparsity = config.get('final_sparsity', 0.0)

def before_run(self, trainer):
import torch

from .utils import SparseLinear, convert_sparse_network

if self.save_dir is None:
self.save_dir = trainer.work_dir

if len(self.compress_module) == 0:
convert_sparse_network(
trainer.model,
pruning_method=self.pruning_method,
weight_rank=self.weight_rank,
weight_beta=self.weight_beta,
mask_rank=self.mask_rank,
mask_alpha1=self.mask_alpha1,
mask_alpha2=self.mask_alpha2,
logger=trainer.logger,
)
else:
for cm in self.compress_module:
for name, module in trainer.model.named_modules():
if name != cm:
continue
convert_sparse_network(
module,
pruning_method=self.pruning_method,
weight_rank=self.weight_rank,
weight_beta=self.weight_beta,
mask_rank=self.mask_rank,
mask_alpha1=self.mask_alpha1,
mask_alpha2=self.mask_alpha2,
logger=trainer.logger,
)

for i in range(len(trainer.optimizer.param_groups)):
new_train_params = []
for param in trainer.optimizer.param_groups[i]['params']:
is_find = False
for name, module in trainer.model.named_modules():
if isinstance(module, SparseLinear):
if torch.equal(param.half(),
module.weight.data.half()):
is_find = True
break

if not is_find:
new_train_params.append(param)

trainer.optimizer.param_groups[i]['params'] = new_train_params

new_params = []
for name, module in trainer.model.named_modules():
if isinstance(module, SparseLinear):
new_params.extend(
[p for p in module.parameters() if p.requires_grad])

trainer.optimizer.add_param_group({'params': new_params})

self.total_step = trainer.iters_per_epoch * trainer._max_epochs

def before_train_iter(self, trainer):
from .utils import schedule_sparsity_ratio, update_network_sparsity

cur_sparsity = schedule_sparsity_ratio(
self.step,
self.total_step,
self.frequency,
self.initial_warmup,
self.final_warmup,
self.initial_sparsity,
self.final_sparsity,
)

update_network_sparsity(trainer.model, cur_sparsity)

if is_master():
trainer.logger.info(
f'Step[{self.step}/{self.total_step}] current sparsity ratio = {cur_sparsity}'
)

self.step += 1

def after_run(self, trainer):
from .utils import generate_sparse_model

generate_sparse_model(trainer.model, logger=trainer.logger)

self._save_checkpoint(trainer)

def _save_checkpoint(self, trainer):
if is_master():
trainer.logger.info('Saving checkpoint at final compress')
cur_save_name = os.path.join(self.save_dir, 'compress_model.pth')
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)

+ 208
- 0
modelscope/trainers/hooks/compression/utils.py View File

@@ -0,0 +1,208 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import torch
import torch.nn as nn

from modelscope.utils.torch_utils import is_master


class SparseBinarizer(torch.autograd.Function):

@staticmethod
def forward(ctx, mask_scores, sparsity):
num_prune = int(mask_scores.numel() * sparsity)
prune_indices = torch.argsort(mask_scores.reshape(-1))[:num_prune]
mask = mask_scores.clone().fill_(1)
mask.reshape(-1)[prune_indices] = 0.0
return mask

@staticmethod
def backward(ctx, gradOutput):
return gradOutput, None


class SparseLinear(nn.Module):
"""
Fully Connected layer with on the fly adaptive mask.
"""

def __init__(
self,
module,
pruning_method='pst',
weight_rank=8,
weight_beta=1.0,
mask_rank=8,
mask_alpha1=1.0,
mask_alpha2=1.0,
):
super(SparseLinear, self).__init__()
self.module = module
out_features = self.module.weight.shape[0]
in_features = self.module.weight.shape[1]

self.weight = self.module.weight
self.module.weight = None
self.module._parameters.pop('weight')

self.pruning_method = pruning_method

self.cur_sparsity = 0.0

if self.pruning_method == 'pst':
self.weight_rank = weight_rank
self.weight_beta = weight_beta
self.mask_rank = mask_rank
self.mask_alpha1 = mask_alpha1
self.mask_alpha2 = mask_alpha2

# create trainable params
self.weight_U = nn.Parameter(
torch.randn(out_features, self.weight_rank).to(
device=self.weight.device, dtype=self.weight.dtype))
self.weight_V = nn.Parameter(
torch.zeros(self.weight_rank, in_features).to(
device=self.weight.device, dtype=self.weight.dtype))

self.mask_scores_A = nn.Parameter(
torch.randn(out_features, self.mask_rank).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_B = nn.Parameter(
torch.zeros(self.mask_rank, in_features).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_R = nn.Parameter(
torch.zeros(out_features).to(
device=self.weight.device, dtype=self.weight.dtype))
self.mask_scores_C = nn.Parameter(
torch.zeros(in_features).to(
device=self.weight.device, dtype=self.weight.dtype))

self.weight.requires_grad = False
if self.module.bias is not None:
self.module.bias.requires_grad = False

def forward(self, *inputs):
if self.pruning_method == 'pst':
weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
mask_scores = (
weight.abs()
+ self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
+ self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
+ self.mask_scores_C.unsqueeze(0)))

mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)
masked_weight = mask * weight

self.module.weight = masked_weight
return self.module(*inputs)
else:
return self.module(*inputs)

def convert(self):
if self.pruning_method == 'pst':
weight = self.weight + self.weight_beta * self.weight_U @ self.weight_V
mask_scores = (
weight.abs()
+ self.mask_alpha1 * self.mask_scores_A @ self.mask_scores_B
+ self.mask_alpha2 * (self.mask_scores_R.unsqueeze(1)
+ self.mask_scores_C.unsqueeze(0)))

mask = SparseBinarizer.apply(mask_scores, self.cur_sparsity)

masked_weight = mask * weight
self.module.weight = nn.Parameter(masked_weight.data)


def _setattr(model, name, module):
name_list = name.split('.')
for name in name_list[:-1]:
model = getattr(model, name)
setattr(model, name_list[-1], module)


def convert_sparse_network(
model,
pruning_method,
weight_rank,
weight_beta,
mask_rank,
mask_alpha1,
mask_alpha2,
logger=None,
):
compress_module = [nn.Linear]
try:
from megatron import mpu
compress_module.extend(
[mpu.RowParallelLinear, mpu.ColumnParallelLinear])
except ImportError:
pass

for name, module in model.named_modules():
if type(module) in compress_module:
new_module = SparseLinear(
module,
pruning_method,
weight_rank,
weight_beta,
mask_rank,
mask_alpha1,
mask_alpha2,
)

# replace original module by new sparse module
_setattr(model, name, new_module)

if is_master():
if logger:
logger.info(f'convert {name} to sparse module.')
else:
print(f'convert {name} to sparse module.')


def update_network_sparsity(model, sparsity):
for name, module in model.named_modules():
if isinstance(module, SparseLinear):
module.cur_sparsity = sparsity


def schedule_sparsity_ratio(
step,
total_step,
frequency,
initial_warmup,
final_warmup,
initial_sparsity,
final_sparsity,
):
if step <= initial_warmup * total_step:
sparsity = initial_sparsity
elif step > (total_step - final_warmup * total_step):
sparsity = final_sparsity
else:
spars_warmup_steps = initial_warmup * total_step
spars_schedu_steps = (final_warmup + initial_warmup) * total_step
step = (step - spars_warmup_steps) // frequency * frequency
mul_coeff = 1 - step / (total_step - spars_schedu_steps)
sparsity = final_sparsity + (initial_sparsity - final_sparsity) * (
mul_coeff**3)
return sparsity


def generate_sparse_model(model, logger=None):
# generate sparse weight for saving
for name, module in model.named_modules():
if isinstance(module, SparseLinear):
module.convert()

_setattr(model, name, module.module)

if is_master():
if logger:
logger.info(f'convert {name} weight to sparse weight, \
sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
)
else:
print(f'convert {name} weight to sparse, \
sparsity ratio={torch.mean(1.0*(module.module.weight==0)).item()}.'
)

+ 0
- 0
tests/trainers/hooks/compression/__init__.py View File


+ 113
- 0
tests/trainers/hooks/compression/test_sparsity_hook.py View File

@@ -0,0 +1,113 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest

import json
import numpy as np
import torch
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import MultiStepLR

from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile, TrainerStages
from modelscope.utils.test_utils import create_dummy_test_dataset

dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 10)


class DummyModel(nn.Module, Model):

def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 10)
self.bn = nn.BatchNorm1d(10)

def forward(self, feat, labels):
x = self.linear(feat)

x = self.bn(x)
loss = torch.sum(x)
return dict(logits=x, loss=loss)


class SparsityHookTest(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)

def test_sparsity_hook(self):
json_cfg = {
'task': 'image_classification',
'train': {
'work_dir':
self.tmp_dir,
'dataloader': {
'batch_size_per_gpu': 2,
'workers_per_gpu': 1
},
'hooks': [{
'type': 'SparsityHook',
'pruning_method': 'pst',
'config': {
'weight_rank': 1,
'mask_rank': 1,
'final_sparsity': 0.9,
'frequency': 1,
},
}],
},
}

config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
with open(config_path, 'w') as f:
json.dump(json_cfg, f)

model = DummyModel()
optimizer = SGD(model.parameters(), lr=0.01)
lr_scheduler = MultiStepLR(optimizer, milestones=[2, 4])
trainer_name = Trainers.default
kwargs = dict(
cfg_file=config_path,
model=model,
train_dataset=dummy_dataset,
optimizers=(optimizer, lr_scheduler),
max_epochs=5,
device='cpu',
)

trainer = build_trainer(trainer_name, kwargs)
train_dataloader = trainer._build_dataloader_with_dataset(
trainer.train_dataset, **trainer.cfg.train.get('dataloader', {}))
trainer.register_optimizers_hook()
trainer.register_hook_from_cfg(trainer.cfg.train.hooks)
trainer.train_dataloader = train_dataloader
trainer.data_loader = train_dataloader
trainer.invoke_hook(TrainerStages.before_run)
for i in range(trainer._epoch, trainer._max_epochs):
trainer.invoke_hook(TrainerStages.before_train_epoch)
for _, data_batch in enumerate(train_dataloader):
trainer.invoke_hook(TrainerStages.before_train_iter)
trainer.train_step(trainer.model, data_batch)
trainer.invoke_hook(TrainerStages.after_train_iter)
trainer.invoke_hook(TrainerStages.after_train_epoch)
trainer.invoke_hook(TrainerStages.after_run)

self.assertEqual(
torch.mean(1.0 * (trainer.model.linear.weight == 0)), 0.9)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save