| @@ -920,7 +920,7 @@ class HSwishQuant(_QuantActivation): | |||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range, | narrow_range=narrow_range, | ||||
| quant_delay=quant_delay) | quant_delay=quant_delay) | ||||
| if isinstance(activation, nn.HSwish): | |||||
| if issubclass(activation, nn.HSwish): | |||||
| self.act = activation() | self.act = activation() | ||||
| else: | else: | ||||
| raise ValueError("Activation should be `nn.HSwish`") | raise ValueError("Activation should be `nn.HSwish`") | ||||
| @@ -989,7 +989,7 @@ class HSigmoidQuant(_QuantActivation): | |||||
| symmetric=symmetric, | symmetric=symmetric, | ||||
| narrow_range=narrow_range, | narrow_range=narrow_range, | ||||
| quant_delay=quant_delay) | quant_delay=quant_delay) | ||||
| if isinstance(activation, nn.HSwish): | |||||
| if issubclass(activation, nn.HSwish): | |||||
| self.act = activation() | self.act = activation() | ||||
| else: | else: | ||||
| raise ValueError("Activation should be `nn.HSigmoid`") | raise ValueError("Activation should be `nn.HSigmoid`") | ||||
| @@ -18,6 +18,7 @@ import time | |||||
| import argparse | import argparse | ||||
| import random | import random | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import nn | from mindspore import nn | ||||
| @@ -32,8 +33,9 @@ from mindspore.train.model import Model, ParallelMode | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.communication.management import init, get_group_size | |||||
| from mindspore.communication.management import init, get_group_size, get_rank | |||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.lr_generator import get_lr | from src.lr_generator import get_lr | ||||
| from src.config import config_gpu, config_ascend | from src.config import config_gpu, config_ascend | ||||
| @@ -60,9 +62,14 @@ if args_opt.platform == "Ascend": | |||||
| device_id=device_id, save_graphs=False) | device_id=device_id, save_graphs=False) | ||||
| elif args_opt.platform == "GPU": | elif args_opt.platform == "GPU": | ||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="GPU", save_graphs=False) | |||||
| device_target="GPU", | |||||
| save_graphs=False) | |||||
| init("nccl") | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| else: | else: | ||||
| raise ValueError("Unsupport platform.") | |||||
| raise ValueError("Unsupported device target.") | |||||
| class CrossEntropyWithLabelSmooth(_Loss): | class CrossEntropyWithLabelSmooth(_Loss): | ||||
| @@ -155,12 +162,8 @@ class Monitor(Callback): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.platform == "GPU": | if args_opt.platform == "GPU": | ||||
| # train on gpu | # train on gpu | ||||
| print("train args: ", args_opt, "\ncfg: ", config_gpu) | |||||
| init('nccl') | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel", | |||||
| mirror_mean=True, | |||||
| device_num=get_group_size()) | |||||
| print("train args: ", args_opt) | |||||
| print("cfg: ", config_gpu) | |||||
| # define net | # define net | ||||
| net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") | net = mobilenet_v2(num_classes=config_gpu.num_classes, platform="GPU") | ||||
| @@ -201,13 +204,13 @@ if __name__ == '__main__': | |||||
| loss_scale_manager=loss_scale) | loss_scale_manager=loss_scale) | ||||
| cb = [Monitor(lr_init=lr.asnumpy())] | cb = [Monitor(lr_init=lr.asnumpy())] | ||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| if config_gpu.save_checkpoint: | if config_gpu.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | keep_checkpoint_max=config_gpu.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint( | |||||
| prefix="mobilenetV2", directory=config_gpu.save_checkpoint_path, config=config_ck) | |||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV2", directory=ckpt_save_dir, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| # begine train | |||||
| # begin train | |||||
| model.train(epoch_size, dataset, callbacks=cb) | model.train(epoch_size, dataset, callbacks=cb) | ||||
| elif args_opt.platform == "Ascend": | elif args_opt.platform == "Ascend": | ||||
| # train on ascend | # train on ascend | ||||
| @@ -18,6 +18,7 @@ import time | |||||
| import argparse | import argparse | ||||
| import random | import random | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import nn | from mindspore import nn | ||||
| @@ -33,7 +34,8 @@ from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| import mindspore.dataset.engine as de | import mindspore.dataset.engine as de | ||||
| from mindspore.communication.management import init, get_group_size | |||||
| from mindspore.communication.management import init, get_group_size, get_rank | |||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.lr_generator import get_lr | from src.lr_generator import get_lr | ||||
| from src.config import config_gpu, config_ascend | from src.config import config_gpu, config_ascend | ||||
| @@ -57,10 +59,16 @@ if args_opt.platform == "Ascend": | |||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="Ascend", | device_target="Ascend", | ||||
| device_id=device_id, save_graphs=False) | |||||
| device_id=device_id, | |||||
| save_graphs=False) | |||||
| elif args_opt.platform == "GPU": | elif args_opt.platform == "GPU": | ||||
| context.set_context(mode=context.GRAPH_MODE, | context.set_context(mode=context.GRAPH_MODE, | ||||
| device_target="GPU", save_graphs=False) | |||||
| device_target="GPU", | |||||
| save_graphs=False) | |||||
| init("nccl") | |||||
| context.set_auto_parallel_context(device_num=get_group_size(), | |||||
| parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True) | |||||
| else: | else: | ||||
| raise ValueError("Unsupport platform.") | raise ValueError("Unsupport platform.") | ||||
| @@ -155,12 +163,8 @@ class Monitor(Callback): | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.platform == "GPU": | if args_opt.platform == "GPU": | ||||
| # train on gpu | # train on gpu | ||||
| print("train args: ", args_opt, "\ncfg: ", config_gpu) | |||||
| init('nccl') | |||||
| context.set_auto_parallel_context(parallel_mode="data_parallel", | |||||
| mirror_mean=True, | |||||
| device_num=get_group_size()) | |||||
| print("train args: ", args_opt) | |||||
| print("cfg: ", config_gpu) | |||||
| # define net | # define net | ||||
| net = mobilenet_v3_large(num_classes=config_gpu.num_classes) | net = mobilenet_v3_large(num_classes=config_gpu.num_classes) | ||||
| @@ -201,11 +205,11 @@ if __name__ == '__main__': | |||||
| loss_scale_manager=loss_scale) | loss_scale_manager=loss_scale) | ||||
| cb = [Monitor(lr_init=lr.asnumpy())] | cb = [Monitor(lr_init=lr.asnumpy())] | ||||
| ckpt_save_dir = config_gpu.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" | |||||
| if config_gpu.save_checkpoint: | if config_gpu.save_checkpoint: | ||||
| config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | config_ck = CheckpointConfig(save_checkpoint_steps=config_gpu.save_checkpoint_epochs * step_size, | ||||
| keep_checkpoint_max=config_gpu.keep_checkpoint_max) | keep_checkpoint_max=config_gpu.keep_checkpoint_max) | ||||
| ckpt_cb = ModelCheckpoint( | |||||
| prefix="mobilenetV3", directory=config_gpu.save_checkpoint_path, config=config_ck) | |||||
| ckpt_cb = ModelCheckpoint(prefix="mobilenetV3", directory=ckpt_save_dir, config=config_ck) | |||||
| cb += [ckpt_cb] | cb += [ckpt_cb] | ||||
| # begine train | # begine train | ||||
| model.train(epoch_size, dataset, callbacks=cb) | model.train(epoch_size, dataset, callbacks=cb) | ||||