Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9551089 * support distributed trainingmaster
| @@ -24,7 +24,7 @@ class SequenceClassificationMetric(Metric): | |||
| self.labels = [] | |||
| def add(self, outputs: Dict, inputs: Dict): | |||
| ground_truths = inputs[SequenceClassificationMetric.label_name] | |||
| ground_truths = inputs[self.label_name] | |||
| eval_results = outputs[OutputKeys.LOGITS] | |||
| self.preds.append( | |||
| torch_nested_numpify(torch_nested_detach(eval_results))) | |||
| @@ -424,7 +424,7 @@ class SingleBackboneTaskModelBase(BaseTaskModel): | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| """default forward method is the backbone-only forward""" | |||
| if if_func_receive_dict_inputs(self.backbone.forward, input): | |||
| if if_func_receive_dict_inputs(self.backbone.forward): | |||
| outputs = self.backbone.forward(input) | |||
| else: | |||
| outputs = self.backbone.forward(**input) | |||
| @@ -472,13 +472,13 @@ class EncoderDecoderTaskModelBase(BaseTaskModel): | |||
| return getattr(self, self._decoder_prefix) | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| if if_func_receive_dict_inputs(self.encoder_.forward, input): | |||
| if if_func_receive_dict_inputs(self.encoder_.forward): | |||
| encoder_outputs = self.encoder_.forward(input) | |||
| else: | |||
| encoder_outputs = self.encoder_.forward(**input) | |||
| decoder_inputs = self.project_decoder_inputs_and_mediate( | |||
| input, encoder_outputs) | |||
| if if_func_receive_dict_inputs(self.decoder_.forward, input): | |||
| if if_func_receive_dict_inputs(self.decoder_.forward): | |||
| outputs = self.decoder_.forward(decoder_inputs) | |||
| else: | |||
| outputs = self.decoder_.forward(**decoder_inputs) | |||
| @@ -5,7 +5,7 @@ from modelscope import __version__ | |||
| from modelscope.utils.checkpoint import save_checkpoint | |||
| from modelscope.utils.constant import LogKeys | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.torch_utils import get_dist_info | |||
| from modelscope.utils.torch_utils import is_master | |||
| from .builder import HOOKS | |||
| from .hook import Hook | |||
| from .priority import Priority | |||
| @@ -47,15 +47,18 @@ class CheckpointHook(Hook): | |||
| else: | |||
| self.logger = trainer.logger | |||
| self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | |||
| if is_master(): | |||
| self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | |||
| def after_train_epoch(self, trainer): | |||
| if not self.by_epoch: | |||
| return | |||
| if self._should_save(trainer): | |||
| self.logger.info(f'Saving checkpoint at {trainer.epoch + 1} epoch') | |||
| self._save_checkpoint(trainer) | |||
| if is_master(): | |||
| self.logger.info( | |||
| f'Saving checkpoint at {trainer.epoch + 1} epoch') | |||
| self._save_checkpoint(trainer) | |||
| def _save_checkpoint(self, trainer): | |||
| if self.by_epoch: | |||
| @@ -65,18 +68,17 @@ class CheckpointHook(Hook): | |||
| cur_save_name = os.path.join( | |||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | |||
| rank, _ = get_dist_info() | |||
| if rank == 0: | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||
| def after_train_iter(self, trainer): | |||
| if self.by_epoch: | |||
| return | |||
| if self._should_save(trainer): | |||
| self.logger.info( | |||
| f'Saving checkpoint at {trainer.iter + 1} iterations') | |||
| self._save_checkpoint(trainer) | |||
| if is_master(): | |||
| self.logger.info( | |||
| f'Saving checkpoint at {trainer.iter + 1} iterations') | |||
| self._save_checkpoint(trainer) | |||
| def _should_save(self, trainer): | |||
| if self.by_epoch: | |||
| @@ -11,7 +11,7 @@ from torch import distributed as dist | |||
| from modelscope.trainers.hooks.builder import HOOKS | |||
| from modelscope.trainers.hooks.logger.base import LoggerHook | |||
| from modelscope.utils.constant import LogKeys, ModeKeys | |||
| from modelscope.utils.torch_utils import get_dist_info | |||
| from modelscope.utils.torch_utils import get_dist_info, is_master | |||
| @HOOKS.register_module() | |||
| @@ -130,7 +130,8 @@ class TextLoggerHook(LoggerHook): | |||
| log_items.append(f'{name}: {val}') | |||
| log_str += ', '.join(log_items) | |||
| trainer.logger.info(log_str) | |||
| if is_master(): | |||
| trainer.logger.info(log_str) | |||
| def _dump_log(self, log_dict): | |||
| # dump log in json format | |||
| @@ -138,8 +139,7 @@ class TextLoggerHook(LoggerHook): | |||
| for k, v in log_dict.items(): | |||
| json_log[k] = self._round_float(v) | |||
| rank, _ = get_dist_info() | |||
| if rank == 0: | |||
| if is_master(): | |||
| with open(self.json_log_path, 'a+') as f: | |||
| json.dump(json_log, f) | |||
| f.write('\n') | |||
| @@ -0,0 +1,2 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .builder import PARALLEL | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from torch.nn.parallel.distributed import DistributedDataParallel | |||
| from modelscope.utils.config import ConfigDict | |||
| from modelscope.utils.registry import Registry, build_from_cfg | |||
| PARALLEL = Registry('parallel') | |||
| PARALLEL.register_module( | |||
| module_name='DistributedDataParallel', module_cls=DistributedDataParallel) | |||
| def build_parallel(cfg: ConfigDict, default_args: dict = None): | |||
| """ build parallel | |||
| Args: | |||
| cfg (:obj:`ConfigDict`): config dict for parallel object. | |||
| default_args (dict, optional): Default initialization arguments. | |||
| """ | |||
| return build_from_cfg(cfg, PARALLEL, default_args=default_args) | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from .builder import PARALLEL | |||
| def is_parallel(module): | |||
| """Check if a module is wrapped by parallel object. | |||
| The following modules are regarded as parallel object: | |||
| - torch.nn.parallel.DataParallel | |||
| - torch.nn.parallel.distributed.DistributedDataParallel | |||
| You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`. | |||
| Args: | |||
| module (nn.Module): The module to be checked. | |||
| Returns: | |||
| bool: True if the is wrapped by parallel object. | |||
| """ | |||
| module_wrappers = [] | |||
| for group, module_dict in PARALLEL.modules.items(): | |||
| module_wrappers.extend(list(module_dict.values())) | |||
| return isinstance(module, tuple(module_wrappers)) | |||
| @@ -1,5 +1,5 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path | |||
| import os | |||
| import random | |||
| import time | |||
| from collections.abc import Mapping | |||
| @@ -32,12 +32,15 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Hubs, ModeKeys, | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.registry import build_from_cfg | |||
| from modelscope.utils.tensor_utils import torch_default_data_collator | |||
| from modelscope.utils.torch_utils import create_device, get_dist_info | |||
| from modelscope.utils.torch_utils import (broadcast, create_device, | |||
| get_dist_info, init_dist) | |||
| from modelscope.utils.utils import if_func_receive_dict_inputs | |||
| from .base import BaseTrainer | |||
| from .builder import TRAINERS | |||
| from .default_config import DEFAULT_CONFIG | |||
| from .hooks.hook import Hook | |||
| from .parallel.builder import build_parallel | |||
| from .parallel.utils import is_parallel | |||
| @TRAINERS.register_module() | |||
| @@ -150,11 +153,16 @@ class EpochBasedTrainer(BaseTrainer): | |||
| # TODO @wenmeng.zwm add seed init fn | |||
| self._seed = 0 | |||
| if kwargs.get('launcher', None) is not None: | |||
| init_dist(kwargs['launcher']) | |||
| self._dist = get_dist_info()[1] > 1 | |||
| # model placement | |||
| if self.device.type == 'cuda': | |||
| self.model.to(self.device) | |||
| if not is_parallel(self.model) and self._dist: | |||
| self.model = self.to_parallel(self.model) | |||
| @property | |||
| def mode(self): | |||
| @@ -287,7 +295,10 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self.train_dataloader = self.get_train_dataloader() | |||
| else: | |||
| self.train_dataloader = self._build_dataloader_with_dataset( | |||
| self.train_dataset, **self.cfg.train.get('dataloader', {})) | |||
| self.train_dataset, | |||
| dist=self._dist, | |||
| seed=self._seed, | |||
| **self.cfg.train.get('dataloader', {})) | |||
| self.data_loader = self.train_dataloader | |||
| self.register_optimizers_hook() | |||
| @@ -303,15 +314,21 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self.eval_dataloader = self.get_eval_data_loader() | |||
| else: | |||
| self.eval_dataloader = self._build_dataloader_with_dataset( | |||
| self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) | |||
| self.eval_dataset, | |||
| dist=self._dist, | |||
| seed=self._seed, | |||
| **self.cfg.evaluation.get('dataloader', {})) | |||
| self.data_loader = self.eval_dataloader | |||
| metric_classes = [build_metric(metric) for metric in self.metrics] | |||
| self.evaluation_loop(self.eval_dataloader, checkpoint_path, | |||
| metric_classes) | |||
| rank, world_size = get_dist_info() | |||
| metric_values = {} | |||
| for metric_cls in metric_classes: | |||
| metric_values.update(metric_cls.evaluate()) | |||
| if rank == 0: | |||
| for metric_cls in metric_classes: | |||
| metric_values.update(metric_cls.evaluate()) | |||
| if world_size > 1: | |||
| metric_values = broadcast(metric_values, 0) | |||
| return metric_values | |||
| def build_model(self) -> Union[nn.Module, TorchModel]: | |||
| @@ -328,6 +345,20 @@ class EpochBasedTrainer(BaseTrainer): | |||
| elif isinstance(model, nn.Module): | |||
| return model | |||
| def to_parallel(self, model) -> Union[nn.Module, TorchModel]: | |||
| # config format to reserve custom ddp | |||
| if self.cfg.get('parallel', None) is not None: | |||
| self.cfg.parallel.update( | |||
| dict(module=model, device_ids=[torch.cuda.current_device()])) | |||
| return build_parallel(self.cfg.parallel) | |||
| dp_cfg = dict( | |||
| type='DistributedDataParallel', | |||
| module=model, | |||
| device_ids=[torch.cuda.current_device()]) | |||
| return build_parallel(dp_cfg) | |||
| def collate_fn(self, data): | |||
| """Prepare the input just before the forward function. | |||
| This method will move the tensors to the right device. | |||
| @@ -378,8 +409,9 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self._mode = ModeKeys.TRAIN | |||
| inputs = self.collate_fn(inputs) | |||
| # call model forward but not __call__ to skip postprocess | |||
| if isinstance(inputs, Mapping) and not if_func_receive_dict_inputs( | |||
| model.forward, inputs): | |||
| if isinstance( | |||
| inputs, | |||
| Mapping) and not if_func_receive_dict_inputs(model.forward): | |||
| train_outputs = model.forward(**inputs) | |||
| else: | |||
| train_outputs = model.forward(inputs) | |||
| @@ -444,7 +476,10 @@ class EpochBasedTrainer(BaseTrainer): | |||
| train_data, mode=ModeKeys.TRAIN) | |||
| data_loader = self._build_dataloader_with_dataset( | |||
| self.train_dataset, **self.cfg.train.get('dataloader', {})) | |||
| self.train_dataset, | |||
| dist=self._dist, | |||
| seed=self._seed, | |||
| **self.cfg.train.get('dataloader', {})) | |||
| return data_loader | |||
| def get_eval_data_loader(self): | |||
| @@ -594,7 +629,7 @@ class EpochBasedTrainer(BaseTrainer): | |||
| if dist: | |||
| sampler = DistributedSampler( | |||
| dataset, world_size, rank, shuffle=shuffle, seed=seed) | |||
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | |||
| else: | |||
| sampler = None | |||
| @@ -3,7 +3,6 @@ | |||
| import os | |||
| import pickle | |||
| import shutil | |||
| import tempfile | |||
| import time | |||
| from collections.abc import Mapping | |||
| @@ -11,8 +10,7 @@ import torch | |||
| from torch import distributed as dist | |||
| from tqdm import tqdm | |||
| from modelscope.models.base import Model | |||
| from modelscope.utils.torch_utils import get_dist_info | |||
| from modelscope.utils.torch_utils import get_dist_info, is_master, make_tmp_dir | |||
| from modelscope.utils.utils import if_func_receive_dict_inputs | |||
| @@ -40,7 +38,7 @@ def single_gpu_test(model, | |||
| with torch.no_grad(): | |||
| if isinstance(data, | |||
| Mapping) and not if_func_receive_dict_inputs( | |||
| model.forward, data): | |||
| model.forward): | |||
| result = model(**data) | |||
| else: | |||
| @@ -82,25 +80,28 @@ def multi_gpu_test(model, | |||
| """ | |||
| model.eval() | |||
| results = [] | |||
| data_list = [] | |||
| dataset = data_loader.dataset | |||
| time.sleep(2) # This line can prevent deadlock problem in some cases. | |||
| rank, world_size = get_dist_info() | |||
| count = 0 | |||
| with tqdm(total=len(dataset), desc='test samples with multi gpus') as pbar: | |||
| for _, data in enumerate(data_loader): | |||
| if data_collate_fn is not None: | |||
| data = data_collate_fn(data) | |||
| data_list.append(data) | |||
| with torch.no_grad(): | |||
| if isinstance(data, | |||
| Mapping) and not if_func_receive_dict_inputs( | |||
| model.forward, data): | |||
| model.forward): | |||
| result = model(**data) | |||
| else: | |||
| result = model(data) | |||
| results.extend(result) | |||
| results.append(result) | |||
| rank, world_size = get_dist_info() | |||
| if rank == 0: | |||
| batch_size = len(result) | |||
| batch_size_all = batch_size * world_size | |||
| @@ -110,15 +111,26 @@ def multi_gpu_test(model, | |||
| for _ in range(batch_size_all): | |||
| pbar.update() | |||
| # collect results from all ranks | |||
| # TODO: allgather data list may cost a lot of memory and needs to be redesigned | |||
| # collect results and data from all ranks | |||
| if gpu_collect: | |||
| results = collect_results_gpu(results, len(dataset)) | |||
| data_list = collect_results_gpu(data_list, len(dataset)) | |||
| else: | |||
| results = collect_results_cpu(results, len(dataset), tmpdir) | |||
| ground_truths = [dataset[i] for i in range(len(dataset))] | |||
| if metric_classes is not None: | |||
| for metric_cls in metric_classes: | |||
| metric_cls.add(results, ground_truths) | |||
| if tmpdir is None: | |||
| tmpdir = make_tmp_dir() | |||
| results = collect_results_cpu(results, len(dataset), | |||
| os.path.join(tmpdir, 'predict')) | |||
| data_list = collect_results_cpu(data_list, len(dataset), | |||
| os.path.join(tmpdir, 'groundtruth')) | |||
| if is_master(): | |||
| assert len(data_list) == len( | |||
| results), f'size mismatch {len(data_list)} and {len(results)}' | |||
| if metric_classes is not None: | |||
| for i in range(len(data_list)): | |||
| for metric_cls in metric_classes: | |||
| metric_cls.add(results[i], data_list[i]) | |||
| def collect_results_cpu(result_part, size, tmpdir=None): | |||
| @@ -140,13 +152,15 @@ def collect_results_cpu(result_part, size, tmpdir=None): | |||
| list: The collected results. | |||
| """ | |||
| rank, world_size = get_dist_info() | |||
| # TODO create a random tmp dir if it is not specified | |||
| if tmpdir is None: | |||
| tmpdir = tempfile.gettempdir() | |||
| if not os.path.exists(tmpdir): | |||
| tmpdir = make_tmp_dir() | |||
| if not os.path.exists(tmpdir) and is_master(): | |||
| os.makedirs(tmpdir) | |||
| dist.barrier() | |||
| # dump the part result to the dir | |||
| pickle.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl')) | |||
| with open(os.path.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: | |||
| pickle.dump(result_part, f) | |||
| dist.barrier() | |||
| # collect all parts | |||
| if rank != 0: | |||
| @@ -156,7 +170,8 @@ def collect_results_cpu(result_part, size, tmpdir=None): | |||
| part_list = [] | |||
| for i in range(world_size): | |||
| part_file = os.path.join(tmpdir, f'part_{i}.pkl') | |||
| part_result = pickle.load(part_file) | |||
| with open(part_file, 'rb') as f: | |||
| part_result = pickle.load(f) | |||
| # When data is severely insufficient, an empty part_result | |||
| # on a certain gpu could makes the overall outputs empty. | |||
| if part_result: | |||
| @@ -1,16 +1,23 @@ | |||
| #!/usr/bin/env python | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import copy | |||
| import os | |||
| import pickle | |||
| import shutil | |||
| import socket | |||
| import subprocess | |||
| import sys | |||
| import tarfile | |||
| import tempfile | |||
| import unittest | |||
| import numpy as np | |||
| import requests | |||
| from datasets import Dataset | |||
| from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | |||
| from modelscope.msdatasets import MsDataset | |||
| from .torch_utils import _find_free_port | |||
| TEST_LEVEL = 2 | |||
| TEST_LEVEL_STR = 'TEST_LEVEL' | |||
| @@ -62,3 +69,167 @@ def download_and_untar(fpath, furl, dst) -> str: | |||
| t.extractall(path=dst) | |||
| return target_dir_path | |||
| _DIST_SCRIPT_TEMPLATE = """ | |||
| import ast | |||
| import argparse | |||
| import pickle | |||
| import torch | |||
| from torch import distributed as dist | |||
| from modelscope.utils.torch_utils import get_dist_info | |||
| import {} | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results') | |||
| parser.add_argument('--save_file', type=str, help='save file') | |||
| parser.add_argument('--local_rank', type=int, default=0) | |||
| args = parser.parse_args() | |||
| def main(): | |||
| results = {}.{}({}) # module.func(params) | |||
| if args.save_all_ranks: | |||
| save_file = args.save_file + str(dist.get_rank()) | |||
| with open(save_file, 'wb') as f: | |||
| pickle.dump(results, f) | |||
| else: | |||
| rank, _ = get_dist_info() | |||
| if rank == 0: | |||
| with open(args.save_file, 'wb') as f: | |||
| pickle.dump(results, f) | |||
| if __name__ == '__main__': | |||
| main() | |||
| """ | |||
| class DistributedTestCase(unittest.TestCase): | |||
| """Distributed TestCase for test function with distributed mode. | |||
| Examples: | |||
| import torch | |||
| from torch import distributed as dist | |||
| from modelscope.utils.torch_utils import init_dist | |||
| def _test_func(*args, **kwargs): | |||
| init_dist(launcher='pytorch') | |||
| rank = dist.get_rank() | |||
| if rank == 0: | |||
| value = torch.tensor(1.0).cuda() | |||
| else: | |||
| value = torch.tensor(2.0).cuda() | |||
| dist.all_reduce(value) | |||
| return value.cpu().numpy() | |||
| class DistTest(DistributedTestCase): | |||
| def test_function_dist(self): | |||
| args = () # args should be python builtin type | |||
| kwargs = {} # kwargs should be python builtin type | |||
| self.start( | |||
| _test_func, | |||
| num_gpus=2, | |||
| assert_callback=lambda x: self.assertEqual(x, 3.0), | |||
| *args, | |||
| **kwargs, | |||
| ) | |||
| """ | |||
| def _start(self, | |||
| dist_start_cmd, | |||
| func, | |||
| num_gpus, | |||
| assert_callback=None, | |||
| save_all_ranks=False, | |||
| *args, | |||
| **kwargs): | |||
| script_path = func.__code__.co_filename | |||
| script_dir, script_name = os.path.split(script_path) | |||
| script_name = os.path.splitext(script_name)[0] | |||
| func_name = func.__qualname__ | |||
| func_params = [] | |||
| for arg in args: | |||
| if isinstance(arg, str): | |||
| arg = ('\'{}\''.format(arg)) | |||
| func_params.append(str(arg)) | |||
| for k, v in kwargs.items(): | |||
| if isinstance(v, str): | |||
| v = ('\'{}\''.format(v)) | |||
| func_params.append('{}={}'.format(k, v)) | |||
| func_params = ','.join(func_params).strip(',') | |||
| tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name | |||
| tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name | |||
| with open(tmp_run_file, 'w') as f: | |||
| print('save temporary run file to : {}'.format(tmp_run_file)) | |||
| print('save results to : {}'.format(tmp_res_file)) | |||
| run_file_content = _DIST_SCRIPT_TEMPLATE.format( | |||
| script_name, script_name, func_name, func_params) | |||
| f.write(run_file_content) | |||
| tmp_res_files = [] | |||
| if save_all_ranks: | |||
| for i in range(num_gpus): | |||
| tmp_res_files.append(tmp_res_file + str(i)) | |||
| else: | |||
| tmp_res_files = [tmp_res_file] | |||
| self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files) | |||
| tmp_env = copy.deepcopy(os.environ) | |||
| tmp_env['PYTHONPATH'] = ':'.join( | |||
| (tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':') | |||
| script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, | |||
| tmp_res_file) | |||
| script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params) | |||
| print('script command: %s' % script_cmd) | |||
| res = subprocess.call(script_cmd, shell=True, env=tmp_env) | |||
| script_res = [] | |||
| for res_file in tmp_res_files: | |||
| with open(res_file, 'rb') as f: | |||
| script_res.append(pickle.load(f)) | |||
| if not save_all_ranks: | |||
| script_res = script_res[0] | |||
| if assert_callback: | |||
| assert_callback(script_res) | |||
| self.assertEqual( | |||
| res, | |||
| 0, | |||
| msg='The test function ``{}`` in ``{}`` run failed!'.format( | |||
| func_name, script_name)) | |||
| return script_res | |||
| def start(self, | |||
| func, | |||
| num_gpus, | |||
| assert_callback=None, | |||
| save_all_ranks=False, | |||
| *args, | |||
| **kwargs): | |||
| ip = socket.gethostbyname(socket.gethostname()) | |||
| dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % ( | |||
| sys.executable, num_gpus, ip, _find_free_port()) | |||
| return self._start( | |||
| dist_start_cmd=dist_start_cmd, | |||
| func=func, | |||
| num_gpus=num_gpus, | |||
| assert_callback=assert_callback, | |||
| save_all_ranks=save_all_ranks, | |||
| *args, | |||
| **kwargs) | |||
| def clean_tmp(self, tmp_file_list): | |||
| for file in tmp_file_list: | |||
| if os.path.exists(file): | |||
| if os.path.isdir(file): | |||
| shutil.rmtree(file) | |||
| else: | |||
| os.remove(file) | |||
| @@ -1,11 +1,11 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # Following code is partialy borrowed from openmmlab/mmcv | |||
| import functools | |||
| import os | |||
| import pickle | |||
| import socket | |||
| import subprocess | |||
| from collections import OrderedDict | |||
| import tempfile | |||
| from typing import Callable, List, Optional, Tuple | |||
| import torch | |||
| @@ -116,6 +116,11 @@ def get_dist_info() -> Tuple[int, int]: | |||
| return rank, world_size | |||
| def is_master(): | |||
| rank, _ = get_dist_info() | |||
| return rank == 0 | |||
| def master_only(func: Callable) -> Callable: | |||
| @functools.wraps(func) | |||
| @@ -136,3 +141,53 @@ def create_device(cpu: bool = False) -> torch.DeviceObjType: | |||
| device = torch.device('cpu') | |||
| return device | |||
| def make_tmp_dir(): | |||
| """Make sure each rank has the same temporary directory on the distributed mode. | |||
| """ | |||
| rank, world_size = get_dist_info() | |||
| if world_size <= 1: | |||
| return tempfile.mkdtemp() | |||
| tmpdir = None | |||
| if rank == 0: | |||
| tmpdir = tempfile.mkdtemp() | |||
| dist.barrier() | |||
| tmpdir = broadcast(tmpdir, 0) | |||
| return tmpdir | |||
| def broadcast(inputs, src): | |||
| """ | |||
| Broadcasts the inputs to all ranks. | |||
| Arguments: | |||
| inputs : Any objects that can be serialized by pickle. | |||
| src (int): Source rank. | |||
| Returns: | |||
| Each rank returns the same value as src. | |||
| """ | |||
| rank, _ = get_dist_info() | |||
| shape_tensor = torch.tensor([0], device='cuda') | |||
| if rank == src: | |||
| inputs_tensor = torch.tensor( | |||
| bytearray(pickle.dumps(inputs)), dtype=torch.uint8, device='cuda') | |||
| shape_tensor = torch.tensor(inputs_tensor.shape, device='cuda') | |||
| dist.barrier() | |||
| dist.broadcast(shape_tensor, src) | |||
| if rank != src: | |||
| inputs_tensor = torch.full((shape_tensor.item(), ), | |||
| 0, | |||
| dtype=torch.uint8, | |||
| device='cuda') | |||
| dist.barrier() | |||
| dist.broadcast(inputs_tensor, src) | |||
| return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) | |||
| @@ -4,30 +4,30 @@ import inspect | |||
| import os | |||
| def if_func_receive_dict_inputs(func, inputs): | |||
| # TODO: remove this api, unify to flattened args | |||
| def if_func_receive_dict_inputs(func): | |||
| """to decide if a func could recieve dict inputs or not | |||
| Args: | |||
| func (class): the target function to be inspected | |||
| inputs (dicts): the inputs that will send to the function | |||
| Returns: | |||
| bool: if func recieve dict, then recieve True | |||
| Examples: | |||
| input = {"input_dict":xxx, "attention_masked":xxx}, | |||
| function(self, inputs) then return True | |||
| function(inputs) then return True | |||
| function(self, input_dict, attention_masked) then return False | |||
| bool: if func only has one arg ``input`` or ``inputs``, return True, else return False | |||
| """ | |||
| signature = inspect.signature(func) | |||
| func_inputs = list(signature.parameters.keys() - set(['self'])) | |||
| mismatched_inputs = list(set(func_inputs) - set(inputs)) | |||
| if len(func_inputs) == len(mismatched_inputs): | |||
| return True | |||
| else: | |||
| full_args_spec = inspect.getfullargspec(func) | |||
| varargs = full_args_spec.varargs | |||
| varkw = full_args_spec.varkw | |||
| if not (varargs is None and varkw is None): | |||
| return False | |||
| args = [] if not full_args_spec.args else full_args_spec.args | |||
| args.pop(0) if (args and args[0] in ['self', 'cls']) else args | |||
| if len(args) == 1 and args[0] in ['input', 'inputs']: | |||
| return True | |||
| return False | |||
| def get_default_cache_dir(): | |||
| """ | |||
| @@ -0,0 +1,264 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import glob | |||
| import os | |||
| import shutil | |||
| import tempfile | |||
| import unittest | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from torch import nn | |||
| from torch.optim import SGD | |||
| from torch.optim.lr_scheduler import StepLR | |||
| from modelscope.metrics.builder import MetricKeys | |||
| from modelscope.trainers import build_trainer | |||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||
| from modelscope.utils.test_utils import (DistributedTestCase, | |||
| create_dummy_test_dataset, test_level) | |||
| class DummyMetric: | |||
| def __call__(self, ground_truth, predict_results): | |||
| return {'accuracy': 0.5} | |||
| dummy_dataset_small = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | |||
| dummy_dataset_big = create_dummy_test_dataset( | |||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | |||
| class DummyModel(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.linear = nn.Linear(5, 4) | |||
| self.bn = nn.BatchNorm1d(4) | |||
| def forward(self, feat, labels): | |||
| x = self.linear(feat) | |||
| x = self.bn(x) | |||
| loss = torch.sum(x) | |||
| return dict(logits=x, loss=loss) | |||
| def train_func(work_dir, dist=False): | |||
| json_cfg = { | |||
| 'train': { | |||
| 'work_dir': work_dir, | |||
| 'dataloader': { | |||
| 'batch_size_per_gpu': 2, | |||
| 'workers_per_gpu': 1 | |||
| }, | |||
| 'hooks': [{ | |||
| 'type': 'EvaluationHook', | |||
| 'interval': 1 | |||
| }] | |||
| }, | |||
| 'evaluation': { | |||
| 'dataloader': { | |||
| 'batch_size_per_gpu': 1, | |||
| 'workers_per_gpu': 1, | |||
| 'shuffle': False | |||
| }, | |||
| 'metrics': ['seq_cls_metric'] | |||
| } | |||
| } | |||
| config_path = os.path.join(work_dir, ModelFile.CONFIGURATION) | |||
| with open(config_path, 'w') as f: | |||
| json.dump(json_cfg, f) | |||
| model = DummyModel() | |||
| optimmizer = SGD(model.parameters(), lr=0.01) | |||
| lr_scheduler = StepLR(optimmizer, 2) | |||
| trainer_name = 'EpochBasedTrainer' | |||
| kwargs = dict( | |||
| cfg_file=config_path, | |||
| model=model, | |||
| data_collator=None, | |||
| train_dataset=dummy_dataset_big, | |||
| eval_dataset=dummy_dataset_small, | |||
| optimizers=(optimmizer, lr_scheduler), | |||
| max_epochs=3, | |||
| device='gpu', | |||
| launcher='pytorch' if dist else None) | |||
| trainer = build_trainer(trainer_name, kwargs) | |||
| trainer.train() | |||
| @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||
| class TrainerTestSingleGpu(unittest.TestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| super().tearDown() | |||
| shutil.rmtree(self.tmp_dir) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_single_gpu(self): | |||
| train_func(self.tmp_dir) | |||
| results_files = os.listdir(self.tmp_dir) | |||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
| self.assertEqual(len(json_files), 1) | |||
| with open(json_files[0], 'r') as f: | |||
| lines = [i.strip() for i in f.readlines()] | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 1, | |||
| LogKeys.ITER: 10, | |||
| LogKeys.LR: 0.01 | |||
| }, json.loads(lines[0])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 1, | |||
| LogKeys.ITER: 20, | |||
| LogKeys.LR: 0.01 | |||
| }, json.loads(lines[1])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 1, | |||
| LogKeys.ITER: 20 | |||
| }, json.loads(lines[2])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 2, | |||
| LogKeys.ITER: 10, | |||
| LogKeys.LR: 0.001 | |||
| }, json.loads(lines[3])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 2, | |||
| LogKeys.ITER: 20, | |||
| LogKeys.LR: 0.001 | |||
| }, json.loads(lines[4])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 2, | |||
| LogKeys.ITER: 20 | |||
| }, json.loads(lines[5])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 3, | |||
| LogKeys.ITER: 10, | |||
| LogKeys.LR: 0.001 | |||
| }, json.loads(lines[6])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 3, | |||
| LogKeys.ITER: 20, | |||
| LogKeys.LR: 0.001 | |||
| }, json.loads(lines[7])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 3, | |||
| LogKeys.ITER: 20 | |||
| }, json.loads(lines[8])) | |||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||
| for i in [0, 1, 3, 4, 6, 7]: | |||
| self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||
| self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||
| for i in [2, 5, 8]: | |||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||
| @unittest.skipIf(not torch.cuda.is_available() | |||
| or torch.cuda.device_count() <= 1, 'distributed unittest') | |||
| class TrainerTestMultiGpus(DistributedTestCase): | |||
| def setUp(self): | |||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||
| if not os.path.exists(self.tmp_dir): | |||
| os.makedirs(self.tmp_dir) | |||
| def tearDown(self): | |||
| super().tearDown() | |||
| shutil.rmtree(self.tmp_dir) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_multi_gpus(self): | |||
| self.start(train_func, num_gpus=2, work_dir=self.tmp_dir, dist=True) | |||
| results_files = os.listdir(self.tmp_dir) | |||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
| self.assertEqual(len(json_files), 1) | |||
| with open(json_files[0], 'r') as f: | |||
| lines = [i.strip() for i in f.readlines()] | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 1, | |||
| LogKeys.ITER: 10, | |||
| LogKeys.LR: 0.01 | |||
| }, json.loads(lines[0])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 1, | |||
| LogKeys.ITER: 10 | |||
| }, json.loads(lines[1])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 2, | |||
| LogKeys.ITER: 10, | |||
| LogKeys.LR: 0.001 | |||
| }, json.loads(lines[2])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 2, | |||
| LogKeys.ITER: 10 | |||
| }, json.loads(lines[3])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.TRAIN, | |||
| LogKeys.EPOCH: 3, | |||
| LogKeys.ITER: 10, | |||
| LogKeys.LR: 0.001 | |||
| }, json.loads(lines[4])) | |||
| self.assertDictContainsSubset( | |||
| { | |||
| LogKeys.MODE: ModeKeys.EVAL, | |||
| LogKeys.EPOCH: 3, | |||
| LogKeys.ITER: 10 | |||
| }, json.loads(lines[5])) | |||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||
| for i in [0, 2, 4]: | |||
| self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||
| self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||
| for i in [1, 3, 5]: | |||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||