|
|
|
@@ -17,16 +17,18 @@ |
|
|
|
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID |
|
|
|
""" |
|
|
|
import argparse |
|
|
|
import os |
|
|
|
import random |
|
|
|
import numpy as np |
|
|
|
import mindspore.nn as nn |
|
|
|
from mindspore import Tensor |
|
|
|
from mindspore.communication.management import init |
|
|
|
from mindspore.nn.optim.momentum import Momentum |
|
|
|
from mindspore.train.model import Model |
|
|
|
from mindspore.train.model import Model, ParallelMode |
|
|
|
from mindspore import context |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor |
|
|
|
from mindspore.model_zoo.vgg import vgg16 |
|
|
|
import dataset |
|
|
|
from dataset import create_dataset |
|
|
|
from config import cifar_cfg as cfg |
|
|
|
random.seed(1) |
|
|
|
np.random.seed(1) |
|
|
|
@@ -62,18 +64,31 @@ if __name__ == '__main__': |
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) |
|
|
|
context.set_context(device_id=args_opt.device_id) |
|
|
|
context.set_context(enable_task_sink=True) |
|
|
|
context.set_context(enable_loop_sink=True) |
|
|
|
context.set_context(enable_mem_reuse=True, enable_hccl=False) |
|
|
|
|
|
|
|
device_num = int(os.environ.get("DEVICE_NUM", 1)) |
|
|
|
if device_num > 1: |
|
|
|
context.reset_auto_parallel_context() |
|
|
|
context.set_context(enable_hccl=True) |
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, |
|
|
|
mirror_mean=True) |
|
|
|
init() |
|
|
|
|
|
|
|
dataset = create_dataset(args_opt.data_path, cfg.epoch_size) |
|
|
|
batch_num = dataset.get_dataset_size() |
|
|
|
|
|
|
|
net = vgg16(num_classes=cfg.num_classes) |
|
|
|
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=50000 // cfg.batch_size) |
|
|
|
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, |
|
|
|
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) |
|
|
|
|
|
|
|
dataset = dataset.create_dataset(args_opt.data_path, cfg.epoch_size) |
|
|
|
batch_num = dataset.get_dataset_size() |
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) |
|
|
|
time_cb = TimeMonitor(data_size=batch_num) |
|
|
|
ckpoint_cb = ModelCheckpoint(prefix="train_vgg_cifar10", directory="./", config=config_ck) |
|
|
|
loss_cb = LossMonitor() |
|
|
|
model.train(cfg.epoch_size, dataset, callbacks=[ckpoint_cb, loss_cb]) |
|
|
|
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) |
|
|
|
print("train success") |