|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """srcnn training"""
-
- import os
- import argparse
- import ast
-
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context, Tensor
- from mindspore.common import set_seed
- from mindspore.train.model import Model
- from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.train.model import ParallelMode
-
- from src.config import srcnn_cfg as config
- from src.dataset import create_train_dataset
- from src.srcnn import SRCNN
-
- set_seed(1)
-
- def filter_checkpoint_parameter_by_list(origin_dict, param_filter):
- """remove useless parameters according to filter_list"""
- for key in list(origin_dict.keys()):
- for name in param_filter:
- if name in key:
- print("Delete parameter from checkpoint: ", key)
- del origin_dict[key]
- break
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description="srcnn training")
- parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
- parser.add_argument('--device_num', type=int, default=1, help='Device num.')
- parser.add_argument('--device_target', type=str, default='GPU', choices=("GPU"),
- help="Device target, support GPU.")
- parser.add_argument('--pre_trained', type=str, default='', help='model_path, local pretrained model to load')
- parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
- help="Run distribute, default: false.")
- parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
- help="Filter head weight parameters, default is False.")
- args, _ = parser.parse_known_args()
-
-
- if args.device_target == "GPU":
- context.set_context(mode=context.GRAPH_MODE,
- device_target=args.device_target,
- save_graphs=False)
- else:
- raise ValueError("Unsupported device target.")
-
- rank = 0
- device_num = 1
- if args.run_distribute:
- init()
- rank = get_rank()
- device_num = get_group_size()
- context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
-
- train_dataset = create_train_dataset(args.dataset_path, batch_size=config.batch_size,
- shard_id=rank, num_shard=device_num)
-
- step_size = train_dataset.get_dataset_size()
-
- # define net
- net = SRCNN()
-
- # init weight
- if args.pre_trained:
- param_dict = load_checkpoint(args.pre_trained)
- if args.filter_weight:
- filter_list = [x.name for x in net.end_point.get_parameters()]
- filter_checkpoint_parameter_by_list(param_dict, filter_list)
- load_param_into_net(net, param_dict)
-
- lr = Tensor(config.lr, ms.float32)
-
- opt = nn.Adam(params=net.trainable_params(), learning_rate=lr, eps=1e-07)
- loss = nn.MSELoss(reduction='mean')
- model = Model(net, loss_fn=loss, optimizer=opt)
-
- # define callbacks
- callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
- if config.save_checkpoint and rank == 0:
- config_ck = CheckpointConfig(save_checkpoint_steps=step_size,
- keep_checkpoint_max=config.keep_checkpoint_max)
- save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
- ckpt_cb = ModelCheckpoint(prefix="srcnn", directory=save_ckpt_path, config=config_ck)
- callbacks.append(ckpt_cb)
-
- model.train(config.epoch_size, train_dataset, callbacks=callbacks)
|