| @@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.serialization import _exec_save_checkpoint | |||||
| from mindspore.train.serialization import save_checkpoint | |||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| from src.dataset import create_dataset, extract_features | from src.dataset import create_dataset, extract_features | ||||
| @@ -116,7 +116,7 @@ if __name__ == '__main__': | |||||
| .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ | .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ | ||||
| end="") | end="") | ||||
| if (epoch + 1) % config.save_checkpoint_epochs == 0: | if (epoch + 1) % config.save_checkpoint_epochs == 0: | ||||
| _exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ | |||||
| save_checkpoint(network, os.path.join(config.save_checkpoint_path, \ | |||||
| f"mobilenetv2_head_{epoch+1}.ckpt")) | f"mobilenetv2_head_{epoch+1}.ckpt")) | ||||
| print("total cost {:5.4f} s".format(time.time() - start)) | print("total cost {:5.4f} s".format(time.time() - start)) | ||||