|
|
|
@@ -23,7 +23,7 @@ from train_utils import TrainWrap |
|
|
|
|
|
|
|
n = LeNet5() |
|
|
|
n.set_train() |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", save_graphs=False) |
|
|
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU", save_graphs=False) |
|
|
|
|
|
|
|
BATCH_SIZE = 32 |
|
|
|
x = Tensor(np.ones((BATCH_SIZE, 1, 32, 32)), mstype.float32) |
|
|
|
|