You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 3.8 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Train NAML."""
  16. import time
  17. from mindspore import nn, load_checkpoint
  18. import mindspore.common.dtype as mstype
  19. from mindspore.common import set_seed
  20. from mindspore.train.model import Model
  21. from mindspore.train.loss_scale_manager import DynamicLossScaleManager
  22. from src.naml import NAML, NAMLWithLossCell
  23. from src.option import get_args
  24. from src.dataset import create_dataset, MINDPreprocess
  25. from src.utils import process_data
  26. from src.callback import Monitor
  27. if __name__ == '__main__':
  28. args = get_args("train")
  29. set_seed(args.seed)
  30. word_embedding = process_data(args)
  31. net = NAML(args, word_embedding)
  32. net_with_loss = NAMLWithLossCell(net)
  33. if args.checkpoint_path is not None:
  34. load_checkpoint(args.pretrain_checkpoint, net_with_loss)
  35. mindpreprocess_train = MINDPreprocess(vars(args), dataset_path=args.train_dataset_path)
  36. dataset = create_dataset(mindpreprocess_train, batch_size=args.batch_size, rank=args.rank,
  37. group_size=args.device_num)
  38. args.dataset_size = dataset.get_dataset_size()
  39. args.print_times = min(args.dataset_size, args.print_times)
  40. if args.weight_decay:
  41. weight_params = list(filter(lambda x: 'weight' in x.name, net.trainable_params()))
  42. other_params = list(filter(lambda x: 'weight' not in x.name, net.trainable_params()))
  43. group_params = [{'params': weight_params, 'weight_decay': 1e-3},
  44. {'params': other_params, 'weight_decay': 0.0},
  45. {'order_params': net.trainable_params()}]
  46. opt = nn.AdamWeightDecay(group_params, args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
  47. else:
  48. opt = nn.Adam(net.trainable_params(), args.lr, beta1=args.beta1, beta2=args.beta2, eps=args.epsilon)
  49. if args.mixed:
  50. loss_scale_manager = DynamicLossScaleManager(init_loss_scale=128.0, scale_factor=2, scale_window=10000)
  51. net_with_loss.to_float(mstype.float16)
  52. for _, cell in net_with_loss.cells_and_names():
  53. if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)):
  54. cell.to_float(mstype.float32)
  55. model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
  56. else:
  57. model = Model(net_with_loss, optimizer=opt)
  58. cb = [Monitor(args)]
  59. epochs = args.epochs
  60. if args.sink_mode:
  61. epochs = int(args.epochs * args.dataset_size / args.print_times)
  62. start_time = time.time()
  63. print("======================= Start Train ==========================", flush=True)
  64. model.train(epochs, dataset, callbacks=cb, dataset_sink_mode=args.sink_mode, sink_size=args.print_times)
  65. end_time = time.time()
  66. print("==============================================================")
  67. print("processor_name: {}".format(args.platform))
  68. print("test_name: NAML")
  69. print(f"model_name: NAML MIND{args.dataset}")
  70. print("batch_size: {}".format(args.batch_size))
  71. print("latency: {} s".format(end_time - start_time))