Browse Source

Merge pull request 'update npu train.py' (#8) from liuzx into master

Reviewed-on: https://openi.pcl.ac.cn/OpenIOSSG/OpenI_Cloudbrain_Example/pulls/8
liuzx-patch-2
liuzx 1 year ago
parent
commit
f3859e69bb
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      npu_mnist_example/train.py

+ 4
- 1
npu_mnist_example/train.py View File

@@ -68,7 +68,10 @@ if __name__ == "__main__":
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")))
# load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")))
if os.path.exists(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")):
load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")))

if args.device_target != "Ascend":
model = Model(network,
net_loss,


Loading…
Cancel
Save