|
|
|
@@ -0,0 +1,119 @@ |
|
|
|
#!/usr/bin/python |
|
|
|
#coding=utf-8 |
|
|
|
''' |
|
|
|
If there are Chinese comments in the code,please add at the beginning: |
|
|
|
#!/usr/bin/python |
|
|
|
#coding=utf-8 |
|
|
|
|
|
|
|
数据集结构是: |
|
|
|
MnistDataset_torch.zip |
|
|
|
├── test |
|
|
|
└── train |
|
|
|
|
|
|
|
预训练模型文件夹结构是: |
|
|
|
Torch_MNIST_Example_Model |
|
|
|
├── mnist_epoch1_0.76.pkl |
|
|
|
|
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
from model import Model |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from torchvision.datasets import mnist |
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
from torch.optim import SGD |
|
|
|
from torch.utils.data import DataLoader |
|
|
|
from torchvision.transforms import ToTensor |
|
|
|
import argparse |
|
|
|
import os |
|
|
|
#导入c2net包 |
|
|
|
from c2net.context import prepare |
|
|
|
|
|
|
|
# Training settings |
|
|
|
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') |
|
|
|
parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train') |
|
|
|
parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') |
|
|
|
|
|
|
|
# 参数声明 |
|
|
|
WORKERS = 0 # dataloder线程数 |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
model = Model().to(device) |
|
|
|
optimizer = SGD(model.parameters(), lr=1e-1) |
|
|
|
cost = CrossEntropyLoss() |
|
|
|
|
|
|
|
# 模型训练 |
|
|
|
def train(model, train_loader, epoch): |
|
|
|
model.train() |
|
|
|
train_loss = 0 |
|
|
|
for i, data in enumerate(train_loader, 0): |
|
|
|
x, y = data |
|
|
|
x = x.to(device) |
|
|
|
y = y.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
y_hat = model(x) |
|
|
|
loss = cost(y_hat, y) |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
train_loss += loss |
|
|
|
loss_mean = train_loss / (i+1) |
|
|
|
print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item())) |
|
|
|
|
|
|
|
# 模型测试 |
|
|
|
def test(model, test_loader, test_data): |
|
|
|
model.eval() |
|
|
|
test_loss = 0 |
|
|
|
correct = 0 |
|
|
|
with torch.no_grad(): |
|
|
|
for i, data in enumerate(test_loader, 0): |
|
|
|
x, y = data |
|
|
|
x = x.to(device) |
|
|
|
y = y.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
y_hat = model(x) |
|
|
|
test_loss += cost(y_hat, y).item() |
|
|
|
pred = y_hat.max(1, keepdim=True)[1] |
|
|
|
correct += pred.eq(y.view_as(pred)).sum().item() |
|
|
|
test_loss /= (i+1) |
|
|
|
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( |
|
|
|
test_loss, correct, len(test_data), 100. * correct / len(test_data))) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
args, unknown = parser.parse_known_args() |
|
|
|
|
|
|
|
#初始化导入数据集和预训练模型到容器内 |
|
|
|
c2net_context = prepare() |
|
|
|
#获取数据集路径 |
|
|
|
MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch" |
|
|
|
#获取预训练模型路径 |
|
|
|
Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model" |
|
|
|
|
|
|
|
#log output |
|
|
|
print('cuda is available:{}'.format(torch.cuda.is_available())) |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
batch_size = args.batch_size |
|
|
|
epochs = args.epoch_size |
|
|
|
train_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "train"), train=True, transform=ToTensor(),download=False) |
|
|
|
test_dataset = mnist.MNIST(root=os.path.join(MnistDataset_torch_path, "test"), train=False, transform=ToTensor(),download=False) |
|
|
|
train_loader = DataLoader(train_dataset, batch_size=batch_size) |
|
|
|
test_loader = DataLoader(test_dataset, batch_size=batch_size) |
|
|
|
|
|
|
|
#如果有保存的模型,则加载模型,并在其基础上继续训练 |
|
|
|
if os.path.exists(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.76.pkl")): |
|
|
|
checkpoint = torch.load(os.path.join(Torch_MNIST_Example_Model_path, "mnist_epoch1_0.76.pkl")) |
|
|
|
model.load_state_dict(checkpoint['model']) |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
|
|
start_epoch = checkpoint['epoch'] |
|
|
|
print('加载 epoch {} 权重成功!'.format(start_epoch)) |
|
|
|
else: |
|
|
|
start_epoch = 0 |
|
|
|
print('无保存模型,将从头开始训练!') |
|
|
|
|
|
|
|
for epoch in range(start_epoch+1, epochs+1): |
|
|
|
train(model, train_loader, epoch) |
|
|
|
test(model, test_loader, test_dataset) |
|
|
|
# 将模型保存到c2net_context.output_path |
|
|
|
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch} |
|
|
|
torch.save(state, '{}/mnist_epoch{}.pkl'.format(c2net_context.output_path, epoch)) |
|
|
|
|
|
|
|
|