Browse Source

update

liuzx-patch-1
liuzx 2 years ago
parent
commit
9b7d4c9549
2 changed files with 8 additions and 44 deletions
  1. +0
    -26
      gpu_mnist_example/train_gpu.py
  2. +8
    -18
      train.py

+ 0
- 26
gpu_mnist_example/train_gpu.py View File

@@ -24,14 +24,9 @@ from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import argparse
<<<<<<< HEAD
#导入openi包
from openi.context import prepare, upload_openi
=======
import os
#导入c2net包
from c2net.context import prepare, upload_output
>>>>>>> origin/liuzx

# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
@@ -92,40 +87,19 @@ if __name__ == '__main__':
#获取预训练模型路径
mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts"

print("dataset_path:")
print(os.listdir(dataset_path))
os.listdir(dataset_path)
print("pretrain_model_path:")
print(os.listdir(pretrain_model_path))
os.listdir(pretrain_model_path)

print("output_path:")
print(os.listdir(output_path))
os.listdir(output_path)

#log output
print('cuda is available:{}'.format(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size
epochs = args.epoch_size
<<<<<<< HEAD
train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False)
=======
train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "train"), train=True, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "test"), train=False, transform=ToTensor(),download=False)
>>>>>>> origin/liuzx
train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

#如果有保存的模型,则加载模型,并在其基础上继续训练
<<<<<<< HEAD
if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")):
checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl"))
=======
if os.path.exists(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")):
checkpoint = torch.load(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl"))
>>>>>>> origin/liuzx
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']


+ 8
- 18
train.py View File

@@ -1,25 +1,15 @@
#安装包
import os
os.system('pip install -U openi')

#导入包
from openi.context import prepare, upload_openi
import os
from c2net.context import prepare, upload_output

#初始化导入数据集和预训练模型到容器内
openi_context = prepare()
c2net_context = prepare()

#获取数据集路径,预训练模型路径,输出路径
dataset_path = openi_context.dataset_path
pretrain_model_path = openi_context.pretrain_model_path
output_path = openi_context.output_path

print("dataset_path:")
os.listdir(dataset_path)

print("pretrain_model_path:")
os.listdir(pretrain_model_path)

print("output_path:")
os.listdir(output_path)
dataset_path = c2net_context.dataset_path
pretrain_model_path = c2net_context.pretrain_model_path
output_path = c2net_context.output_path

#回传结果到openi
upload_openi()
upload_output()

Loading…
Cancel
Save