diff --git a/gpu_mnist_example/train_gpu.py b/gpu_mnist_example/train_gpu.py index 4a94d95..88d20e1 100644 --- a/gpu_mnist_example/train_gpu.py +++ b/gpu_mnist_example/train_gpu.py @@ -24,14 +24,9 @@ from torch.optim import SGD from torch.utils.data import DataLoader from torchvision.transforms import ToTensor import argparse -<<<<<<< HEAD -#导入openi包 -from openi.context import prepare, upload_openi -======= import os #导入c2net包 from c2net.context import prepare, upload_output ->>>>>>> origin/liuzx # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') @@ -92,40 +87,19 @@ if __name__ == '__main__': #获取预训练模型路径 mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts" - print("dataset_path:") - print(os.listdir(dataset_path)) - os.listdir(dataset_path) - print("pretrain_model_path:") - print(os.listdir(pretrain_model_path)) - os.listdir(pretrain_model_path) - - print("output_path:") - print(os.listdir(output_path)) - os.listdir(output_path) - #log output print('cuda is available:{}'.format(torch.cuda.is_available())) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") batch_size = args.batch_size epochs = args.epoch_size -<<<<<<< HEAD - train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False) - test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False) -======= train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "train"), train=True, transform=ToTensor(),download=False) test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "test"), train=False, transform=ToTensor(),download=False) ->>>>>>> origin/liuzx train_loader = DataLoader(train_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size) #如果有保存的模型,则加载模型,并在其基础上继续训练 -<<<<<<< HEAD - if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")): - checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")) -======= if os.path.exists(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")): checkpoint = torch.load(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")) ->>>>>>> origin/liuzx model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch'] diff --git a/train.py b/train.py index b4d0faa..d114c93 100644 --- a/train.py +++ b/train.py @@ -1,25 +1,15 @@ -#安装包 -import os -os.system('pip install -U openi') + #导入包 -from openi.context import prepare, upload_openi +import os +from c2net.context import prepare, upload_output #初始化导入数据集和预训练模型到容器内 -openi_context = prepare() +c2net_context = prepare() #获取数据集路径,预训练模型路径,输出路径 -dataset_path = openi_context.dataset_path -pretrain_model_path = openi_context.pretrain_model_path -output_path = openi_context.output_path - -print("dataset_path:") -os.listdir(dataset_path) - -print("pretrain_model_path:") -os.listdir(pretrain_model_path) - -print("output_path:") -os.listdir(output_path) +dataset_path = c2net_context.dataset_path +pretrain_model_path = c2net_context.pretrain_model_path +output_path = c2net_context.output_path #回传结果到openi -upload_openi() \ No newline at end of file +upload_output() \ No newline at end of file