You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 7.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. """train"""
  2. # Copyright 2021 Huawei Technologies Co., Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ============================================================================
  16. import os
  17. import math
  18. import mindspore.dataset as ds
  19. from mindspore import Parameter, set_seed, context
  20. from mindspore.context import ParallelMode
  21. from mindspore.common.initializer import initializer, HeUniform, XavierUniform, Uniform, Normal, Zero
  22. from mindspore.communication.management import init, get_rank, get_group_size
  23. from mindspore.train.serialization import load_checkpoint, load_param_into_net
  24. from src.args import args
  25. from src.data.bicubic import bicubic
  26. from src.data.imagenet import ImgData
  27. from src.ipt_model import IPT
  28. from src.utils import Trainer
  29. def _calculate_fan_in_and_fan_out(shape):
  30. """
  31. calculate fan_in and fan_out
  32. Args:
  33. shape (tuple): input shape.
  34. Returns:
  35. Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
  36. """
  37. dimensions = len(shape)
  38. if dimensions < 2:
  39. raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
  40. if dimensions == 2:
  41. fan_in = shape[1]
  42. fan_out = shape[0]
  43. else:
  44. num_input_fmaps = shape[1]
  45. num_output_fmaps = shape[0]
  46. receptive_field_size = 1
  47. if dimensions > 2:
  48. receptive_field_size = shape[2] * shape[3]
  49. fan_in = num_input_fmaps * receptive_field_size
  50. fan_out = num_output_fmaps * receptive_field_size
  51. return fan_in, fan_out
  52. def init_weights(net, init_type='normal', init_gain=0.02):
  53. """
  54. Initialize network weights.
  55. :param net: network to be initialized
  56. :type net: nn.Module
  57. :param init_type: the name of an initialization method: normal | xavier | kaiming | orthogonal
  58. :type init_type: str
  59. :param init_gain: scaling factor for normal, xavier and orthogonal.
  60. :type init_gain: float
  61. """
  62. for _, cell in net.cells_and_names():
  63. classname = cell.__class__.__name__
  64. if hasattr(cell, 'in_proj_layer'):
  65. cell.in_proj_layer = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.in_proj_layer.shape,
  66. cell.in_proj_layer.dtype), name=cell.in_proj_layer.name)
  67. if hasattr(cell, 'weight'):
  68. if init_type == 'normal':
  69. cell.weight = Parameter(initializer(Normal(init_gain), cell.weight.shape,
  70. cell.weight.dtype), name=cell.weight.name)
  71. elif init_type == 'xavier':
  72. cell.weight = Parameter(initializer(XavierUniform(init_gain), cell.weight.shape,
  73. cell.weight.dtype), name=cell.weight.name)
  74. elif init_type == "he":
  75. cell.weight = Parameter(initializer(HeUniform(negative_slope=math.sqrt(5)), cell.weight.shape,
  76. cell.weight.dtype), name=cell.weight.name)
  77. else:
  78. raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
  79. if hasattr(cell, 'bias') and cell.bias is not None:
  80. fan_in, _ = _calculate_fan_in_and_fan_out(cell.weight.shape)
  81. bound = 1 / math.sqrt(fan_in)
  82. cell.bias = Parameter(initializer(Uniform(bound), cell.bias.shape, cell.bias.dtype),
  83. name=cell.bias.name)
  84. elif classname.find('BatchNorm2d') != -1:
  85. cell.gamma = Parameter(initializer(Normal(1.0), cell.gamma.default_input.shape()), name=cell.gamma.name)
  86. cell.beta = Parameter(initializer(Zero(), cell.beta.default_input.shape()), name=cell.beta.name)
  87. print('initialize network weight with %s' % init_type)
  88. def train_net(distribute, imagenet, epochs):
  89. """Train net"""
  90. set_seed(1)
  91. device_id = int(os.getenv('DEVICE_ID', '0'))
  92. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
  93. if imagenet == 1:
  94. train_dataset = ImgData(args)
  95. else:
  96. train_dataset = data.Data(args).loader_train
  97. if distribute:
  98. init()
  99. rank_id = get_rank()
  100. rank_size = get_group_size()
  101. parallel_mode = ParallelMode.DATA_PARALLEL
  102. context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=rank_size, gradients_mean=True)
  103. print('Rank {}, rank_size {}'.format(rank_id, rank_size))
  104. if imagenet == 1:
  105. train_de_dataset = ds.GeneratorDataset(train_dataset,
  106. ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
  107. num_shards=rank_size, shard_id=args.rank, shuffle=True)
  108. else:
  109. train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], num_shards=rank_size,
  110. shard_id=rank_id, shuffle=True)
  111. else:
  112. if imagenet == 1:
  113. train_de_dataset = ds.GeneratorDataset(train_dataset,
  114. ["HR", "Rain", "LRx2", "LRx3", "LRx4", "scales", "filename"],
  115. shuffle=True)
  116. else:
  117. train_de_dataset = ds.GeneratorDataset(train_dataset, ["LR", "HR"], shuffle=True)
  118. resize_fuc = bicubic()
  119. train_de_dataset = train_de_dataset.project(columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"])
  120. train_de_dataset = train_de_dataset.batch(args.batch_size,
  121. input_columns=["HR", "Rain", "LRx2", "LRx3", "LRx4", "filename"],
  122. output_columns=["LR", "HR", "idx", "filename"],
  123. drop_remainder=True, per_batch_map=resize_fuc.forward)
  124. train_loader = train_de_dataset.create_dict_iterator(output_numpy=True)
  125. net_work = IPT(args)
  126. init_weights(net_work, init_type='he', init_gain=1.0)
  127. print("Init net weight successfully")
  128. if args.pth_path:
  129. param_dict = load_checkpoint(args.pth_path)
  130. load_param_into_net(net_work, param_dict)
  131. print("Load net weight successfully")
  132. train_func = Trainer(args, train_loader, net_work)
  133. for epoch in range(0, epochs):
  134. train_func.update_learning_rate(epoch)
  135. train_func.train()
  136. if __name__ == '__main__':
  137. train_net(distribute=args.distribute, imagenet=args.imagenet, epochs=args.epochs)