Browse Source

update

liuzx-patch-1
liuzx 2 years ago
parent
commit
9b7d4c9549
2 changed files with 8 additions and 44 deletions
  1. +0
    -26
      gpu_mnist_example/train_gpu.py
  2. +8
    -18
      train.py

+ 0
- 26
gpu_mnist_example/train_gpu.py View File

@@ -24,14 +24,9 @@ from torch.optim import SGD
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor from torchvision.transforms import ToTensor
import argparse import argparse
<<<<<<< HEAD
#导入openi包
from openi.context import prepare, upload_openi
=======
import os import os
#导入c2net包 #导入c2net包
from c2net.context import prepare, upload_output from c2net.context import prepare, upload_output
>>>>>>> origin/liuzx


# Training settings # Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
@@ -92,40 +87,19 @@ if __name__ == '__main__':
#获取预训练模型路径 #获取预训练模型路径
mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts" mnist_example_test2_model_djts_path = c2net_context.pretrain_model_path+"/"+"MNIST_Example_test2_model_djts"


print("dataset_path:")
print(os.listdir(dataset_path))
os.listdir(dataset_path)
print("pretrain_model_path:")
print(os.listdir(pretrain_model_path))
os.listdir(pretrain_model_path)

print("output_path:")
print(os.listdir(output_path))
os.listdir(output_path)

#log output #log output
print('cuda is available:{}'.format(torch.cuda.is_available())) print('cuda is available:{}'.format(torch.cuda.is_available()))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = args.batch_size batch_size = args.batch_size
epochs = args.epoch_size epochs = args.epoch_size
<<<<<<< HEAD
train_dataset = mnist.MNIST(root=os.path.join(dataset_path + "/MnistDataset_torch", "train"), train=True, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=os.path.join(dataset_path+ "/MnistDataset_torch", "test"), train=False, transform=ToTensor(),download=False)
=======
train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "train"), train=True, transform=ToTensor(),download=False) train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "train"), train=True, transform=ToTensor(),download=False)
test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "test"), train=False, transform=ToTensor(),download=False) test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch, "test"), train=False, transform=ToTensor(),download=False)
>>>>>>> origin/liuzx
train_loader = DataLoader(train_dataset, batch_size=batch_size) train_loader = DataLoader(train_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size) test_loader = DataLoader(test_dataset, batch_size=batch_size)


#如果有保存的模型,则加载模型,并在其基础上继续训练 #如果有保存的模型,则加载模型,并在其基础上继续训练
<<<<<<< HEAD
if os.path.exists(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl")):
checkpoint = torch.load(os.path.join(pretrain_model_path + "/MNIST_PytorchExample_GPU_test34_model_7f9j", "mnist_epoch1_0.70.pkl"))
=======
if os.path.exists(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")): if os.path.exists(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")):
checkpoint = torch.load(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl")) checkpoint = torch.load(os.path.join(mnist_example_test2_model_djts_path, "mnist_epoch1_0.76.pkl"))
>>>>>>> origin/liuzx
model.load_state_dict(checkpoint['model']) model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer']) optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] start_epoch = checkpoint['epoch']


+ 8
- 18
train.py View File

@@ -1,25 +1,15 @@
#安装包
import os
os.system('pip install -U openi')

#导入包 #导入包
from openi.context import prepare, upload_openi
import os
from c2net.context import prepare, upload_output


#初始化导入数据集和预训练模型到容器内 #初始化导入数据集和预训练模型到容器内
openi_context = prepare()
c2net_context = prepare()


#获取数据集路径,预训练模型路径,输出路径 #获取数据集路径,预训练模型路径,输出路径
dataset_path = openi_context.dataset_path
pretrain_model_path = openi_context.pretrain_model_path
output_path = openi_context.output_path

print("dataset_path:")
os.listdir(dataset_path)

print("pretrain_model_path:")
os.listdir(pretrain_model_path)

print("output_path:")
os.listdir(output_path)
dataset_path = c2net_context.dataset_path
pretrain_model_path = c2net_context.pretrain_model_path
output_path = c2net_context.output_path


#回传结果到openi #回传结果到openi
upload_openi()
upload_output()

Loading…
Cancel
Save