Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10185634 * fix dist trainingmaster
| @@ -37,8 +37,8 @@ from modelscope.utils.device import create_device, verify_device | |||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.registry import build_from_cfg | |||
| from modelscope.utils.torch_utils import (get_dist_info, init_dist, | |||
| set_random_seed) | |||
| from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, | |||
| init_dist, set_random_seed) | |||
| from .base import BaseTrainer | |||
| from .builder import TRAINERS | |||
| from .default_config import DEFAULT_CONFIG | |||
| @@ -155,8 +155,17 @@ class EpochBasedTrainer(BaseTrainer): | |||
| if self.eval_preprocessor is not None: | |||
| self.eval_preprocessor.mode = ModeKeys.EVAL | |||
| if kwargs.get('launcher', None) is not None: | |||
| init_dist(kwargs['launcher']) | |||
| _, world_size = get_dist_info() | |||
| self._dist = world_size > 1 | |||
| device_name = kwargs.get('device', 'gpu') | |||
| verify_device(device_name) | |||
| if self._dist: | |||
| local_rank = get_local_rank() | |||
| device_name = f'cuda:{local_rank}' | |||
| self.device = create_device(device_name) | |||
| self.train_dataset = self.to_task_dataset( | |||
| @@ -219,11 +228,6 @@ class EpochBasedTrainer(BaseTrainer): | |||
| self.use_fp16 = kwargs.get('use_fp16', False) | |||
| if kwargs.get('launcher', None) is not None: | |||
| init_dist(kwargs['launcher']) | |||
| self._dist = get_dist_info()[1] > 1 | |||
| # model placement | |||
| if self.device.type == 'cuda': | |||
| self.model.to(self.device) | |||
| @@ -531,8 +535,14 @@ class EpochBasedTrainer(BaseTrainer): | |||
| model.train() | |||
| self._mode = ModeKeys.TRAIN | |||
| # call model forward but not __call__ to skip postprocess | |||
| if isinstance(inputs, | |||
| Mapping) and not func_receive_dict_inputs(model.forward): | |||
| if is_parallel(model): | |||
| receive_dict_inputs = func_receive_dict_inputs( | |||
| model.module.forward) | |||
| else: | |||
| receive_dict_inputs = func_receive_dict_inputs(model.forward) | |||
| if isinstance(inputs, Mapping) and not receive_dict_inputs: | |||
| train_outputs = model.forward(**inputs) | |||
| else: | |||
| train_outputs = model.forward(inputs) | |||
| @@ -11,6 +11,7 @@ import torch | |||
| from torch import distributed as dist | |||
| from tqdm import tqdm | |||
| from modelscope.trainers.parallel.utils import is_parallel | |||
| from modelscope.utils.data_utils import to_device | |||
| from modelscope.utils.file_utils import func_receive_dict_inputs | |||
| from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master, | |||
| @@ -134,7 +135,10 @@ def multi_gpu_test(model, | |||
| data_len = data_loader_iters_per_gpu * world_size | |||
| desc = 'Total test iterations with multi gpus' | |||
| time.sleep(2) # This line can prevent deadlock problem in some cases. | |||
| if is_parallel(model): | |||
| receive_dict_inputs = func_receive_dict_inputs(model.module.forward) | |||
| else: | |||
| receive_dict_inputs = func_receive_dict_inputs(model.forward) | |||
| count = 0 | |||
| with tqdm(total=data_len, desc=desc) as pbar: | |||
| @@ -142,8 +146,7 @@ def multi_gpu_test(model, | |||
| data = to_device(data, device) | |||
| data_list.append(data) | |||
| with torch.no_grad(): | |||
| if isinstance(data, Mapping) and not func_receive_dict_inputs( | |||
| model.forward): | |||
| if isinstance(data, Mapping) and not receive_dict_inputs: | |||
| result = model.forward(**data) | |||
| else: | |||
| result = model.forward(data) | |||
| @@ -115,6 +115,10 @@ def get_dist_info() -> Tuple[int, int]: | |||
| return rank, world_size | |||
| def get_local_rank(): | |||
| return int(os.environ.get('LOCAL_RANK', 0)) | |||
| def is_master(): | |||
| rank, _ = get_dist_info() | |||
| return rank == 0 | |||
| @@ -53,7 +53,18 @@ class DummyModel(nn.Module, Model): | |||
| return dict(logits=x, loss=loss) | |||
| def train_func(work_dir, dist=False, iterable_dataset=False, **kwargs): | |||
| class DummyModelForwardInputs(DummyModel): | |||
| def forward(self, inputs): | |||
| feat, labels = inputs['feat'], inputs['labels'] | |||
| return super().forward(feat, labels) | |||
| def train_func(work_dir, | |||
| dist=False, | |||
| iterable_dataset=False, | |||
| forward_inputs=False, | |||
| **kwargs): | |||
| json_cfg = { | |||
| 'task': Tasks.image_classification, | |||
| 'train': { | |||
| @@ -81,7 +92,10 @@ def train_func(work_dir, dist=False, iterable_dataset=False, **kwargs): | |||
| with open(config_path, 'w') as f: | |||
| json.dump(json_cfg, f) | |||
| model = DummyModel() | |||
| if forward_inputs: | |||
| model = DummyModelForwardInputs() | |||
| else: | |||
| model = DummyModel() | |||
| optimmizer = SGD(model.parameters(), lr=0.01) | |||
| lr_scheduler = StepLR(optimmizer, 2) | |||
| trainer_name = Trainers.default | |||
| @@ -273,6 +287,22 @@ class TrainerTestMultiGpus(DistributedTestCase): | |||
| for i in [1, 3, 5]: | |||
| self.assertIn(MetricKeys.ACCURACY, lines[i]) | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_multi_gpus_forward_inputs(self): | |||
| self.start( | |||
| train_func, | |||
| num_gpus=2, | |||
| work_dir=self.tmp_dir, | |||
| dist=True, | |||
| forward_inputs=True) | |||
| results_files = os.listdir(self.tmp_dir) | |||
| json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json')) | |||
| self.assertEqual(len(json_files), 1) | |||
| self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files) | |||
| self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files) | |||
| # TODO: support iters_per_epoch for dist mode | |||
| @unittest.skipIf(True, 'need to adapt to DistributedSampler') | |||
| def test_multi_gpus_with_iters_per_epoch(self): | |||