diff --git a/gpu_mnist_example/train_gpu.py b/gpu_mnist_example/train_gpu.py index bb2f938..f163b42 100644 --- a/gpu_mnist_example/train_gpu.py +++ b/gpu_mnist_example/train_gpu.py @@ -12,7 +12,7 @@ If there are Chinese comments in the code,please add at the beginning: ''' -import os + from model import Model import numpy as np import torch @@ -22,6 +22,7 @@ from torch.optim import SGD from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import argparse +import os #导入openi包 from openi.context import prepare, upload_openi