|
|
|
@@ -24,14 +24,9 @@ from torch.optim import SGD |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torchvision.transforms import ToTensor |
|
|
|
import argparse |
|
|
|
<<<<<<< HEAD |
|
|
|
#导入openi包 |
|
|
|
from openi.context import prepare, upload_openi |
|
|
|
======= |
|
|
|
import os |
|
|
|
#导入c2net包 |
|
|
|
from c2net.context import prepare, upload_output |
|
|
|
>>>>>>> origin/liuzx |
|
|
|
|
|
|
|
# Training settings |
|
|
|
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') |
|
|
|
@@ -92,40 +87,19 @@ if __name__ == '__main__': |
|
|
|
#获取预训练模型路径 |
|
|
|
mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts" |
|
|
|
|
|
|
|
print("dataset_path:") |
|
|
|
print(os.listdir(dataset_path)) |
|
|
|
os.listdir(dataset_path) |
|
|
|
print("pretrain_model_path:") |
|
|
|
print(os.listdir(pretrain_model_path)) |
|
|
|
os.listdir(pretrain_model_path) |
|
|
|
|
|
|
|
print("output_path:") |
|
|
|
print(os.listdir(output_path)) |
|
|
|
os.listdir(output_path) |
|
|
|
|
|
|
|
#log output |
|
|
|
print('cuda is available:{}'.format(torch.cuda.is_available())) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
batch_size = args.batch_size |
|
|
|
epochs = args.epoch_size |
|
|
|
<<<<<<< HEAD |
|
|
|
train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False) |
|
|
|
test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False) |
|
|
|
======= |
|
|
|
train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "train"), train=True, transform=ToTensor(),download=False) |
|
|
|
test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "test"), train=False, transform=ToTensor(),download=False) |
|
|
|
>>>>>>> origin/liuzx |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size) |
|
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
|
|
|
|
#如果有保存的模型,则加载模型,并在其基础上继续训练 |
|
|
|
<<<<<<< HEAD |
|
|
|
if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")) |
|
|
|
======= |
|
|
|
if os.path.exists(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")) |
|
|
|
>>>>>>> origin/liuzx |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
start_epoch = checkpoint['epoch'] |
|
|
|
|