You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_mdnn.py 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """train"""
  16. import argparse
  17. import numpy as np
  18. from src.mdnn import Mdnn
  19. from mindspore import nn, Model, context
  20. from mindspore import dataset as ds
  21. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
  22. from mindspore.train.callback import Callback
  23. import mindspore.common.initializer as weight_init
  24. parser = argparse.ArgumentParser(description='Mdnn Controller')
  25. parser.add_argument('--i', type=str, default=None, help='Input radial and angular dat file')
  26. parser.add_argument('--charge', type=str, default=None, help='Input charge dat file')
  27. parser.add_argument('--device_id', type=int, default=0, help='GPU device id')
  28. args_opt = parser.parse_args()
  29. context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=args_opt.device_id, save_graphs=False)
  30. class StepLossAccInfo(Callback):
  31. """custom callback function"""
  32. def __init__(self, models, eval_dataset, steploss):
  33. """init model"""
  34. self.model = models
  35. self.eval_dataset = eval_dataset
  36. self.steps_loss = steploss
  37. def step_end(self, run_context):
  38. """step end"""
  39. cb_params = run_context.original_args()
  40. cur_epoch = cb_params.cur_epoch_num
  41. cur_step = (cur_epoch - 1) * 1875 + cb_params.cur_step_num
  42. self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
  43. self.steps_loss["step"].append(str(cur_step))
  44. def get_data(inputdata, outputdata):
  45. """get data function"""
  46. for _, data in enumerate(zip(inputdata, outputdata)):
  47. yield data
  48. def create_dataset(inputdata, outputdata, batchsize=32, repeat_size=1):
  49. """create dataset function"""
  50. input_data = ds.GeneratorDataset(list(get_data(inputdata, outputdata)), column_names=['data', 'label'])
  51. input_data = input_data.batch(batchsize)
  52. input_data = input_data.repeat(repeat_size)
  53. return input_data
  54. def init_weight(nnet):
  55. for _, cell in nnet.cells_and_names():
  56. if isinstance(cell, nn.Conv2d):
  57. cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
  58. cell.weight.shape,
  59. cell.weight.dtype))
  60. if isinstance(cell, nn.Dense):
  61. cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
  62. cell.weight.shape,
  63. cell.weight.dtype))
  64. if __name__ == '__main__':
  65. # read input files
  66. inputs = args_opt.i
  67. outputs = args_opt.charge
  68. radial_angular = np.fromfile(inputs, dtype=np.float32)
  69. radial_angular = radial_angular.reshape((-1, 258)).astype(np.float32)
  70. charge = np.fromfile(outputs, dtype=np.float32)
  71. charge = charge.reshape((-1, 129)).astype(np.float32)
  72. # define the model
  73. net = Mdnn()
  74. lr = 0.0001
  75. decay_rate = 0.8
  76. epoch_size = 1000
  77. batch_size = 500
  78. total_step = epoch_size * batch_size
  79. step_per_epoch = 100
  80. decay_epoch = epoch_size
  81. lr_rate = nn.exponential_decay_lr(lr, decay_rate, total_step, step_per_epoch, decay_epoch)
  82. net_loss = nn.loss.MSELoss(reduction='mean')
  83. net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_rate)
  84. model = Model(net, net_loss, net_opt)
  85. ds_train = create_dataset(radial_angular, charge, batchsize=batch_size)
  86. model_params = net.trainable_params()
  87. net.set_train()
  88. init_weight(net)
  89. # config files
  90. path = './params/'
  91. config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=10)
  92. ckpoint_cb = ModelCheckpoint(prefix="mdnn_best", directory=path, config=config_ck)
  93. steps_loss = {"step": [], "loss_value": []}
  94. step_loss_acc_info = StepLossAccInfo(model, ds_train, steps_loss)
  95. # train the model
  96. model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor(100)])