|
|
|
@@ -68,7 +68,10 @@ if __name__ == "__main__": |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) |
|
|
|
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) |
|
|
|
load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) |
|
|
|
# load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) |
|
|
|
if os.path.exists(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt")): |
|
|
|
load_param_into_net(network, load_checkpoint(os.path.join(Mindspore_MNIST_Example_Model_path, "checkpoint_lenet-1_1875.ckpt"))) |
|
|
|
|
|
|
|
if args.device_target != "Ascend": |
|
|
|
model = Model(network, |
|
|
|
net_loss, |
|
|
|
|