You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_finetune.py 4.1 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """train finetune"""
  2. # Copyright 2021 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. import os
  17. from mindspore import context
  18. from mindspore.context import ParallelMode
  19. import mindspore.dataset as ds
  20. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  21. from mindspore.communication.management import init, get_rank, get_group_size
  22. from mindspore.common import set_seed
  23. from src.args import args
  24. from src.data.imagenet import ImgData
  25. from src.data.srdata import SRData
  26. from src.data.div2k import DIV2K
  27. from src.data.bicubic import bicubic
  28. from src.ipt_model import IPT
  29. from src.utils import Trainer
  30. def train_net(distribute, imagenet):
  31. """Train net with finetune"""
  32. set_seed(1)
  33. device_id = int(os.getenv('DEVICE_ID', '0'))
  34. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
  35. if imagenet == 1:
  36. train_dataset = ImgData(args)
  37. elif not args.derain:
  38. train_dataset = DIV2K(args, name=args.data_train, train=True, benchmark=False)
  39. train_dataset.set_scale(args.task_id)
  40. else:
  41. train_dataset = SRData(args, name=args.data_train, train=True, benchmark=False)
  42. train_dataset.set_scale(args.task_id)
  43. if distribute:
  44. init()
  45. rank_id = get_rank()
  46. rank_size = get_group_size()
  47. parallel_mode = ParallelMode.DATA_PARALLEL
  48. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
  49. print('Rank {}, group_size {}'.format(rank_id, rank_size))
  50. if imagenet == 1:
  51. train_de_dataset = ds.GeneratorDataset(train_dataset,
  52. ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
  53. num_shards=rank_size, shard_id=rank_id, shuffle=True)
  54. else:
  55. train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"],
  56. num_shards=rank_size, shard_id=rank_id, shuffle=True)
  57. else:
  58. if imagenet == 1:
  59. train_de_dataset = ds.GeneratorDataset(train_dataset,
  60. ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
  61. shuffle=True)
  62. else:
  63. train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR", "idx", "filename"], shuffle=True)
  64. if args.imagenet == 1:
  65. resize_fuc = bicubic()
  66. train_de_dataset = train_de_dataset.batch(
  67. args.batch_size,
  68. input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
  69. output_columns=["LR", "HR", "idx", "filename"], drop_remainder=True,
  70. per_batch_map=resize_fuc.forward)
  71. else:
  72. train_de_dataset = train_de_dataset.batch(args.batch_size, drop_remainder=True)
  73. train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
  74. net_m = IPT(args)
  75. print("Init net weights successfully")
  76. if args.pth_path:
  77. param_dict = load_checkpoint(args.pth_path)
  78. load_param_into_net(net_m, param_dict)
  79. print("Load net weight successfully")
  80. train_func = Trainer(args, train_loader, net_m)
  81. for epoch in range(0, args.epochs):
  82. train_func.update_learning_rate(epoch)
  83. train_func.train()
  84. if __name__ == "__main__":
  85. train_net(distribute=args.distribute, imagenet=args.imagenet)