Browse Source

[to #43627720] support distributed training

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9551089

    * support distributed training
master
jiangnana.jnn 3 years ago
parent
commit
21437650f1
13 changed files with 652 additions and 65 deletions
  1. +1
    -1
      modelscope/metrics/sequence_classification_metric.py
  2. +3
    -3
      modelscope/models/nlp/task_model.py
  3. +12
    -10
      modelscope/trainers/hooks/checkpoint_hook.py
  4. +4
    -4
      modelscope/trainers/hooks/logger/text_logger_hook.py
  5. +2
    -0
      modelscope/trainers/parallel/__init__.py
  6. +20
    -0
      modelscope/trainers/parallel/builder.py
  7. +23
    -0
      modelscope/trainers/parallel/utils.py
  8. +46
    -11
      modelscope/trainers/trainer.py
  9. +33
    -18
      modelscope/trainers/utils/inference.py
  10. +172
    -1
      modelscope/utils/test_utils.py
  11. +57
    -2
      modelscope/utils/torch_utils.py
  12. +15
    -15
      modelscope/utils/utils.py
  13. +264
    -0
      tests/trainers/test_trainer_gpu.py

+ 1
- 1
modelscope/metrics/sequence_classification_metric.py View File

@@ -24,7 +24,7 @@ class SequenceClassificationMetric(Metric):
self.labels = []

def add(self, outputs: Dict, inputs: Dict):
ground_truths = inputs[SequenceClassificationMetric.label_name]
ground_truths = inputs[self.label_name]
eval_results = outputs[OutputKeys.LOGITS]
self.preds.append(
torch_nested_numpify(torch_nested_detach(eval_results)))


+ 3
- 3
modelscope/models/nlp/task_model.py View File

@@ -424,7 +424,7 @@ class SingleBackboneTaskModelBase(BaseTaskModel):

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""default forward method is the backbone-only forward"""
if if_func_receive_dict_inputs(self.backbone.forward, input):
if if_func_receive_dict_inputs(self.backbone.forward):
outputs = self.backbone.forward(input)
else:
outputs = self.backbone.forward(**input)
@@ -472,13 +472,13 @@ class EncoderDecoderTaskModelBase(BaseTaskModel):
return getattr(self, self._decoder_prefix)

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if if_func_receive_dict_inputs(self.encoder_.forward, input):
if if_func_receive_dict_inputs(self.encoder_.forward):
encoder_outputs = self.encoder_.forward(input)
else:
encoder_outputs = self.encoder_.forward(**input)
decoder_inputs = self.project_decoder_inputs_and_mediate(
input, encoder_outputs)
if if_func_receive_dict_inputs(self.decoder_.forward, input):
if if_func_receive_dict_inputs(self.decoder_.forward):
outputs = self.decoder_.forward(decoder_inputs)
else:
outputs = self.decoder_.forward(**decoder_inputs)


+ 12
- 10
modelscope/trainers/hooks/checkpoint_hook.py View File

@@ -5,7 +5,7 @@ from modelscope import __version__
from modelscope.utils.checkpoint import save_checkpoint
from modelscope.utils.constant import LogKeys
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import get_dist_info
from modelscope.utils.torch_utils import is_master
from .builder import HOOKS
from .hook import Hook
from .priority import Priority
@@ -47,15 +47,18 @@ class CheckpointHook(Hook):
else:
self.logger = trainer.logger

self.logger.info(f'Checkpoints will be saved to {self.save_dir}')
if is_master():
self.logger.info(f'Checkpoints will be saved to {self.save_dir}')

def after_train_epoch(self, trainer):
if not self.by_epoch:
return

if self._should_save(trainer):
self.logger.info(f'Saving checkpoint at {trainer.epoch + 1} epoch')
self._save_checkpoint(trainer)
if is_master():
self.logger.info(
f'Saving checkpoint at {trainer.epoch + 1} epoch')
self._save_checkpoint(trainer)

def _save_checkpoint(self, trainer):
if self.by_epoch:
@@ -65,18 +68,17 @@ class CheckpointHook(Hook):
cur_save_name = os.path.join(
self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth')

rank, _ = get_dist_info()
if rank == 0:
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)
save_checkpoint(trainer.model, cur_save_name, trainer.optimizer)

def after_train_iter(self, trainer):
if self.by_epoch:
return

if self._should_save(trainer):
self.logger.info(
f'Saving checkpoint at {trainer.iter + 1} iterations')
self._save_checkpoint(trainer)
if is_master():
self.logger.info(
f'Saving checkpoint at {trainer.iter + 1} iterations')
self._save_checkpoint(trainer)

def _should_save(self, trainer):
if self.by_epoch:


+ 4
- 4
modelscope/trainers/hooks/logger/text_logger_hook.py View File

@@ -11,7 +11,7 @@ from torch import distributed as dist
from modelscope.trainers.hooks.builder import HOOKS
from modelscope.trainers.hooks.logger.base import LoggerHook
from modelscope.utils.constant import LogKeys, ModeKeys
from modelscope.utils.torch_utils import get_dist_info
from modelscope.utils.torch_utils import get_dist_info, is_master


@HOOKS.register_module()
@@ -130,7 +130,8 @@ class TextLoggerHook(LoggerHook):
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)

trainer.logger.info(log_str)
if is_master():
trainer.logger.info(log_str)

def _dump_log(self, log_dict):
# dump log in json format
@@ -138,8 +139,7 @@ class TextLoggerHook(LoggerHook):
for k, v in log_dict.items():
json_log[k] = self._round_float(v)

rank, _ = get_dist_info()
if rank == 0:
if is_master():
with open(self.json_log_path, 'a+') as f:
json.dump(json_log, f)
f.write('\n')


+ 2
- 0
modelscope/trainers/parallel/__init__.py View File

@@ -0,0 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .builder import PARALLEL

+ 20
- 0
modelscope/trainers/parallel/builder.py View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from torch.nn.parallel.distributed import DistributedDataParallel

from modelscope.utils.config import ConfigDict
from modelscope.utils.registry import Registry, build_from_cfg

PARALLEL = Registry('parallel')
PARALLEL.register_module(
module_name='DistributedDataParallel', module_cls=DistributedDataParallel)


def build_parallel(cfg: ConfigDict, default_args: dict = None):
""" build parallel

Args:
cfg (:obj:`ConfigDict`): config dict for parallel object.
default_args (dict, optional): Default initialization arguments.
"""
return build_from_cfg(cfg, PARALLEL, default_args=default_args)

+ 23
- 0
modelscope/trainers/parallel/utils.py View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .builder import PARALLEL


def is_parallel(module):
"""Check if a module is wrapped by parallel object.

The following modules are regarded as parallel object:
- torch.nn.parallel.DataParallel
- torch.nn.parallel.distributed.DistributedDataParallel
You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`.

Args:
module (nn.Module): The module to be checked.

Returns:
bool: True if the is wrapped by parallel object.
"""
module_wrappers = []
for group, module_dict in PARALLEL.modules.items():
module_wrappers.extend(list(module_dict.values()))

return isinstance(module, tuple(module_wrappers))

+ 46
- 11
modelscope/trainers/trainer.py View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path
import os
import random
import time
from collections.abc import Mapping
@@ -32,12 +32,15 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Hubs, ModeKeys,
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import build_from_cfg
from modelscope.utils.tensor_utils import torch_default_data_collator
from modelscope.utils.torch_utils import create_device, get_dist_info
from modelscope.utils.torch_utils import (broadcast, create_device,
get_dist_info, init_dist)
from modelscope.utils.utils import if_func_receive_dict_inputs
from .base import BaseTrainer
from .builder import TRAINERS
from .default_config import DEFAULT_CONFIG
from .hooks.hook import Hook
from .parallel.builder import build_parallel
from .parallel.utils import is_parallel


@TRAINERS.register_module()
@@ -150,11 +153,16 @@ class EpochBasedTrainer(BaseTrainer):
# TODO @wenmeng.zwm add seed init fn
self._seed = 0

if kwargs.get('launcher', None) is not None:
init_dist(kwargs['launcher'])

self._dist = get_dist_info()[1] > 1

# model placement
if self.device.type == 'cuda':
self.model.to(self.device)
if not is_parallel(self.model) and self._dist:
self.model = self.to_parallel(self.model)

@property
def mode(self):
@@ -287,7 +295,10 @@ class EpochBasedTrainer(BaseTrainer):
self.train_dataloader = self.get_train_dataloader()
else:
self.train_dataloader = self._build_dataloader_with_dataset(
self.train_dataset, **self.cfg.train.get('dataloader', {}))
self.train_dataset,
dist=self._dist,
seed=self._seed,
**self.cfg.train.get('dataloader', {}))
self.data_loader = self.train_dataloader

self.register_optimizers_hook()
@@ -303,15 +314,21 @@ class EpochBasedTrainer(BaseTrainer):
self.eval_dataloader = self.get_eval_data_loader()
else:
self.eval_dataloader = self._build_dataloader_with_dataset(
self.eval_dataset, **self.cfg.evaluation.get('dataloader', {}))
self.eval_dataset,
dist=self._dist,
seed=self._seed,
**self.cfg.evaluation.get('dataloader', {}))
self.data_loader = self.eval_dataloader
metric_classes = [build_metric(metric) for metric in self.metrics]
self.evaluation_loop(self.eval_dataloader, checkpoint_path,
metric_classes)
rank, world_size = get_dist_info()
metric_values = {}
for metric_cls in metric_classes:
metric_values.update(metric_cls.evaluate())
if rank == 0:
for metric_cls in metric_classes:
metric_values.update(metric_cls.evaluate())
if world_size > 1:
metric_values = broadcast(metric_values, 0)
return metric_values

def build_model(self) -> Union[nn.Module, TorchModel]:
@@ -328,6 +345,20 @@ class EpochBasedTrainer(BaseTrainer):
elif isinstance(model, nn.Module):
return model

def to_parallel(self, model) -> Union[nn.Module, TorchModel]:
# config format to reserve custom ddp
if self.cfg.get('parallel', None) is not None:
self.cfg.parallel.update(
dict(module=model, device_ids=[torch.cuda.current_device()]))
return build_parallel(self.cfg.parallel)

dp_cfg = dict(
type='DistributedDataParallel',
module=model,
device_ids=[torch.cuda.current_device()])

return build_parallel(dp_cfg)

def collate_fn(self, data):
"""Prepare the input just before the forward function.
This method will move the tensors to the right device.
@@ -378,8 +409,9 @@ class EpochBasedTrainer(BaseTrainer):
self._mode = ModeKeys.TRAIN
inputs = self.collate_fn(inputs)
# call model forward but not __call__ to skip postprocess
if isinstance(inputs, Mapping) and not if_func_receive_dict_inputs(
model.forward, inputs):
if isinstance(
inputs,
Mapping) and not if_func_receive_dict_inputs(model.forward):
train_outputs = model.forward(**inputs)
else:
train_outputs = model.forward(inputs)
@@ -444,7 +476,10 @@ class EpochBasedTrainer(BaseTrainer):
train_data, mode=ModeKeys.TRAIN)

data_loader = self._build_dataloader_with_dataset(
self.train_dataset, **self.cfg.train.get('dataloader', {}))
self.train_dataset,
dist=self._dist,
seed=self._seed,
**self.cfg.train.get('dataloader', {}))
return data_loader

def get_eval_data_loader(self):
@@ -594,7 +629,7 @@ class EpochBasedTrainer(BaseTrainer):

if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, seed=seed)
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle)
else:
sampler = None



+ 33
- 18
modelscope/trainers/utils/inference.py View File

@@ -3,7 +3,6 @@
import os
import pickle
import shutil
import tempfile
import time
from collections.abc import Mapping

@@ -11,8 +10,7 @@ import torch
from torch import distributed as dist
from tqdm import tqdm

from modelscope.models.base import Model
from modelscope.utils.torch_utils import get_dist_info
from modelscope.utils.torch_utils import get_dist_info, is_master, make_tmp_dir
from modelscope.utils.utils import if_func_receive_dict_inputs


@@ -40,7 +38,7 @@ def single_gpu_test(model,
with torch.no_grad():
if isinstance(data,
Mapping) and not if_func_receive_dict_inputs(
model.forward, data):
model.forward):

result = model(**data)
else:
@@ -82,25 +80,28 @@ def multi_gpu_test(model,
"""
model.eval()
results = []
data_list = []
dataset = data_loader.dataset

time.sleep(2) # This line can prevent deadlock problem in some cases.

rank, world_size = get_dist_info()

count = 0
with tqdm(total=len(dataset), desc='test samples with multi gpus') as pbar:
for _, data in enumerate(data_loader):
if data_collate_fn is not None:
data = data_collate_fn(data)
data_list.append(data)
with torch.no_grad():
if isinstance(data,
Mapping) and not if_func_receive_dict_inputs(
model.forward, data):
model.forward):
result = model(**data)
else:
result = model(data)
results.extend(result)
results.append(result)

rank, world_size = get_dist_info()
if rank == 0:
batch_size = len(result)
batch_size_all = batch_size * world_size
@@ -110,15 +111,26 @@ def multi_gpu_test(model,
for _ in range(batch_size_all):
pbar.update()

# collect results from all ranks
# TODO: allgather data list may cost a lot of memory and needs to be redesigned
# collect results and data from all ranks
if gpu_collect:
results = collect_results_gpu(results, len(dataset))
data_list = collect_results_gpu(data_list, len(dataset))
else:
results = collect_results_cpu(results, len(dataset), tmpdir)
ground_truths = [dataset[i] for i in range(len(dataset))]
if metric_classes is not None:
for metric_cls in metric_classes:
metric_cls.add(results, ground_truths)
if tmpdir is None:
tmpdir = make_tmp_dir()
results = collect_results_cpu(results, len(dataset),
os.path.join(tmpdir, 'predict'))
data_list = collect_results_cpu(data_list, len(dataset),
os.path.join(tmpdir, 'groundtruth'))

if is_master():
assert len(data_list) == len(
results), f'size mismatch {len(data_list)} and {len(results)}'
if metric_classes is not None:
for i in range(len(data_list)):
for metric_cls in metric_classes:
metric_cls.add(results[i], data_list[i])


def collect_results_cpu(result_part, size, tmpdir=None):
@@ -140,13 +152,15 @@ def collect_results_cpu(result_part, size, tmpdir=None):
list: The collected results.
"""
rank, world_size = get_dist_info()
# TODO create a random tmp dir if it is not specified
if tmpdir is None:
tmpdir = tempfile.gettempdir()
if not os.path.exists(tmpdir):
tmpdir = make_tmp_dir()
if not os.path.exists(tmpdir) and is_master():
os.makedirs(tmpdir)
dist.barrier()

# dump the part result to the dir
pickle.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl'))
with open(os.path.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
pickle.dump(result_part, f)
dist.barrier()
# collect all parts
if rank != 0:
@@ -156,7 +170,8 @@ def collect_results_cpu(result_part, size, tmpdir=None):
part_list = []
for i in range(world_size):
part_file = os.path.join(tmpdir, f'part_{i}.pkl')
part_result = pickle.load(part_file)
with open(part_file, 'rb') as f:
part_result = pickle.load(f)
# When data is severely insufficient, an empty part_result
# on a certain gpu could makes the overall outputs empty.
if part_result:


+ 172
- 1
modelscope/utils/test_utils.py View File

@@ -1,16 +1,23 @@
#!/usr/bin/env python
# Copyright (c) Alibaba, Inc. and its affiliates.

import copy
import os
import pickle
import shutil
import socket
import subprocess
import sys
import tarfile
import tempfile
import unittest

import numpy as np
import requests
from datasets import Dataset
from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE

from modelscope.msdatasets import MsDataset
from .torch_utils import _find_free_port

TEST_LEVEL = 2
TEST_LEVEL_STR = 'TEST_LEVEL'
@@ -62,3 +69,167 @@ def download_and_untar(fpath, furl, dst) -> str:
t.extractall(path=dst)

return target_dir_path


_DIST_SCRIPT_TEMPLATE = """
import ast
import argparse
import pickle
import torch
from torch import distributed as dist
from modelscope.utils.torch_utils import get_dist_info
import {}

parser = argparse.ArgumentParser()
parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results')
parser.add_argument('--save_file', type=str, help='save file')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()


def main():
results = {}.{}({}) # module.func(params)
if args.save_all_ranks:
save_file = args.save_file + str(dist.get_rank())
with open(save_file, 'wb') as f:
pickle.dump(results, f)
else:
rank, _ = get_dist_info()
if rank == 0:
with open(args.save_file, 'wb') as f:
pickle.dump(results, f)


if __name__ == '__main__':
main()
"""


class DistributedTestCase(unittest.TestCase):
"""Distributed TestCase for test function with distributed mode.
Examples:
import torch
from torch import distributed as dist
from modelscope.utils.torch_utils import init_dist

def _test_func(*args, **kwargs):
init_dist(launcher='pytorch')
rank = dist.get_rank()
if rank == 0:
value = torch.tensor(1.0).cuda()
else:
value = torch.tensor(2.0).cuda()
dist.all_reduce(value)
return value.cpu().numpy()

class DistTest(DistributedTestCase):
def test_function_dist(self):
args = () # args should be python builtin type
kwargs = {} # kwargs should be python builtin type
self.start(
_test_func,
num_gpus=2,
assert_callback=lambda x: self.assertEqual(x, 3.0),
*args,
**kwargs,
)
"""

def _start(self,
dist_start_cmd,
func,
num_gpus,
assert_callback=None,
save_all_ranks=False,
*args,
**kwargs):
script_path = func.__code__.co_filename
script_dir, script_name = os.path.split(script_path)
script_name = os.path.splitext(script_name)[0]
func_name = func.__qualname__

func_params = []
for arg in args:
if isinstance(arg, str):
arg = ('\'{}\''.format(arg))
func_params.append(str(arg))

for k, v in kwargs.items():
if isinstance(v, str):
v = ('\'{}\''.format(v))
func_params.append('{}={}'.format(k, v))

func_params = ','.join(func_params).strip(',')

tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name
tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name

with open(tmp_run_file, 'w') as f:
print('save temporary run file to : {}'.format(tmp_run_file))
print('save results to : {}'.format(tmp_res_file))
run_file_content = _DIST_SCRIPT_TEMPLATE.format(
script_name, script_name, func_name, func_params)
f.write(run_file_content)

tmp_res_files = []
if save_all_ranks:
for i in range(num_gpus):
tmp_res_files.append(tmp_res_file + str(i))
else:
tmp_res_files = [tmp_res_file]
self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files)

tmp_env = copy.deepcopy(os.environ)
tmp_env['PYTHONPATH'] = ':'.join(
(tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':')
script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks,
tmp_res_file)
script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params)
print('script command: %s' % script_cmd)
res = subprocess.call(script_cmd, shell=True, env=tmp_env)

script_res = []
for res_file in tmp_res_files:
with open(res_file, 'rb') as f:
script_res.append(pickle.load(f))
if not save_all_ranks:
script_res = script_res[0]

if assert_callback:
assert_callback(script_res)

self.assertEqual(
res,
0,
msg='The test function ``{}`` in ``{}`` run failed!'.format(
func_name, script_name))

return script_res

def start(self,
func,
num_gpus,
assert_callback=None,
save_all_ranks=False,
*args,
**kwargs):
ip = socket.gethostbyname(socket.gethostname())
dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % (
sys.executable, num_gpus, ip, _find_free_port())

return self._start(
dist_start_cmd=dist_start_cmd,
func=func,
num_gpus=num_gpus,
assert_callback=assert_callback,
save_all_ranks=save_all_ranks,
*args,
**kwargs)

def clean_tmp(self, tmp_file_list):
for file in tmp_file_list:
if os.path.exists(file):
if os.path.isdir(file):
shutil.rmtree(file)
else:
os.remove(file)

+ 57
- 2
modelscope/utils/torch_utils.py View File

@@ -1,11 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Following code is partialy borrowed from openmmlab/mmcv

import functools
import os
import pickle
import socket
import subprocess
from collections import OrderedDict
import tempfile
from typing import Callable, List, Optional, Tuple

import torch
@@ -116,6 +116,11 @@ def get_dist_info() -> Tuple[int, int]:
return rank, world_size


def is_master():
rank, _ = get_dist_info()
return rank == 0


def master_only(func: Callable) -> Callable:

@functools.wraps(func)
@@ -136,3 +141,53 @@ def create_device(cpu: bool = False) -> torch.DeviceObjType:
device = torch.device('cpu')

return device


def make_tmp_dir():
"""Make sure each rank has the same temporary directory on the distributed mode.
"""
rank, world_size = get_dist_info()
if world_size <= 1:
return tempfile.mkdtemp()

tmpdir = None
if rank == 0:
tmpdir = tempfile.mkdtemp()

dist.barrier()
tmpdir = broadcast(tmpdir, 0)

return tmpdir


def broadcast(inputs, src):
"""
Broadcasts the inputs to all ranks.

Arguments:
inputs : Any objects that can be serialized by pickle.
src (int): Source rank.
Returns:
Each rank returns the same value as src.
"""
rank, _ = get_dist_info()
shape_tensor = torch.tensor([0], device='cuda')

if rank == src:
inputs_tensor = torch.tensor(
bytearray(pickle.dumps(inputs)), dtype=torch.uint8, device='cuda')
shape_tensor = torch.tensor(inputs_tensor.shape, device='cuda')

dist.barrier()
dist.broadcast(shape_tensor, src)

if rank != src:
inputs_tensor = torch.full((shape_tensor.item(), ),
0,
dtype=torch.uint8,
device='cuda')

dist.barrier()
dist.broadcast(inputs_tensor, src)

return pickle.loads(inputs_tensor.cpu().numpy().tobytes())

+ 15
- 15
modelscope/utils/utils.py View File

@@ -4,30 +4,30 @@ import inspect
import os


def if_func_receive_dict_inputs(func, inputs):
# TODO: remove this api, unify to flattened args
def if_func_receive_dict_inputs(func):
"""to decide if a func could recieve dict inputs or not

Args:
func (class): the target function to be inspected
inputs (dicts): the inputs that will send to the function

Returns:
bool: if func recieve dict, then recieve True

Examples:
input = {"input_dict":xxx, "attention_masked":xxx},
function(self, inputs) then return True
function(inputs) then return True
function(self, input_dict, attention_masked) then return False
bool: if func only has one arg ``input`` or ``inputs``, return True, else return False
"""
signature = inspect.signature(func)
func_inputs = list(signature.parameters.keys() - set(['self']))
mismatched_inputs = list(set(func_inputs) - set(inputs))
if len(func_inputs) == len(mismatched_inputs):
return True
else:
full_args_spec = inspect.getfullargspec(func)
varargs = full_args_spec.varargs
varkw = full_args_spec.varkw
if not (varargs is None and varkw is None):
return False

args = [] if not full_args_spec.args else full_args_spec.args
args.pop(0) if (args and args[0] in ['self', 'cls']) else args

if len(args) == 1 and args[0] in ['input', 'inputs']:
return True

return False


def get_default_cache_dir():
"""


+ 264
- 0
tests/trainers/test_trainer_gpu.py View File

@@ -0,0 +1,264 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import shutil
import tempfile
import unittest

import json
import numpy as np
import torch
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR

from modelscope.metrics.builder import MetricKeys
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile
from modelscope.utils.test_utils import (DistributedTestCase,
create_dummy_test_dataset, test_level)


class DummyMetric:

def __call__(self, ground_truth, predict_results):
return {'accuracy': 0.5}


dummy_dataset_small = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)

dummy_dataset_big = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)


class DummyModel(nn.Module):

def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 4)
self.bn = nn.BatchNorm1d(4)

def forward(self, feat, labels):
x = self.linear(feat)

x = self.bn(x)
loss = torch.sum(x)
return dict(logits=x, loss=loss)


def train_func(work_dir, dist=False):
json_cfg = {
'train': {
'work_dir': work_dir,
'dataloader': {
'batch_size_per_gpu': 2,
'workers_per_gpu': 1
},
'hooks': [{
'type': 'EvaluationHook',
'interval': 1
}]
},
'evaluation': {
'dataloader': {
'batch_size_per_gpu': 1,
'workers_per_gpu': 1,
'shuffle': False
},
'metrics': ['seq_cls_metric']
}
}

config_path = os.path.join(work_dir, ModelFile.CONFIGURATION)
with open(config_path, 'w') as f:
json.dump(json_cfg, f)

model = DummyModel()
optimmizer = SGD(model.parameters(), lr=0.01)
lr_scheduler = StepLR(optimmizer, 2)
trainer_name = 'EpochBasedTrainer'
kwargs = dict(
cfg_file=config_path,
model=model,
data_collator=None,
train_dataset=dummy_dataset_big,
eval_dataset=dummy_dataset_small,
optimizers=(optimmizer, lr_scheduler),
max_epochs=3,
device='gpu',
launcher='pytorch' if dist else None)

trainer = build_trainer(trainer_name, kwargs)
trainer.train()


@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
class TrainerTestSingleGpu(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_single_gpu(self):
train_func(self.tmp_dir)

results_files = os.listdir(self.tmp_dir)
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
self.assertEqual(len(json_files), 1)

with open(json_files[0], 'r') as f:
lines = [i.strip() for i in f.readlines()]
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 1,
LogKeys.ITER: 10,
LogKeys.LR: 0.01
}, json.loads(lines[0]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 1,
LogKeys.ITER: 20,
LogKeys.LR: 0.01
}, json.loads(lines[1]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 1,
LogKeys.ITER: 20
}, json.loads(lines[2]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 2,
LogKeys.ITER: 10,
LogKeys.LR: 0.001
}, json.loads(lines[3]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 2,
LogKeys.ITER: 20,
LogKeys.LR: 0.001
}, json.loads(lines[4]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 2,
LogKeys.ITER: 20
}, json.loads(lines[5]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 3,
LogKeys.ITER: 10,
LogKeys.LR: 0.001
}, json.loads(lines[6]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 3,
LogKeys.ITER: 20,
LogKeys.LR: 0.001
}, json.loads(lines[7]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 3,
LogKeys.ITER: 20
}, json.loads(lines[8]))
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
for i in [0, 1, 3, 4, 6, 7]:
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
self.assertIn(LogKeys.ITER_TIME, lines[i])
for i in [2, 5, 8]:
self.assertIn(MetricKeys.ACCURACY, lines[i])


@unittest.skipIf(not torch.cuda.is_available()
or torch.cuda.device_count() <= 1, 'distributed unittest')
class TrainerTestMultiGpus(DistributedTestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_multi_gpus(self):
self.start(train_func, num_gpus=2, work_dir=self.tmp_dir, dist=True)

results_files = os.listdir(self.tmp_dir)
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
self.assertEqual(len(json_files), 1)

with open(json_files[0], 'r') as f:
lines = [i.strip() for i in f.readlines()]

self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 1,
LogKeys.ITER: 10,
LogKeys.LR: 0.01
}, json.loads(lines[0]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 1,
LogKeys.ITER: 10
}, json.loads(lines[1]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 2,
LogKeys.ITER: 10,
LogKeys.LR: 0.001
}, json.loads(lines[2]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 2,
LogKeys.ITER: 10
}, json.loads(lines[3]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.TRAIN,
LogKeys.EPOCH: 3,
LogKeys.ITER: 10,
LogKeys.LR: 0.001
}, json.loads(lines[4]))
self.assertDictContainsSubset(
{
LogKeys.MODE: ModeKeys.EVAL,
LogKeys.EPOCH: 3,
LogKeys.ITER: 10
}, json.loads(lines[5]))
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
for i in [0, 2, 4]:
self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i])
self.assertIn(LogKeys.ITER_TIME, lines[i])
for i in [1, 3, 5]:
self.assertIn(MetricKeys.ACCURACY, lines[i])


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save