|
|
@@ -21,6 +21,7 @@ import time |
|
|
import threading |
|
|
import threading |
|
|
import mindspore.context as context |
|
|
import mindspore.context as context |
|
|
from mindspore import log as logger |
|
|
from mindspore import log as logger |
|
|
|
|
|
from mindspore import nn |
|
|
from mindspore._checkparam import Validator |
|
|
from mindspore._checkparam import Validator |
|
|
from mindspore.train._utils import _make_directory |
|
|
from mindspore.train._utils import _make_directory |
|
|
from mindspore.train.serialization import save_checkpoint, _save_graph |
|
|
from mindspore.train.serialization import save_checkpoint, _save_graph |
|
|
@@ -88,13 +89,36 @@ class CheckpointConfig: |
|
|
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene. |
|
|
integrated_save (bool): Whether to perform integrated save function in automatic model parallel scene. |
|
|
Default: True. Integrated save function is only supported in automatic parallel scene, not supported |
|
|
Default: True. Integrated save function is only supported in automatic parallel scene, not supported |
|
|
in manual parallel. |
|
|
in manual parallel. |
|
|
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False |
|
|
|
|
|
|
|
|
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False. |
|
|
|
|
|
saved_network (Cell): Network to be saved in checkpoint file. Default: None. |
|
|
|
|
|
|
|
|
Raises: |
|
|
Raises: |
|
|
ValueError: If the input_param is None or 0. |
|
|
ValueError: If the input_param is None or 0. |
|
|
|
|
|
|
|
|
Examples: |
|
|
Examples: |
|
|
>>> config = CheckpointConfig() |
|
|
|
|
|
|
|
|
>>> class Net(nn.Cell): |
|
|
|
|
|
>>> def __init__(self): |
|
|
|
|
|
>>> super(Net, self).__init__() |
|
|
|
|
|
>>> self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal') |
|
|
|
|
|
>>> self.bn = nn.BatchNorm2d(64) |
|
|
|
|
|
>>> self.relu = nn.ReLU() |
|
|
|
|
|
>>> self.flatten = nn.Flatten() |
|
|
|
|
|
>>> self.fc = nn.Dense(64*224*224, 12) |
|
|
|
|
|
>>> |
|
|
|
|
|
>>> def construct(self, x): |
|
|
|
|
|
>>> x = self.conv(x) |
|
|
|
|
|
>>> x = self.bn(x) |
|
|
|
|
|
>>> x = self.relu(x) |
|
|
|
|
|
>>> x = self.flatten(x) |
|
|
|
|
|
>>> out = self.fc(x) |
|
|
|
|
|
>>> return out |
|
|
|
|
|
>>> |
|
|
|
|
|
>>> net = Net() |
|
|
|
|
|
>>> loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
|
|
>>> optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) |
|
|
|
|
|
>>> model = Model(net, loss_fn=loss, optimizer=optim) |
|
|
|
|
|
>>> dataset = get_dataset() |
|
|
|
|
|
>>> config = CheckpointConfig(saved_network=net) |
|
|
>>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config) |
|
|
>>> ckpoint_cb = ModelCheckpoint(prefix="ck_prefix", directory='./', config=config) |
|
|
>>> model.train(10, dataset, callbacks=ckpoint_cb) |
|
|
>>> model.train(10, dataset, callbacks=ckpoint_cb) |
|
|
""" |
|
|
""" |
|
|
@@ -104,7 +128,8 @@ class CheckpointConfig: |
|
|
keep_checkpoint_max=5, |
|
|
keep_checkpoint_max=5, |
|
|
keep_checkpoint_per_n_minutes=0, |
|
|
keep_checkpoint_per_n_minutes=0, |
|
|
integrated_save=True, |
|
|
integrated_save=True, |
|
|
async_save=False): |
|
|
|
|
|
|
|
|
async_save=False, |
|
|
|
|
|
saved_network=None): |
|
|
|
|
|
|
|
|
if save_checkpoint_steps is not None: |
|
|
if save_checkpoint_steps is not None: |
|
|
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) |
|
|
save_checkpoint_steps = Validator.check_non_negative_int(save_checkpoint_steps) |
|
|
@@ -115,6 +140,9 @@ class CheckpointConfig: |
|
|
if keep_checkpoint_per_n_minutes is not None: |
|
|
if keep_checkpoint_per_n_minutes is not None: |
|
|
keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) |
|
|
keep_checkpoint_per_n_minutes = Validator.check_non_negative_int(keep_checkpoint_per_n_minutes) |
|
|
|
|
|
|
|
|
|
|
|
if saved_network is not None and not isinstance(saved_network, nn.Cell): |
|
|
|
|
|
raise TypeError(f"The type of saved_network must be None or Cell, but got {str(type(saved_network))}.") |
|
|
|
|
|
|
|
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \ |
|
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \ |
|
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: |
|
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: |
|
|
raise ValueError("The input_param can't be all None or 0") |
|
|
raise ValueError("The input_param can't be all None or 0") |
|
|
@@ -134,6 +162,7 @@ class CheckpointConfig: |
|
|
|
|
|
|
|
|
self._integrated_save = Validator.check_bool(integrated_save) |
|
|
self._integrated_save = Validator.check_bool(integrated_save) |
|
|
self._async_save = Validator.check_bool(async_save) |
|
|
self._async_save = Validator.check_bool(async_save) |
|
|
|
|
|
self._saved_network = saved_network |
|
|
|
|
|
|
|
|
@property |
|
|
@property |
|
|
def save_checkpoint_steps(self): |
|
|
def save_checkpoint_steps(self): |
|
|
@@ -165,12 +194,18 @@ class CheckpointConfig: |
|
|
"""Get the value of _async_save.""" |
|
|
"""Get the value of _async_save.""" |
|
|
return self._async_save |
|
|
return self._async_save |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def saved_network(self): |
|
|
|
|
|
"""Get the value of _saved_network""" |
|
|
|
|
|
return self._saved_network |
|
|
|
|
|
|
|
|
def get_checkpoint_policy(self): |
|
|
def get_checkpoint_policy(self): |
|
|
"""Get the policy of checkpoint.""" |
|
|
"""Get the policy of checkpoint.""" |
|
|
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, |
|
|
|
|
|
'save_checkpoint_seconds': self._save_checkpoint_seconds, |
|
|
|
|
|
'keep_checkpoint_max': self._keep_checkpoint_max, |
|
|
|
|
|
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} |
|
|
|
|
|
|
|
|
checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, |
|
|
|
|
|
'save_checkpoint_seconds': self.save_checkpoint_seconds, |
|
|
|
|
|
'keep_checkpoint_max': self.keep_checkpoint_max, |
|
|
|
|
|
'keep_checkpoint_per_n_minutes': self.keep_checkpoint_per_n_minutes, |
|
|
|
|
|
'saved_network': self.saved_network} |
|
|
|
|
|
|
|
|
return checkpoint_policy |
|
|
return checkpoint_policy |
|
|
|
|
|
|
|
|
@@ -306,7 +341,8 @@ class ModelCheckpoint(Callback): |
|
|
set_cur_net(cb_params.train_network) |
|
|
set_cur_net(cb_params.train_network) |
|
|
cb_params.train_network.exec_checkpoint_graph() |
|
|
cb_params.train_network.exec_checkpoint_graph() |
|
|
|
|
|
|
|
|
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, |
|
|
|
|
|
|
|
|
network = self._config.saved_network if self._config.saved_network is not None else cb_params.train_network |
|
|
|
|
|
save_checkpoint(network, cur_file, self._config.integrated_save, |
|
|
self._config.async_save) |
|
|
self._config.async_save) |
|
|
|
|
|
|
|
|
self._latest_ckpt_file_name = cur_file |
|
|
self._latest_ckpt_file_name = cur_file |
|
|
|