|
|
|
@@ -27,8 +27,7 @@ from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.nn import TrainOneStepCell, WithLossCell |
|
|
|
from mindspore.nn.optim import Momentum |
|
|
|
from mindspore.train.callback import ModelCheckpoint, RunContext, LossMonitor, _InternalCallbackParam, \ |
|
|
|
_CallbackManager, Callback, CheckpointConfig |
|
|
|
from mindspore.train.callback._callback import set_cur_net, checkpoint_cb_for_save_op |
|
|
|
_CallbackManager, Callback, CheckpointConfig, _set_cur_net, _checkpoint_cb_for_save_op |
|
|
|
from mindspore.train.callback._checkpoint import _check_file_name_prefix, _chg_ckpt_file_name_if_same_exist |
|
|
|
|
|
|
|
class Net(nn.Cell): |
|
|
|
@@ -189,7 +188,7 @@ def test_checkpoint_cb_for_save_op(): |
|
|
|
one_param['name'] = "conv1.weight" |
|
|
|
one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32) |
|
|
|
parameter_list.append(one_param) |
|
|
|
checkpoint_cb_for_save_op(parameter_list) |
|
|
|
_checkpoint_cb_for_save_op(parameter_list) |
|
|
|
|
|
|
|
|
|
|
|
def test_checkpoint_cb_for_save_op_update_net(): |
|
|
|
@@ -200,8 +199,8 @@ def test_checkpoint_cb_for_save_op_update_net(): |
|
|
|
one_param['data'] = Tensor(np.ones(shape=(64, 3, 3, 3)), dtype=mstype.float32) |
|
|
|
parameter_list.append(one_param) |
|
|
|
net = Net() |
|
|
|
set_cur_net(net) |
|
|
|
checkpoint_cb_for_save_op(parameter_list) |
|
|
|
_set_cur_net(net) |
|
|
|
_checkpoint_cb_for_save_op(parameter_list) |
|
|
|
assert net.conv.weight.default_input.asnumpy()[0][0][0][0] == 1 |
|
|
|
|
|
|
|
|
|
|
|
|