Browse Source

modify _exec_save_checkpoint

tags/v1.0.0
liuyang_655 5 years ago
parent
commit
18c442e724
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      model_zoo/official/cv/mobilenetv2/train.py

+ 2
- 2
model_zoo/official/cv/mobilenetv2/train.py View File

@@ -26,7 +26,7 @@ from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed from mindspore.common import set_seed


from src.dataset import create_dataset, extract_features from src.dataset import create_dataset, extract_features
@@ -116,7 +116,7 @@ if __name__ == '__main__':
.format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \ .format(epoch + 1, step_size, epoch_mseconds, per_step_mseconds, np.mean(np.array(losses))), \
end="") end="")
if (epoch + 1) % config.save_checkpoint_epochs == 0: if (epoch + 1) % config.save_checkpoint_epochs == 0:
_exec_save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
save_checkpoint(network, os.path.join(config.save_checkpoint_path, \
f"mobilenetv2_head_{epoch+1}.ckpt")) f"mobilenetv2_head_{epoch+1}.ckpt"))
print("total cost {:5.4f} s".format(time.time() - start)) print("total cost {:5.4f} s".format(time.time() - start))




Loading…
Cancel
Save