|
|
|
@@ -37,6 +37,7 @@ from mindspore.train.model import Model, ParallelMode |
|
|
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, Callback |
|
|
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager |
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net |
|
|
|
import mindspore.dataset.engine as de |
|
|
|
from mindspore.communication.management import init |
|
|
|
|
|
|
|
@@ -46,6 +47,7 @@ de.config.set_seed(1) |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Image classification') |
|
|
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path') |
|
|
|
parser.add_argument('--pre_trained', type=str, default=None, help='Pretrained checkpoint path') |
|
|
|
args_opt = parser.parse_args() |
|
|
|
|
|
|
|
device_id = int(os.getenv('DEVICE_ID')) |
|
|
|
@@ -165,6 +167,9 @@ if __name__ == '__main__': |
|
|
|
dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=True, |
|
|
|
repeat_num=epoch_size, batch_size=config.batch_size) |
|
|
|
step_size = dataset.get_dataset_size() |
|
|
|
if args_opt.pre_trained: |
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained) |
|
|
|
load_param_into_net(net, param_dict) |
|
|
|
|
|
|
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) |
|
|
|
lr = Tensor(get_lr(global_step=0, lr_init=0, lr_end=0, lr_max=config.lr, |
|
|
|
|