| @@ -14,7 +14,7 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| import argparse | import argparse | ||||
| import ast | |||||
| def launch_parse_args(): | def launch_parse_args(): | ||||
| @@ -43,6 +43,7 @@ def train_parse_args(): | |||||
| help='run platform, only support CPU, GPU and Ascend') | help='run platform, only support CPU, GPU and Ascend') | ||||
| train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | ||||
| for fine tune or incremental learning') | for fine tune or incremental learning') | ||||
| train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | |||||
| train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ | train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ | ||||
| help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ | help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ | ||||
| train from initialization model") | train from initialization model") | ||||
| @@ -59,6 +59,7 @@ def set_config(args): | |||||
| "save_checkpoint_path": "./checkpoint", | "save_checkpoint_path": "./checkpoint", | ||||
| "platform": args.platform, | "platform": args.platform, | ||||
| "ccl": "nccl", | "ccl": "nccl", | ||||
| "run_distribute": args.run_distribute | |||||
| }) | }) | ||||
| config_ascend = ed({ | config_ascend = ed({ | ||||
| "num_classes": 1000, | "num_classes": 1000, | ||||
| @@ -51,9 +51,12 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): | |||||
| num_shards=rank_size, shard_id=rank_id) | num_shards=rank_size, shard_id=rank_id) | ||||
| elif config.platform == "GPU": | elif config.platform == "GPU": | ||||
| if do_train: | if do_train: | ||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=get_group_size(), shard_id=get_rank()) | |||||
| if config.run_distribute: | |||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=get_group_size(), shard_id=get_rank()) | |||||
| else: | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||||
| else: | else: | ||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | ||||
| elif config.platform == "CPU": | elif config.platform == "CPU": | ||||
| @@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size | |||||
| from src.models import Monitor | from src.models import Monitor | ||||
| def switch_precision(net, data_type, config): | def switch_precision(net, data_type, config): | ||||
| if config.platform == "Ascend": | if config.platform == "Ascend": | ||||
| net.to_float(data_type) | net.to_float(data_type) | ||||
| @@ -29,17 +30,18 @@ def switch_precision(net, data_type, config): | |||||
| if isinstance(cell, nn.Dense): | if isinstance(cell, nn.Dense): | ||||
| cell.to_float(mstype.float32) | cell.to_float(mstype.float32) | ||||
| def context_device_init(config): | |||||
| def context_device_init(config): | |||||
| if config.platform == "CPU": | if config.platform == "CPU": | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) | ||||
| elif config.platform == "GPU": | elif config.platform == "GPU": | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) | ||||
| init("nccl") | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| if config.run_distribute: | |||||
| init("nccl") | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| elif config.platform == "Ascend": | elif config.platform == "Ascend": | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id, | context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id, | ||||
| @@ -53,6 +55,7 @@ def context_device_init(config): | |||||
| else: | else: | ||||
| raise ValueError("Only support CPU, GPU and Ascend.") | raise ValueError("Only support CPU, GPU and Ascend.") | ||||
| def set_context(config): | def set_context(config): | ||||
| if config.platform == "CPU": | if config.platform == "CPU": | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, | context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, | ||||
| @@ -64,6 +67,7 @@ def set_context(config): | |||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target=config.platform, save_graphs=False) | device_target=config.platform, save_graphs=False) | ||||
| def config_ckpoint(config, lr, step_size): | def config_ckpoint(config, lr, step_size): | ||||
| cb = None | cb = None | ||||
| if config.platform in ("CPU", "GPU") or config.rank_id == 0: | if config.platform in ("CPU", "GPU") or config.rank_id == 0: | ||||
| @@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size): | |||||
| ckpt_save_dir = config.save_checkpoint_path | ckpt_save_dir = config.save_checkpoint_path | ||||
| if config.platform == "GPU": | if config.platform == "GPU": | ||||
| ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" | |||||
| if config.run_distribute: | |||||
| ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" | |||||
| else: | |||||
| ckpt_save_dir += "ckpt_" + "/" | |||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) | ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) | ||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| @@ -21,7 +21,7 @@ import mindspore.dataset.vision.c_transforms as C | |||||
| import mindspore.dataset.transforms.c_transforms as C2 | import mindspore.dataset.transforms.c_transforms as C2 | ||||
| def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): | |||||
| def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32, run_distribute=False): | |||||
| """ | """ | ||||
| create a train or eval dataset | create a train or eval dataset | ||||
| @@ -36,9 +36,12 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, | |||||
| """ | """ | ||||
| if device_target == "GPU": | if device_target == "GPU": | ||||
| if do_train: | if do_train: | ||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=get_group_size(), shard_id=get_rank()) | |||||
| if run_distribute: | |||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=get_group_size(), shard_id=get_rank()) | |||||
| else: | |||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||||
| else: | else: | ||||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | ||||
| else: | else: | ||||
| @@ -56,7 +59,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, | |||||
| resize_op = C.Resize(256) | resize_op = C.Resize(256) | ||||
| center_crop = C.CenterCrop(resize_width) | center_crop = C.CenterCrop(resize_width) | ||||
| rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | ||||
| normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) | |||||
| normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], | |||||
| std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) | |||||
| change_swap_op = C.HWC2CHW() | change_swap_op = C.HWC2CHW() | ||||
| if do_train: | if do_train: | ||||
| @@ -16,6 +16,7 @@ | |||||
| import time | import time | ||||
| import argparse | import argparse | ||||
| import ast | |||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context | from mindspore import context | ||||
| @@ -46,16 +47,18 @@ parser = argparse.ArgumentParser(description='Image classification') | |||||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | ||||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | ||||
| parser.add_argument('--device_target', type=str, default="GPU", help='run device_target') | parser.add_argument('--device_target', type=str, default="GPU", help='run device_target') | ||||
| parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | |||||
| args_opt = parser.parse_args() | args_opt = parser.parse_args() | ||||
| if args_opt.device_target == "GPU": | if args_opt.device_target == "GPU": | ||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="GPU", | device_target="GPU", | ||||
| save_graphs=False) | save_graphs=False) | ||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| if args_opt.run_distribute: | |||||
| init() | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| gradients_mean=True) | |||||
| else: | else: | ||||
| raise ValueError("Unsupported device_target.") | raise ValueError("Unsupported device_target.") | ||||
| @@ -168,7 +171,8 @@ if __name__ == '__main__': | |||||
| config=config_gpu, | config=config_gpu, | ||||
| device_target=args_opt.device_target, | device_target=args_opt.device_target, | ||||
| repeat_num=1, | repeat_num=1, | ||||
| batch_size=config_gpu.batch_size) | |||||
| batch_size=config_gpu.batch_size, | |||||
| run_distribute=args_opt.run_distribute) | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| # resume | # resume | ||||
| if args_opt.pre_trained: | if args_opt.pre_trained: | ||||
| @@ -191,7 +195,10 @@ if __name__ == '__main__': | |||||
| loss_scale_manager=loss_scale) | loss_scale_manager=loss_scale) | ||||
| cb = [Monitor(lr_init=lr.asnumpy())] | cb = [Monitor(lr_init=lr.asnumpy())] | ||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| if args_opt.run_distribute: | |||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| else: | |||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" | |||||
| if config_gpu.save_checkpoint: | if config_gpu.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | keep_checkpoint_max=config_gpu.keep_checkpoint_max) | ||||
| @@ -399,6 +399,6 @@ class PredictWithSigmoid(nn.Cell): | |||||
| self.sigmoid = P.Sigmoid() | self.sigmoid = P.Sigmoid() | ||||
| def construct(self, batch_ids, batch_wts, labels): | def construct(self, batch_ids, batch_wts, labels): | ||||
| logits, _, _, = self.network(batch_ids, batch_wts) | |||||
| logits, _, = self.network(batch_ids, batch_wts) | |||||
| pred_probs = self.sigmoid(logits) | pred_probs = self.sigmoid(logits) | ||||
| return logits, pred_probs, labels | return logits, pred_probs, labels | ||||