Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9551089 * support distributed trainingmaster
| @@ -24,7 +24,7 @@ class SequenceClassificationMetric(Metric): | |||||
| self.labels = [] | self.labels = [] | ||||
| def add(self, outputs: Dict, inputs: Dict): | def add(self, outputs: Dict, inputs: Dict): | ||||
| ground_truths = inputs[SequenceClassificationMetric.label_name] | |||||
| ground_truths = inputs[self.label_name] | |||||
| eval_results = outputs[OutputKeys.LOGITS] | eval_results = outputs[OutputKeys.LOGITS] | ||||
| self.preds.append( | self.preds.append( | ||||
| torch_nested_numpify(torch_nested_detach(eval_results))) | torch_nested_numpify(torch_nested_detach(eval_results))) | ||||
| @@ -424,7 +424,7 @@ class SingleBackboneTaskModelBase(BaseTaskModel): | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| """default forward method is the backbone-only forward""" | """default forward method is the backbone-only forward""" | ||||
| if if_func_receive_dict_inputs(self.backbone.forward, input): | |||||
| if if_func_receive_dict_inputs(self.backbone.forward): | |||||
| outputs = self.backbone.forward(input) | outputs = self.backbone.forward(input) | ||||
| else: | else: | ||||
| outputs = self.backbone.forward(**input) | outputs = self.backbone.forward(**input) | ||||
| @@ -472,13 +472,13 @@ class EncoderDecoderTaskModelBase(BaseTaskModel): | |||||
| return getattr(self, self._decoder_prefix) | return getattr(self, self._decoder_prefix) | ||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | ||||
| if if_func_receive_dict_inputs(self.encoder_.forward, input): | |||||
| if if_func_receive_dict_inputs(self.encoder_.forward): | |||||
| encoder_outputs = self.encoder_.forward(input) | encoder_outputs = self.encoder_.forward(input) | ||||
| else: | else: | ||||
| encoder_outputs = self.encoder_.forward(**input) | encoder_outputs = self.encoder_.forward(**input) | ||||
| decoder_inputs = self.project_decoder_inputs_and_mediate( | decoder_inputs = self.project_decoder_inputs_and_mediate( | ||||
| input, encoder_outputs) | input, encoder_outputs) | ||||
| if if_func_receive_dict_inputs(self.decoder_.forward, input): | |||||
| if if_func_receive_dict_inputs(self.decoder_.forward): | |||||
| outputs = self.decoder_.forward(decoder_inputs) | outputs = self.decoder_.forward(decoder_inputs) | ||||
| else: | else: | ||||
| outputs = self.decoder_.forward(**decoder_inputs) | outputs = self.decoder_.forward(**decoder_inputs) | ||||
| @@ -5,7 +5,7 @@ from modelscope import __version__ | |||||
| from modelscope.utils.checkpoint import save_checkpoint | from modelscope.utils.checkpoint import save_checkpoint | ||||
| from modelscope.utils.constant import LogKeys | from modelscope.utils.constant import LogKeys | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.torch_utils import get_dist_info | |||||
| from modelscope.utils.torch_utils import is_master | |||||
| from .builder import HOOKS | from .builder import HOOKS | ||||
| from .hook import Hook | from .hook import Hook | ||||
| from .priority import Priority | from .priority import Priority | ||||
| @@ -47,15 +47,18 @@ class CheckpointHook(Hook): | |||||
| else: | else: | ||||
| self.logger = trainer.logger | self.logger = trainer.logger | ||||
| self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | |||||
| if is_master(): | |||||
| self.logger.info(f'Checkpoints will be saved to {self.save_dir}') | |||||
| def after_train_epoch(self, trainer): | def after_train_epoch(self, trainer): | ||||
| if not self.by_epoch: | if not self.by_epoch: | ||||
| return | return | ||||
| if self._should_save(trainer): | if self._should_save(trainer): | ||||
| self.logger.info(f'Saving checkpoint at {trainer.epoch + 1} epoch') | |||||
| self._save_checkpoint(trainer) | |||||
| if is_master(): | |||||
| self.logger.info( | |||||
| f'Saving checkpoint at {trainer.epoch + 1} epoch') | |||||
| self._save_checkpoint(trainer) | |||||
| def _save_checkpoint(self, trainer): | def _save_checkpoint(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| @@ -65,18 +68,17 @@ class CheckpointHook(Hook): | |||||
| cur_save_name = os.path.join( | cur_save_name = os.path.join( | ||||
| self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | self.save_dir, f'{LogKeys.ITER}_{trainer.iter + 1}.pth') | ||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| save_checkpoint(trainer.model, cur_save_name, trainer.optimizer) | |||||
| def after_train_iter(self, trainer): | def after_train_iter(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| return | return | ||||
| if self._should_save(trainer): | if self._should_save(trainer): | ||||
| self.logger.info( | |||||
| f'Saving checkpoint at {trainer.iter + 1} iterations') | |||||
| self._save_checkpoint(trainer) | |||||
| if is_master(): | |||||
| self.logger.info( | |||||
| f'Saving checkpoint at {trainer.iter + 1} iterations') | |||||
| self._save_checkpoint(trainer) | |||||
| def _should_save(self, trainer): | def _should_save(self, trainer): | ||||
| if self.by_epoch: | if self.by_epoch: | ||||
| @@ -11,7 +11,7 @@ from torch import distributed as dist | |||||
| from modelscope.trainers.hooks.builder import HOOKS | from modelscope.trainers.hooks.builder import HOOKS | ||||
| from modelscope.trainers.hooks.logger.base import LoggerHook | from modelscope.trainers.hooks.logger.base import LoggerHook | ||||
| from modelscope.utils.constant import LogKeys, ModeKeys | from modelscope.utils.constant import LogKeys, ModeKeys | ||||
| from modelscope.utils.torch_utils import get_dist_info | |||||
| from modelscope.utils.torch_utils import get_dist_info, is_master | |||||
| @HOOKS.register_module() | @HOOKS.register_module() | ||||
| @@ -130,7 +130,8 @@ class TextLoggerHook(LoggerHook): | |||||
| log_items.append(f'{name}: {val}') | log_items.append(f'{name}: {val}') | ||||
| log_str += ', '.join(log_items) | log_str += ', '.join(log_items) | ||||
| trainer.logger.info(log_str) | |||||
| if is_master(): | |||||
| trainer.logger.info(log_str) | |||||
| def _dump_log(self, log_dict): | def _dump_log(self, log_dict): | ||||
| # dump log in json format | # dump log in json format | ||||
| @@ -138,8 +139,7 @@ class TextLoggerHook(LoggerHook): | |||||
| for k, v in log_dict.items(): | for k, v in log_dict.items(): | ||||
| json_log[k] = self._round_float(v) | json_log[k] = self._round_float(v) | ||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| if is_master(): | |||||
| with open(self.json_log_path, 'a+') as f: | with open(self.json_log_path, 'a+') as f: | ||||
| json.dump(json_log, f) | json.dump(json_log, f) | ||||
| f.write('\n') | f.write('\n') | ||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .builder import PARALLEL | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from torch.nn.parallel.distributed import DistributedDataParallel | |||||
| from modelscope.utils.config import ConfigDict | |||||
| from modelscope.utils.registry import Registry, build_from_cfg | |||||
| PARALLEL = Registry('parallel') | |||||
| PARALLEL.register_module( | |||||
| module_name='DistributedDataParallel', module_cls=DistributedDataParallel) | |||||
| def build_parallel(cfg: ConfigDict, default_args: dict = None): | |||||
| """ build parallel | |||||
| Args: | |||||
| cfg (:obj:`ConfigDict`): config dict for parallel object. | |||||
| default_args (dict, optional): Default initialization arguments. | |||||
| """ | |||||
| return build_from_cfg(cfg, PARALLEL, default_args=default_args) | |||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from .builder import PARALLEL | |||||
| def is_parallel(module): | |||||
| """Check if a module is wrapped by parallel object. | |||||
| The following modules are regarded as parallel object: | |||||
| - torch.nn.parallel.DataParallel | |||||
| - torch.nn.parallel.distributed.DistributedDataParallel | |||||
| You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`. | |||||
| Args: | |||||
| module (nn.Module): The module to be checked. | |||||
| Returns: | |||||
| bool: True if the is wrapped by parallel object. | |||||
| """ | |||||
| module_wrappers = [] | |||||
| for group, module_dict in PARALLEL.modules.items(): | |||||
| module_wrappers.extend(list(module_dict.values())) | |||||
| return isinstance(module, tuple(module_wrappers)) | |||||
| @@ -1,5 +1,5 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path | |||||
| import os | |||||
| import random | import random | ||||
| import time | import time | ||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||
| @@ -32,12 +32,15 @@ from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Hubs, ModeKeys, | |||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from modelscope.utils.registry import build_from_cfg | from modelscope.utils.registry import build_from_cfg | ||||
| from modelscope.utils.tensor_utils import torch_default_data_collator | from modelscope.utils.tensor_utils import torch_default_data_collator | ||||
| from modelscope.utils.torch_utils import create_device, get_dist_info | |||||
| from modelscope.utils.torch_utils import (broadcast, create_device, | |||||
| get_dist_info, init_dist) | |||||
| from modelscope.utils.utils import if_func_receive_dict_inputs | from modelscope.utils.utils import if_func_receive_dict_inputs | ||||
| from .base import BaseTrainer | from .base import BaseTrainer | ||||
| from .builder import TRAINERS | from .builder import TRAINERS | ||||
| from .default_config import DEFAULT_CONFIG | from .default_config import DEFAULT_CONFIG | ||||
| from .hooks.hook import Hook | from .hooks.hook import Hook | ||||
| from .parallel.builder import build_parallel | |||||
| from .parallel.utils import is_parallel | |||||
| @TRAINERS.register_module() | @TRAINERS.register_module() | ||||
| @@ -150,11 +153,16 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| # TODO @wenmeng.zwm add seed init fn | # TODO @wenmeng.zwm add seed init fn | ||||
| self._seed = 0 | self._seed = 0 | ||||
| if kwargs.get('launcher', None) is not None: | |||||
| init_dist(kwargs['launcher']) | |||||
| self._dist = get_dist_info()[1] > 1 | self._dist = get_dist_info()[1] > 1 | ||||
| # model placement | # model placement | ||||
| if self.device.type == 'cuda': | if self.device.type == 'cuda': | ||||
| self.model.to(self.device) | self.model.to(self.device) | ||||
| if not is_parallel(self.model) and self._dist: | |||||
| self.model = self.to_parallel(self.model) | |||||
| @property | @property | ||||
| def mode(self): | def mode(self): | ||||
| @@ -287,7 +295,10 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.train_dataloader = self.get_train_dataloader() | self.train_dataloader = self.get_train_dataloader() | ||||
| else: | else: | ||||
| self.train_dataloader = self._build_dataloader_with_dataset( | self.train_dataloader = self._build_dataloader_with_dataset( | ||||
| self.train_dataset, **self.cfg.train.get('dataloader', {})) | |||||
| self.train_dataset, | |||||
| dist=self._dist, | |||||
| seed=self._seed, | |||||
| **self.cfg.train.get('dataloader', {})) | |||||
| self.data_loader = self.train_dataloader | self.data_loader = self.train_dataloader | ||||
| self.register_optimizers_hook() | self.register_optimizers_hook() | ||||
| @@ -303,15 +314,21 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self.eval_dataloader = self.get_eval_data_loader() | self.eval_dataloader = self.get_eval_data_loader() | ||||
| else: | else: | ||||
| self.eval_dataloader = self._build_dataloader_with_dataset( | self.eval_dataloader = self._build_dataloader_with_dataset( | ||||
| self.eval_dataset, **self.cfg.evaluation.get('dataloader', {})) | |||||
| self.eval_dataset, | |||||
| dist=self._dist, | |||||
| seed=self._seed, | |||||
| **self.cfg.evaluation.get('dataloader', {})) | |||||
| self.data_loader = self.eval_dataloader | self.data_loader = self.eval_dataloader | ||||
| metric_classes = [build_metric(metric) for metric in self.metrics] | metric_classes = [build_metric(metric) for metric in self.metrics] | ||||
| self.evaluation_loop(self.eval_dataloader, checkpoint_path, | self.evaluation_loop(self.eval_dataloader, checkpoint_path, | ||||
| metric_classes) | metric_classes) | ||||
| rank, world_size = get_dist_info() | |||||
| metric_values = {} | metric_values = {} | ||||
| for metric_cls in metric_classes: | |||||
| metric_values.update(metric_cls.evaluate()) | |||||
| if rank == 0: | |||||
| for metric_cls in metric_classes: | |||||
| metric_values.update(metric_cls.evaluate()) | |||||
| if world_size > 1: | |||||
| metric_values = broadcast(metric_values, 0) | |||||
| return metric_values | return metric_values | ||||
| def build_model(self) -> Union[nn.Module, TorchModel]: | def build_model(self) -> Union[nn.Module, TorchModel]: | ||||
| @@ -328,6 +345,20 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| elif isinstance(model, nn.Module): | elif isinstance(model, nn.Module): | ||||
| return model | return model | ||||
| def to_parallel(self, model) -> Union[nn.Module, TorchModel]: | |||||
| # config format to reserve custom ddp | |||||
| if self.cfg.get('parallel', None) is not None: | |||||
| self.cfg.parallel.update( | |||||
| dict(module=model, device_ids=[torch.cuda.current_device()])) | |||||
| return build_parallel(self.cfg.parallel) | |||||
| dp_cfg = dict( | |||||
| type='DistributedDataParallel', | |||||
| module=model, | |||||
| device_ids=[torch.cuda.current_device()]) | |||||
| return build_parallel(dp_cfg) | |||||
| def collate_fn(self, data): | def collate_fn(self, data): | ||||
| """Prepare the input just before the forward function. | """Prepare the input just before the forward function. | ||||
| This method will move the tensors to the right device. | This method will move the tensors to the right device. | ||||
| @@ -378,8 +409,9 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| self._mode = ModeKeys.TRAIN | self._mode = ModeKeys.TRAIN | ||||
| inputs = self.collate_fn(inputs) | inputs = self.collate_fn(inputs) | ||||
| # call model forward but not __call__ to skip postprocess | # call model forward but not __call__ to skip postprocess | ||||
| if isinstance(inputs, Mapping) and not if_func_receive_dict_inputs( | |||||
| model.forward, inputs): | |||||
| if isinstance( | |||||
| inputs, | |||||
| Mapping) and not if_func_receive_dict_inputs(model.forward): | |||||
| train_outputs = model.forward(**inputs) | train_outputs = model.forward(**inputs) | ||||
| else: | else: | ||||
| train_outputs = model.forward(inputs) | train_outputs = model.forward(inputs) | ||||
| @@ -444,7 +476,10 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| train_data, mode=ModeKeys.TRAIN) | train_data, mode=ModeKeys.TRAIN) | ||||
| data_loader = self._build_dataloader_with_dataset( | data_loader = self._build_dataloader_with_dataset( | ||||
| self.train_dataset, **self.cfg.train.get('dataloader', {})) | |||||
| self.train_dataset, | |||||
| dist=self._dist, | |||||
| seed=self._seed, | |||||
| **self.cfg.train.get('dataloader', {})) | |||||
| return data_loader | return data_loader | ||||
| def get_eval_data_loader(self): | def get_eval_data_loader(self): | ||||
| @@ -594,7 +629,7 @@ class EpochBasedTrainer(BaseTrainer): | |||||
| if dist: | if dist: | ||||
| sampler = DistributedSampler( | sampler = DistributedSampler( | ||||
| dataset, world_size, rank, shuffle=shuffle, seed=seed) | |||||
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | |||||
| else: | else: | ||||
| sampler = None | sampler = None | ||||
| @@ -3,7 +3,6 @@ | |||||
| import os | import os | ||||
| import pickle | import pickle | ||||
| import shutil | import shutil | ||||
| import tempfile | |||||
| import time | import time | ||||
| from collections.abc import Mapping | from collections.abc import Mapping | ||||
| @@ -11,8 +10,7 @@ import torch | |||||
| from torch import distributed as dist | from torch import distributed as dist | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from modelscope.models.base import Model | |||||
| from modelscope.utils.torch_utils import get_dist_info | |||||
| from modelscope.utils.torch_utils import get_dist_info, is_master, make_tmp_dir | |||||
| from modelscope.utils.utils import if_func_receive_dict_inputs | from modelscope.utils.utils import if_func_receive_dict_inputs | ||||
| @@ -40,7 +38,7 @@ def single_gpu_test(model, | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if isinstance(data, | if isinstance(data, | ||||
| Mapping) and not if_func_receive_dict_inputs( | Mapping) and not if_func_receive_dict_inputs( | ||||
| model.forward, data): | |||||
| model.forward): | |||||
| result = model(**data) | result = model(**data) | ||||
| else: | else: | ||||
| @@ -82,25 +80,28 @@ def multi_gpu_test(model, | |||||
| """ | """ | ||||
| model.eval() | model.eval() | ||||
| results = [] | results = [] | ||||
| data_list = [] | |||||
| dataset = data_loader.dataset | dataset = data_loader.dataset | ||||
| time.sleep(2) # This line can prevent deadlock problem in some cases. | time.sleep(2) # This line can prevent deadlock problem in some cases. | ||||
| rank, world_size = get_dist_info() | |||||
| count = 0 | count = 0 | ||||
| with tqdm(total=len(dataset), desc='test samples with multi gpus') as pbar: | with tqdm(total=len(dataset), desc='test samples with multi gpus') as pbar: | ||||
| for _, data in enumerate(data_loader): | for _, data in enumerate(data_loader): | ||||
| if data_collate_fn is not None: | if data_collate_fn is not None: | ||||
| data = data_collate_fn(data) | data = data_collate_fn(data) | ||||
| data_list.append(data) | |||||
| with torch.no_grad(): | with torch.no_grad(): | ||||
| if isinstance(data, | if isinstance(data, | ||||
| Mapping) and not if_func_receive_dict_inputs( | Mapping) and not if_func_receive_dict_inputs( | ||||
| model.forward, data): | |||||
| model.forward): | |||||
| result = model(**data) | result = model(**data) | ||||
| else: | else: | ||||
| result = model(data) | result = model(data) | ||||
| results.extend(result) | |||||
| results.append(result) | |||||
| rank, world_size = get_dist_info() | |||||
| if rank == 0: | if rank == 0: | ||||
| batch_size = len(result) | batch_size = len(result) | ||||
| batch_size_all = batch_size * world_size | batch_size_all = batch_size * world_size | ||||
| @@ -110,15 +111,26 @@ def multi_gpu_test(model, | |||||
| for _ in range(batch_size_all): | for _ in range(batch_size_all): | ||||
| pbar.update() | pbar.update() | ||||
| # collect results from all ranks | |||||
| # TODO: allgather data list may cost a lot of memory and needs to be redesigned | |||||
| # collect results and data from all ranks | |||||
| if gpu_collect: | if gpu_collect: | ||||
| results = collect_results_gpu(results, len(dataset)) | results = collect_results_gpu(results, len(dataset)) | ||||
| data_list = collect_results_gpu(data_list, len(dataset)) | |||||
| else: | else: | ||||
| results = collect_results_cpu(results, len(dataset), tmpdir) | |||||
| ground_truths = [dataset[i] for i in range(len(dataset))] | |||||
| if metric_classes is not None: | |||||
| for metric_cls in metric_classes: | |||||
| metric_cls.add(results, ground_truths) | |||||
| if tmpdir is None: | |||||
| tmpdir = make_tmp_dir() | |||||
| results = collect_results_cpu(results, len(dataset), | |||||
| os.path.join(tmpdir, 'predict')) | |||||
| data_list = collect_results_cpu(data_list, len(dataset), | |||||
| os.path.join(tmpdir, 'groundtruth')) | |||||
| if is_master(): | |||||
| assert len(data_list) == len( | |||||
| results), f'size mismatch {len(data_list)} and {len(results)}' | |||||
| if metric_classes is not None: | |||||
| for i in range(len(data_list)): | |||||
| for metric_cls in metric_classes: | |||||
| metric_cls.add(results[i], data_list[i]) | |||||
| def collect_results_cpu(result_part, size, tmpdir=None): | def collect_results_cpu(result_part, size, tmpdir=None): | ||||
| @@ -140,13 +152,15 @@ def collect_results_cpu(result_part, size, tmpdir=None): | |||||
| list: The collected results. | list: The collected results. | ||||
| """ | """ | ||||
| rank, world_size = get_dist_info() | rank, world_size = get_dist_info() | ||||
| # TODO create a random tmp dir if it is not specified | |||||
| if tmpdir is None: | if tmpdir is None: | ||||
| tmpdir = tempfile.gettempdir() | |||||
| if not os.path.exists(tmpdir): | |||||
| tmpdir = make_tmp_dir() | |||||
| if not os.path.exists(tmpdir) and is_master(): | |||||
| os.makedirs(tmpdir) | os.makedirs(tmpdir) | ||||
| dist.barrier() | |||||
| # dump the part result to the dir | # dump the part result to the dir | ||||
| pickle.dump(result_part, os.path.join(tmpdir, f'part_{rank}.pkl')) | |||||
| with open(os.path.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f: | |||||
| pickle.dump(result_part, f) | |||||
| dist.barrier() | dist.barrier() | ||||
| # collect all parts | # collect all parts | ||||
| if rank != 0: | if rank != 0: | ||||
| @@ -156,7 +170,8 @@ def collect_results_cpu(result_part, size, tmpdir=None): | |||||
| part_list = [] | part_list = [] | ||||
| for i in range(world_size): | for i in range(world_size): | ||||
| part_file = os.path.join(tmpdir, f'part_{i}.pkl') | part_file = os.path.join(tmpdir, f'part_{i}.pkl') | ||||
| part_result = pickle.load(part_file) | |||||
| with open(part_file, 'rb') as f: | |||||
| part_result = pickle.load(f) | |||||
| # When data is severely insufficient, an empty part_result | # When data is severely insufficient, an empty part_result | ||||
| # on a certain gpu could makes the overall outputs empty. | # on a certain gpu could makes the overall outputs empty. | ||||
| if part_result: | if part_result: | ||||
| @@ -1,16 +1,23 @@ | |||||
| #!/usr/bin/env python | #!/usr/bin/env python | ||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import copy | |||||
| import os | import os | ||||
| import pickle | |||||
| import shutil | |||||
| import socket | |||||
| import subprocess | |||||
| import sys | |||||
| import tarfile | import tarfile | ||||
| import tempfile | |||||
| import unittest | import unittest | ||||
| import numpy as np | |||||
| import requests | import requests | ||||
| from datasets import Dataset | from datasets import Dataset | ||||
| from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | from datasets.config import TF_AVAILABLE, TORCH_AVAILABLE | ||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from .torch_utils import _find_free_port | |||||
| TEST_LEVEL = 2 | TEST_LEVEL = 2 | ||||
| TEST_LEVEL_STR = 'TEST_LEVEL' | TEST_LEVEL_STR = 'TEST_LEVEL' | ||||
| @@ -62,3 +69,167 @@ def download_and_untar(fpath, furl, dst) -> str: | |||||
| t.extractall(path=dst) | t.extractall(path=dst) | ||||
| return target_dir_path | return target_dir_path | ||||
| _DIST_SCRIPT_TEMPLATE = """ | |||||
| import ast | |||||
| import argparse | |||||
| import pickle | |||||
| import torch | |||||
| from torch import distributed as dist | |||||
| from modelscope.utils.torch_utils import get_dist_info | |||||
| import {} | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results') | |||||
| parser.add_argument('--save_file', type=str, help='save file') | |||||
| parser.add_argument('--local_rank', type=int, default=0) | |||||
| args = parser.parse_args() | |||||
| def main(): | |||||
| results = {}.{}({}) # module.func(params) | |||||
| if args.save_all_ranks: | |||||
| save_file = args.save_file + str(dist.get_rank()) | |||||
| with open(save_file, 'wb') as f: | |||||
| pickle.dump(results, f) | |||||
| else: | |||||
| rank, _ = get_dist_info() | |||||
| if rank == 0: | |||||
| with open(args.save_file, 'wb') as f: | |||||
| pickle.dump(results, f) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| """ | |||||
| class DistributedTestCase(unittest.TestCase): | |||||
| """Distributed TestCase for test function with distributed mode. | |||||
| Examples: | |||||
| import torch | |||||
| from torch import distributed as dist | |||||
| from modelscope.utils.torch_utils import init_dist | |||||
| def _test_func(*args, **kwargs): | |||||
| init_dist(launcher='pytorch') | |||||
| rank = dist.get_rank() | |||||
| if rank == 0: | |||||
| value = torch.tensor(1.0).cuda() | |||||
| else: | |||||
| value = torch.tensor(2.0).cuda() | |||||
| dist.all_reduce(value) | |||||
| return value.cpu().numpy() | |||||
| class DistTest(DistributedTestCase): | |||||
| def test_function_dist(self): | |||||
| args = () # args should be python builtin type | |||||
| kwargs = {} # kwargs should be python builtin type | |||||
| self.start( | |||||
| _test_func, | |||||
| num_gpus=2, | |||||
| assert_callback=lambda x: self.assertEqual(x, 3.0), | |||||
| *args, | |||||
| **kwargs, | |||||
| ) | |||||
| """ | |||||
| def _start(self, | |||||
| dist_start_cmd, | |||||
| func, | |||||
| num_gpus, | |||||
| assert_callback=None, | |||||
| save_all_ranks=False, | |||||
| *args, | |||||
| **kwargs): | |||||
| script_path = func.__code__.co_filename | |||||
| script_dir, script_name = os.path.split(script_path) | |||||
| script_name = os.path.splitext(script_name)[0] | |||||
| func_name = func.__qualname__ | |||||
| func_params = [] | |||||
| for arg in args: | |||||
| if isinstance(arg, str): | |||||
| arg = ('\'{}\''.format(arg)) | |||||
| func_params.append(str(arg)) | |||||
| for k, v in kwargs.items(): | |||||
| if isinstance(v, str): | |||||
| v = ('\'{}\''.format(v)) | |||||
| func_params.append('{}={}'.format(k, v)) | |||||
| func_params = ','.join(func_params).strip(',') | |||||
| tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name | |||||
| tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name | |||||
| with open(tmp_run_file, 'w') as f: | |||||
| print('save temporary run file to : {}'.format(tmp_run_file)) | |||||
| print('save results to : {}'.format(tmp_res_file)) | |||||
| run_file_content = _DIST_SCRIPT_TEMPLATE.format( | |||||
| script_name, script_name, func_name, func_params) | |||||
| f.write(run_file_content) | |||||
| tmp_res_files = [] | |||||
| if save_all_ranks: | |||||
| for i in range(num_gpus): | |||||
| tmp_res_files.append(tmp_res_file + str(i)) | |||||
| else: | |||||
| tmp_res_files = [tmp_res_file] | |||||
| self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files) | |||||
| tmp_env = copy.deepcopy(os.environ) | |||||
| tmp_env['PYTHONPATH'] = ':'.join( | |||||
| (tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':') | |||||
| script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, | |||||
| tmp_res_file) | |||||
| script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params) | |||||
| print('script command: %s' % script_cmd) | |||||
| res = subprocess.call(script_cmd, shell=True, env=tmp_env) | |||||
| script_res = [] | |||||
| for res_file in tmp_res_files: | |||||
| with open(res_file, 'rb') as f: | |||||
| script_res.append(pickle.load(f)) | |||||
| if not save_all_ranks: | |||||
| script_res = script_res[0] | |||||
| if assert_callback: | |||||
| assert_callback(script_res) | |||||
| self.assertEqual( | |||||
| res, | |||||
| 0, | |||||
| msg='The test function ``{}`` in ``{}`` run failed!'.format( | |||||
| func_name, script_name)) | |||||
| return script_res | |||||
| def start(self, | |||||
| func, | |||||
| num_gpus, | |||||
| assert_callback=None, | |||||
| save_all_ranks=False, | |||||
| *args, | |||||
| **kwargs): | |||||
| ip = socket.gethostbyname(socket.gethostname()) | |||||
| dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % ( | |||||
| sys.executable, num_gpus, ip, _find_free_port()) | |||||
| return self._start( | |||||
| dist_start_cmd=dist_start_cmd, | |||||
| func=func, | |||||
| num_gpus=num_gpus, | |||||
| assert_callback=assert_callback, | |||||
| save_all_ranks=save_all_ranks, | |||||
| *args, | |||||
| **kwargs) | |||||
| def clean_tmp(self, tmp_file_list): | |||||
| for file in tmp_file_list: | |||||
| if os.path.exists(file): | |||||
| if os.path.isdir(file): | |||||
| shutil.rmtree(file) | |||||
| else: | |||||
| os.remove(file) | |||||
| @@ -1,11 +1,11 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| # Following code is partialy borrowed from openmmlab/mmcv | # Following code is partialy borrowed from openmmlab/mmcv | ||||
| import functools | import functools | ||||
| import os | import os | ||||
| import pickle | |||||
| import socket | import socket | ||||
| import subprocess | import subprocess | ||||
| from collections import OrderedDict | |||||
| import tempfile | |||||
| from typing import Callable, List, Optional, Tuple | from typing import Callable, List, Optional, Tuple | ||||
| import torch | import torch | ||||
| @@ -116,6 +116,11 @@ def get_dist_info() -> Tuple[int, int]: | |||||
| return rank, world_size | return rank, world_size | ||||
| def is_master(): | |||||
| rank, _ = get_dist_info() | |||||
| return rank == 0 | |||||
| def master_only(func: Callable) -> Callable: | def master_only(func: Callable) -> Callable: | ||||
| @functools.wraps(func) | @functools.wraps(func) | ||||
| @@ -136,3 +141,53 @@ def create_device(cpu: bool = False) -> torch.DeviceObjType: | |||||
| device = torch.device('cpu') | device = torch.device('cpu') | ||||
| return device | return device | ||||
| def make_tmp_dir(): | |||||
| """Make sure each rank has the same temporary directory on the distributed mode. | |||||
| """ | |||||
| rank, world_size = get_dist_info() | |||||
| if world_size <= 1: | |||||
| return tempfile.mkdtemp() | |||||
| tmpdir = None | |||||
| if rank == 0: | |||||
| tmpdir = tempfile.mkdtemp() | |||||
| dist.barrier() | |||||
| tmpdir = broadcast(tmpdir, 0) | |||||
| return tmpdir | |||||
| def broadcast(inputs, src): | |||||
| """ | |||||
| Broadcasts the inputs to all ranks. | |||||
| Arguments: | |||||
| inputs : Any objects that can be serialized by pickle. | |||||
| src (int): Source rank. | |||||
| Returns: | |||||
| Each rank returns the same value as src. | |||||
| """ | |||||
| rank, _ = get_dist_info() | |||||
| shape_tensor = torch.tensor([0], device='cuda') | |||||
| if rank == src: | |||||
| inputs_tensor = torch.tensor( | |||||
| bytearray(pickle.dumps(inputs)), dtype=torch.uint8, device='cuda') | |||||
| shape_tensor = torch.tensor(inputs_tensor.shape, device='cuda') | |||||
| dist.barrier() | |||||
| dist.broadcast(shape_tensor, src) | |||||
| if rank != src: | |||||
| inputs_tensor = torch.full((shape_tensor.item(), ), | |||||
| 0, | |||||
| dtype=torch.uint8, | |||||
| device='cuda') | |||||
| dist.barrier() | |||||
| dist.broadcast(inputs_tensor, src) | |||||
| return pickle.loads(inputs_tensor.cpu().numpy().tobytes()) | |||||
| @@ -4,30 +4,30 @@ import inspect | |||||
| import os | import os | ||||
| def if_func_receive_dict_inputs(func, inputs): | |||||
| # TODO: remove this api, unify to flattened args | |||||
| def if_func_receive_dict_inputs(func): | |||||
| """to decide if a func could recieve dict inputs or not | """to decide if a func could recieve dict inputs or not | ||||
| Args: | Args: | ||||
| func (class): the target function to be inspected | func (class): the target function to be inspected | ||||
| inputs (dicts): the inputs that will send to the function | |||||
| Returns: | Returns: | ||||
| bool: if func recieve dict, then recieve True | |||||
| Examples: | |||||
| input = {"input_dict":xxx, "attention_masked":xxx}, | |||||
| function(self, inputs) then return True | |||||
| function(inputs) then return True | |||||
| function(self, input_dict, attention_masked) then return False | |||||
| bool: if func only has one arg ``input`` or ``inputs``, return True, else return False | |||||
| """ | """ | ||||
| signature = inspect.signature(func) | |||||
| func_inputs = list(signature.parameters.keys() - set(['self'])) | |||||
| mismatched_inputs = list(set(func_inputs) - set(inputs)) | |||||
| if len(func_inputs) == len(mismatched_inputs): | |||||
| return True | |||||
| else: | |||||
| full_args_spec = inspect.getfullargspec(func) | |||||
| varargs = full_args_spec.varargs | |||||
| varkw = full_args_spec.varkw | |||||
| if not (varargs is None and varkw is None): | |||||
| return False | return False | ||||
| args = [] if not full_args_spec.args else full_args_spec.args | |||||
| args.pop(0) if (args and args[0] in ['self', 'cls']) else args | |||||
| if len(args) == 1 and args[0] in ['input', 'inputs']: | |||||
| return True | |||||
| return False | |||||
| def get_default_cache_dir(): | def get_default_cache_dir(): | ||||
| """ | """ | ||||
| @@ -0,0 +1,264 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import glob | |||||
| import os | |||||
| import shutil | |||||
| import tempfile | |||||
| import unittest | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.optim import SGD | |||||
| from torch.optim.lr_scheduler import StepLR | |||||
| from modelscope.metrics.builder import MetricKeys | |||||
| from modelscope.trainers import build_trainer | |||||
| from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile | |||||
| from modelscope.utils.test_utils import (DistributedTestCase, | |||||
| create_dummy_test_dataset, test_level) | |||||
| class DummyMetric: | |||||
| def __call__(self, ground_truth, predict_results): | |||||
| return {'accuracy': 0.5} | |||||
| dummy_dataset_small = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20) | |||||
| dummy_dataset_big = create_dummy_test_dataset( | |||||
| np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40) | |||||
| class DummyModel(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.linear = nn.Linear(5, 4) | |||||
| self.bn = nn.BatchNorm1d(4) | |||||
| def forward(self, feat, labels): | |||||
| x = self.linear(feat) | |||||
| x = self.bn(x) | |||||
| loss = torch.sum(x) | |||||
| return dict(logits=x, loss=loss) | |||||
| def train_func(work_dir, dist=False): | |||||
| json_cfg = { | |||||
| 'train': { | |||||
| 'work_dir': work_dir, | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 2, | |||||
| 'workers_per_gpu': 1 | |||||
| }, | |||||
| 'hooks': [{ | |||||
| 'type': 'EvaluationHook', | |||||
| 'interval': 1 | |||||
| }] | |||||
| }, | |||||
| 'evaluation': { | |||||
| 'dataloader': { | |||||
| 'batch_size_per_gpu': 1, | |||||
| 'workers_per_gpu': 1, | |||||
| 'shuffle': False | |||||
| }, | |||||
| 'metrics': ['seq_cls_metric'] | |||||
| } | |||||
| } | |||||
| config_path = os.path.join(work_dir, ModelFile.CONFIGURATION) | |||||
| with open(config_path, 'w') as f: | |||||
| json.dump(json_cfg, f) | |||||
| model = DummyModel() | |||||
| optimmizer = SGD(model.parameters(), lr=0.01) | |||||
| lr_scheduler = StepLR(optimmizer, 2) | |||||
| trainer_name = 'EpochBasedTrainer' | |||||
| kwargs = dict( | |||||
| cfg_file=config_path, | |||||
| model=model, | |||||
| data_collator=None, | |||||
| train_dataset=dummy_dataset_big, | |||||
| eval_dataset=dummy_dataset_small, | |||||
| optimizers=(optimmizer, lr_scheduler), | |||||
| max_epochs=3, | |||||
| device='gpu', | |||||
| launcher='pytorch' if dist else None) | |||||
| trainer = build_trainer(trainer_name, kwargs) | |||||
| trainer.train() | |||||
| @unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest') | |||||
| class TrainerTestSingleGpu(unittest.TestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_single_gpu(self): | |||||
| train_func(self.tmp_dir) | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| with open(json_files[0], 'r') as f: | |||||
| lines = [i.strip() for i in f.readlines()] | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[0])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[1])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 20 | |||||
| }, json.loads(lines[2])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[3])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[4])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 20 | |||||
| }, json.loads(lines[5])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[6])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 20, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[7])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 20 | |||||
| }, json.loads(lines[8])) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| for i in [0, 1, 3, 4, 6, 7]: | |||||
| self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||||
| self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||||
| for i in [2, 5, 8]: | |||||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||||
| @unittest.skipIf(not torch.cuda.is_available() | |||||
| or torch.cuda.device_count() <= 1, 'distributed unittest') | |||||
| class TrainerTestMultiGpus(DistributedTestCase): | |||||
| def setUp(self): | |||||
| print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) | |||||
| self.tmp_dir = tempfile.TemporaryDirectory().name | |||||
| if not os.path.exists(self.tmp_dir): | |||||
| os.makedirs(self.tmp_dir) | |||||
| def tearDown(self): | |||||
| super().tearDown() | |||||
| shutil.rmtree(self.tmp_dir) | |||||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||||
| def test_multi_gpus(self): | |||||
| self.start(train_func, num_gpus=2, work_dir=self.tmp_dir, dist=True) | |||||
| results_files = os.listdir(self.tmp_dir) | |||||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||||
| self.assertEqual(len(json_files), 1) | |||||
| with open(json_files[0], 'r') as f: | |||||
| lines = [i.strip() for i in f.readlines()] | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.01 | |||||
| }, json.loads(lines[0])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 1, | |||||
| LogKeys.ITER: 10 | |||||
| }, json.loads(lines[1])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[2])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 2, | |||||
| LogKeys.ITER: 10 | |||||
| }, json.loads(lines[3])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.TRAIN, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 10, | |||||
| LogKeys.LR: 0.001 | |||||
| }, json.loads(lines[4])) | |||||
| self.assertDictContainsSubset( | |||||
| { | |||||
| LogKeys.MODE: ModeKeys.EVAL, | |||||
| LogKeys.EPOCH: 3, | |||||
| LogKeys.ITER: 10 | |||||
| }, json.loads(lines[5])) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||||
| for i in [0, 2, 4]: | |||||
| self.assertIn(LogKeys.DATA_LOAD_TIME, lines[i]) | |||||
| self.assertIn(LogKeys.ITER_TIME, lines[i]) | |||||
| for i in [1, 3, 5]: | |||||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||