You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.train.callback.CheckpointConfig.txt 5.1 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. Class mindspore.train.callback.CheckpointConfig(save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, integrated_save=True, async_save=False, saved_network=None, append_info=None, enc_key=None, enc_mode='AES-GCM')
  2. 保存checkpoint时的配置策略。
  3. 注:
  4. 在训练过程中,如果数据集是通过数据通道传输的,建议将save_checkpoint_steps设为循环下沉step数量的整数倍数,
  5. 否则,保存checkpoint的时机可能会有偏差。
  6. 建议同时只设置一种触发保存checkpoint策略和一种保留checkpoint文件总数策略。
  7. 如果同时设置了`save_checkpoint_steps`和`save_checkpoint_seconds`,则`save_checkpoint_seconds`无效。
  8. 如果同时设置了`keep_checkpoint_max`和`keep_checkpoint_per_n_minutes`,则`keep_checkpoint_per_n_minutes`无效。
  9. 参数:
  10. save_checkpoint_steps (int):每隔多少个step保存一次checkpoint。默认值:1。
  11. save_checkpoint_seconds (int):每隔多少秒保存一次checkpoint。
  12. 不能同时与save_checkpoint_steps一起使用。默认值:0。
  13. keep_checkpoint_max (int):最多保存多少个checkpoint文件。默认值:5。
  14. keep_checkpoint_per_n_minutes (int):每隔多少分钟保存一个checkpoint文件。
  15. 不能同时与keep_checkpoint_max一起使用。默认值:0。
  16. integrated_save (bool):在自动并行场景下,是否合并保存拆分后的Tensor。
  17. 合并保存功能仅支持在自动并行场景中使用,在手动并行场景中不支持。默认值:True。
  18. async_save (bool):是否异步执行保存checkpoint文件。默认值:False。
  19. saved_network (Cell):保存在checkpoint文件中的网络。如果`saved_network`没有被训练,则保存`saved_network`的初始值。默认值:None。
  20. append_info (list):保存在checkpoint文件中的信息。支持"epoch_num"、"step_num"和dict类型。
  21. dict的key必须是str,dict的value必须是int、float或bool中的一个。默认值:None。
  22. enc_key (Union[None, bytes]):用于加密的字节类型key。如果值为None,则不需要加密。默认值:None。
  23. enc_mode (str):仅当enc_key不设为None时,该参数有效。指定了加密模式,目前支持AES-GCM和AES-CBC。默认值:AES-GCM。
  24. 异常:
  25. ValueError:输入参数的类型不正确。
  26. 示例:
  27. >>> from mindspore import Model, nn
  28. >>> from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
  29. >>>
  30. >>> class LeNet5(nn.Cell):
  31. ... def __init__(self, num_class=10, num_channel=1):
  32. ... super(LeNet5, self).__init__()
  33. ... self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
  34. ... self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  35. ... self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
  36. ... self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
  37. ... self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
  38. ... self.relu = nn.ReLU()
  39. ... self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  40. ... self.flatten = nn.Flatten()
  41. ...
  42. ... def construct(self, x):
  43. ... x = self.max_pool2d(self.relu(self.conv1(x)))
  44. ... x = self.max_pool2d(self.relu(self.conv2(x)))
  45. ... x = self.flatten(x)
  46. ... x = self.relu(self.fc1(x))
  47. ... x = self.relu(self.fc2(x))
  48. ... x = self.fc3(x)
  49. ... return x
  50. >>>
  51. >>> net = LeNet5()
  52. >>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
  53. >>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)
  54. >>> model = Model(net, loss_fn=loss, optimizer=optim)
  55. >>> data_path = './MNIST_Data'
  56. >>> dataset = create_dataset(data_path)
  57. >>> config = CheckpointConfig(saved_network=net)
  58. >>> ckpoint_cb = ModelCheckpoint(prefix='LeNet5', directory='./checkpoint', config=config)
  59. >>> model.train(10, dataset, callbacks=ckpoint_cb)
  60. append_dict
  61. 获取checkpoint中添加字典里面的值。
  62. async_save
  63. 获取是否异步保存checkpoint。
  64. enc_key
  65. 获取加密的key值。
  66. enc_mode
  67. 获取加密模式。
  68. get_checkpoint_policy()
  69. 获取checkpoint的保存策略。
  70. integrated_save
  71. 获取是否合并保存拆分后的Tensor。
  72. keep_checkpoint_max
  73. 获取最多保存checkpoint文件的数量。
  74. keep_checkpoint_per_n_minutes
  75. 获取每隔多少分钟保存一个checkpoint文件。
  76. saved_network
  77. 获取_保存的网络。
  78. save_checkpoint_seconds
  79. 获取每隔多少秒保存一次checkpoint文件。。
  80. save_checkpoint_steps
  81. 获取每隔多少个step保存一次checkpoint文件。