| @@ -14,7 +14,7 @@ | |||
| # ============================================================================ | |||
| import argparse | |||
| import ast | |||
| def launch_parse_args(): | |||
| @@ -43,6 +43,7 @@ def train_parse_args(): | |||
| help='run platform, only support CPU, GPU and Ascend') | |||
| train_parser.add_argument('--pretrain_ckpt', type=str, default=None, help='Pretrained checkpoint path \ | |||
| for fine tune or incremental learning') | |||
| train_parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | |||
| train_parser.add_argument('--train_method', type=str, choices=("train", "fine_tune", "incremental_learn"), \ | |||
| help="\"fine_tune\"or \"incremental_learn\" if to fine tune the net after loading the ckpt, \"train\" to \ | |||
| train from initialization model") | |||
| @@ -59,6 +59,7 @@ def set_config(args): | |||
| "save_checkpoint_path": "./checkpoint", | |||
| "platform": args.platform, | |||
| "ccl": "nccl", | |||
| "run_distribute": args.run_distribute | |||
| }) | |||
| config_ascend = ed({ | |||
| "num_classes": 1000, | |||
| @@ -51,9 +51,12 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1): | |||
| num_shards=rank_size, shard_id=rank_id) | |||
| elif config.platform == "GPU": | |||
| if do_train: | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=get_group_size(), shard_id=get_rank()) | |||
| if config.run_distribute: | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=get_group_size(), shard_id=get_rank()) | |||
| else: | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| else: | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| elif config.platform == "CPU": | |||
| @@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size | |||
| from src.models import Monitor | |||
| def switch_precision(net, data_type, config): | |||
| if config.platform == "Ascend": | |||
| net.to_float(data_type) | |||
| @@ -29,17 +30,18 @@ def switch_precision(net, data_type, config): | |||
| if isinstance(cell, nn.Dense): | |||
| cell.to_float(mstype.float32) | |||
| def context_device_init(config): | |||
| def context_device_init(config): | |||
| if config.platform == "CPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) | |||
| elif config.platform == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False) | |||
| init("nccl") | |||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| if config.run_distribute: | |||
| init("nccl") | |||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| elif config.platform == "Ascend": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, device_id=config.device_id, | |||
| @@ -53,6 +55,7 @@ def context_device_init(config): | |||
| else: | |||
| raise ValueError("Only support CPU, GPU and Ascend.") | |||
| def set_context(config): | |||
| if config.platform == "CPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, | |||
| @@ -64,6 +67,7 @@ def set_context(config): | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target=config.platform, save_graphs=False) | |||
| def config_ckpoint(config, lr, step_size): | |||
| cb = None | |||
| if config.platform in ("CPU", "GPU") or config.rank_id == 0: | |||
| @@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size): | |||
| ckpt_save_dir = config.save_checkpoint_path | |||
| if config.platform == "GPU": | |||
| ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" | |||
| if config.run_distribute: | |||
| ckpt_save_dir += "ckpt_" + str(get_rank()) + "/" | |||
| else: | |||
| ckpt_save_dir += "ckpt_" + "/" | |||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) | |||
| cb += [ckpt_cb] | |||
| @@ -21,7 +21,7 @@ import mindspore.dataset.vision.c_transforms as C | |||
| import mindspore.dataset.transforms.c_transforms as C2 | |||
| def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32): | |||
| def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, batch_size=32, run_distribute=False): | |||
| """ | |||
| create a train or eval dataset | |||
| @@ -36,9 +36,12 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, | |||
| """ | |||
| if device_target == "GPU": | |||
| if do_train: | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=get_group_size(), shard_id=get_rank()) | |||
| if run_distribute: | |||
| from mindspore.communication.management import get_rank, get_group_size | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True, | |||
| num_shards=get_group_size(), shard_id=get_rank()) | |||
| else: | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| else: | |||
| ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=8, shuffle=True) | |||
| else: | |||
| @@ -56,7 +59,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1, | |||
| resize_op = C.Resize(256) | |||
| center_crop = C.CenterCrop(resize_width) | |||
| rescale_op = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) | |||
| normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255]) | |||
| normalize_op = C.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], | |||
| std=[0.229 * 255, 0.224 * 255, 0.225 * 255]) | |||
| change_swap_op = C.HWC2CHW() | |||
| if do_train: | |||
| @@ -16,6 +16,7 @@ | |||
| import time | |||
| import argparse | |||
| import ast | |||
| import numpy as np | |||
| from mindspore import context | |||
| @@ -46,16 +47,18 @@ parser = argparse.ArgumentParser(description='Image classification') | |||
| parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') | |||
| parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') | |||
| parser.add_argument('--device_target', type=str, default="GPU", help='run device_target') | |||
| parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distribute') | |||
| args_opt = parser.parse_args() | |||
| if args_opt.device_target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="GPU", | |||
| save_graphs=False) | |||
| init() | |||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| if args_opt.run_distribute: | |||
| init() | |||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||
| gradients_mean=True) | |||
| else: | |||
| raise ValueError("Unsupported device_target.") | |||
| @@ -168,7 +171,8 @@ if __name__ == '__main__': | |||
| config=config_gpu, | |||
| device_target=args_opt.device_target, | |||
| repeat_num=1, | |||
| batch_size=config_gpu.batch_size) | |||
| batch_size=config_gpu.batch_size, | |||
| run_distribute=args_opt.run_distribute) | |||
| step_size = dataset.get_dataset_size() | |||
| # resume | |||
| if args_opt.pre_trained: | |||
| @@ -191,7 +195,10 @@ if __name__ == '__main__': | |||
| loss_scale_manager=loss_scale) | |||
| cb = [Monitor(lr_init=lr.asnumpy())] | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||
| if args_opt.run_distribute: | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||
| else: | |||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + "/" | |||
| if config_gpu.save_checkpoint: | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | |||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | |||
| @@ -399,6 +399,6 @@ class PredictWithSigmoid(nn.Cell): | |||
| self.sigmoid = P.Sigmoid() | |||
| def construct(self, batch_ids, batch_wts, labels): | |||
| logits, _, _, = self.network(batch_ids, batch_wts) | |||
| logits, _, = self.network(batch_ids, batch_wts) | |||
| pred_probs = self.sigmoid(logits) | |||
| return logits, pred_probs, labels | |||