| @@ -21,7 +21,6 @@ import mindspore.dataset.engine as de | |||||
| import mindspore.dataset.transforms.vision.c_transforms as C | import mindspore.dataset.transforms.vision.c_transforms as C | ||||
| import mindspore.dataset.transforms.c_transforms as C2 | import mindspore.dataset.transforms.c_transforms as C2 | ||||
| def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32): | def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch_size=32): | ||||
| """ | """ | ||||
| create a train or eval dataset | create a train or eval dataset | ||||
| @@ -44,7 +43,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch | |||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | ||||
| num_shards=rank_size, shard_id=rank_id) | num_shards=rank_size, shard_id=rank_id) | ||||
| elif platform == "GPU": | elif platform == "GPU": | ||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) | |||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=get_group_size(), shard_id=get_rank()) | |||||
| else: | else: | ||||
| raise ValueError("Unsupport platform.") | raise ValueError("Unsupport platform.") | ||||
| @@ -32,7 +32,7 @@ from mindspore.train.model import Model, ParallelMode | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.communication.management import init | |||||
| from mindspore.communication.management import init, get_group_size | |||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.lr_generator import get_lr | from src.lr_generator import get_lr | ||||
| @@ -157,6 +157,11 @@ if __name__ == '__main__': | |||||
| # train on gpu | # train on gpu | ||||
| print("train args: ", args_opt, "\ncfg: ", config_gpu) | print("train args: ", args_opt, "\ncfg: ", config_gpu) | ||||
| init('nccl') | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel", | |||||
| mirror_mean=True, | |||||
| device_num=get_group_size()) | |||||
| # define net | # define net | ||||
| net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") | net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") | ||||
| # define loss | # define loss | ||||
| @@ -223,7 +228,7 @@ if __name__ == '__main__': | |||||
| cell.to_float(mstype.float32) | cell.to_float(mstype.float32) | ||||
| if config_ascend.label_smooth > 0: | if config_ascend.label_smooth > 0: | ||||
| loss = CrossEntropyWithLabelSmooth( | loss = CrossEntropyWithLabelSmooth( | ||||
| smooth_factor=config_ascend.label_smooth, num_classes=config.num_classes) | |||||
| smooth_factor=config_ascend.label_smooth, num_classes=config_ascend.num_classes) | |||||
| else: | else: | ||||
| loss = SoftmaxCrossEntropyWithLogits( | loss = SoftmaxCrossEntropyWithLogits( | ||||
| is_grad=False, sparse=True, reduction='mean') | is_grad=False, sparse=True, reduction='mean') | ||||
| @@ -24,7 +24,8 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.config import config_ascend, config_gpu | from src.config import config_ascend, config_gpu | ||||
| from src.mobilenetV2 import mobilenet_v2 | |||||
| from src.mobilenetV3 import mobilenet_v3_large | |||||
| parser = argparse.ArgumentParser(description='Image classification') | parser = argparse.ArgumentParser(description='Image classification') | ||||
| parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') | ||||
| @@ -49,7 +50,7 @@ if __name__ == '__main__': | |||||
| loss = nn.SoftmaxCrossEntropyWithLogits( | loss = nn.SoftmaxCrossEntropyWithLogits( | ||||
| is_grad=False, sparse=True, reduction='mean') | is_grad=False, sparse=True, reduction='mean') | ||||
| net = mobilenet_v2(num_classes=config_platform.num_classes) | |||||
| net = mobilenet_v3_large(num_classes=config_platform.num_classes) | |||||
| if args_opt.platform == "Ascend": | if args_opt.platform == "Ascend": | ||||
| net.to_float(mstype.float16) | net.to_float(mstype.float16) | ||||
| @@ -44,7 +44,9 @@ def create_dataset(dataset_path, do_train, config, platform, repeat_num=1, batch | |||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | ||||
| num_shards=rank_size, shard_id=rank_id) | num_shards=rank_size, shard_id=rank_id) | ||||
| elif platform == "GPU": | elif platform == "GPU": | ||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True) | |||||
| from mindspore.communication.management import get_rank, get_group_size | |||||
| ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=8, shuffle=True, | |||||
| num_shards=get_group_size(), shard_id=get_rank()) | |||||
| else: | else: | ||||
| raise ValueError("Unsupport platform.") | raise ValueError("Unsupport platform.") | ||||
| @@ -33,7 +33,7 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| from mindspore.communication.management import init | |||||
| from mindspore.communication.management import init, get_group_size | |||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.lr_generator import get_lr | from src.lr_generator import get_lr | ||||
| from src.config import config_gpu, config_ascend | from src.config import config_gpu, config_ascend | ||||
| @@ -157,6 +157,11 @@ if __name__ == '__main__': | |||||
| # train on gpu | # train on gpu | ||||
| print("train args: ", args_opt, "\ncfg: ", config_gpu) | print("train args: ", args_opt, "\ncfg: ", config_gpu) | ||||
| init('nccl') | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel", | |||||
| mirror_mean=True, | |||||
| device_num=get_group_size()) | |||||
| # define net | # define net | ||||
| net = mobilenet_v3_large(num_classes=config_gpu.num_classes) | net = mobilenet_v3_large(num_classes=config_gpu.num_classes) | ||||
| # define loss | # define loss | ||||