|
- """train finetune"""
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import os
- from mindspore import context
- from mindspore.context import ParallelMode
- import mindspore.dataset as ds
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.common import set_seed
- from src.args import args
- from src.data.imagenet import ImgData
- from src.data.srdata import SRData
- from src.data.div2k import DIV2K
- from src.data.bicubic import bicubic
- from src.ipt_model import IPT
- from src.utils import Trainer
-
- def train_net(distribute, imagenet):
- """Train net with finetune"""
- set_seed(1)
- device_id = int(os.getenv('DEVICE_ID', '0'))
- context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
-
- if imagenet == 1:
- train_dataset = ImgData(args)
- elif not args.derain:
- train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
- train_dataset.set_scale(args.task_id)
- else:
- train_dataset = SRData(args, name=args.data_train, train=True, benchmark=False)
- train_dataset.set_scale(args.task_id)
-
- if distribute:
- init()
- rank_id = get_rank()
- rank_size = get_group_size()
- parallel_mode = ParallelMode.DATA_PARALLEL
- context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
- print('Rank {}, group_size {}'.format(rank_id, rank_size))
- if imagenet == 1:
- train_de_dataset = ds.GeneratorDataset(train_dataset,
- ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
- num_shards=rank_size, shard_id=rank_id, shuffle=True)
- else:
- train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"],
- num_shards=rank_size, shard_id=rank_id, shuffle=True)
- else:
- if imagenet == 1:
- train_de_dataset = ds.GeneratorDataset(train_dataset,
- ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
- shuffle=True)
- else:
- train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], shuffle=True)
-
- if args.imagenet == 1:
- resize_fuc = bicubic()
- train_de_dataset = train_de_dataset.batch(
- args.batch_size,
- input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
- output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True,
- per_batch_map=resize_fuc.forward)
- else:
- train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
-
- train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
- net_m = IPT(args)
- print("Init net weights successfully")
-
- if args.pth_path:
- param_dict = load_checkpoint(args.pth_path)
- load_param_into_net(net_m, param_dict)
- print("Load net weight successfully")
-
- train_func = Trainer(args, train_loader, net_m)
-
- for epoch in range(0, args.epochs):
- train_func.update_learning_rate(epoch)
- train_func.train()
-
- if __name__ == "__main__":
- train_net(distribute=args.distribute, imagenet=args.imagenet)
|