diff --git a/gpu_mnist_example/parallel_train.py b/gpu_mnist_example/parallel_train.py index 690d450..4d2ac02 100644 --- a/gpu_mnist_example/parallel_train.py +++ b/gpu_mnist_example/parallel_train.py @@ -28,6 +28,7 @@ from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import argparse import os +os.system("pip install c2net") #导入c2net包 from c2net.context import prepare