|
|
|
@@ -52,15 +52,15 @@ class ParameterReduce(nn.Cell): |
|
|
|
def parse_args(cloud_args=None): |
|
|
|
"""parse_args""" |
|
|
|
parser = argparse.ArgumentParser('mindspore classification test') |
|
|
|
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'], |
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], |
|
|
|
help='device where the code will be implemented. (Default: Ascend)') |
|
|
|
# dataset related |
|
|
|
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="imagenet2012") |
|
|
|
parser.add_argument('--dataset', type=str, choices=["cifar10", "imagenet2012"], default="cifar10") |
|
|
|
parser.add_argument('--data_path', type=str, default='', help='eval data dir') |
|
|
|
parser.add_argument('--per_batch_size', default=32, type=int, help='batch size for per npu') |
|
|
|
# network related |
|
|
|
parser.add_argument('--graph_ckpt', type=int, default=1, help='graph ckpt or feed ckpt') |
|
|
|
parser.add_argument('--pretrained', default='', type=str, help='fully path of pretrained model to load. ' |
|
|
|
parser.add_argument('--pre_trained', default='', type=str, help='fully path of pretrained model to load. ' |
|
|
|
'If it is a direction, it will test all ckpt') |
|
|
|
|
|
|
|
# logging related |
|
|
|
@@ -68,9 +68,6 @@ def parse_args(cloud_args=None): |
|
|
|
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') |
|
|
|
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') |
|
|
|
|
|
|
|
# roma obs |
|
|
|
parser.add_argument('--train_url', type=str, default="", help='train url') |
|
|
|
|
|
|
|
args_opt = parser.parse_args() |
|
|
|
args_opt = merge_args(args_opt, cloud_args) |
|
|
|
|
|
|
|
@@ -82,6 +79,8 @@ def parse_args(cloud_args=None): |
|
|
|
args_opt.image_size = cfg.image_size |
|
|
|
args_opt.num_classes = cfg.num_classes |
|
|
|
args_opt.per_batch_size = cfg.batch_size |
|
|
|
args_opt.momentum = cfg.momentum |
|
|
|
args_opt.weight_decay = cfg.weight_decay |
|
|
|
args_opt.buffer_size = cfg.buffer_size |
|
|
|
args_opt.pad_mode = cfg.pad_mode |
|
|
|
args_opt.padding = cfg.padding |
|
|
|
@@ -130,23 +129,23 @@ def test(cloud_args=None): |
|
|
|
args.logger.save_args(args) |
|
|
|
|
|
|
|
if args.dataset == "cifar10": |
|
|
|
net = vgg16(num_classes=args.num_classes) |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, |
|
|
|
net = vgg16(num_classes=args.num_classes, args=args) |
|
|
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, args.momentum, |
|
|
|
weight_decay=args.weight_decay) |
|
|
|
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) |
|
|
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) |
|
|
|
|
|
|
|
param_dict = load_checkpoint(args.checkpoint_path) |
|
|
|
param_dict = load_checkpoint(args.pre_trained) |
|
|
|
load_param_into_net(net, param_dict) |
|
|
|
net.set_train(False) |
|
|
|
dataset = vgg_create_dataset(args.data_path, 1, False) |
|
|
|
dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False) |
|
|
|
res = model.eval(dataset) |
|
|
|
print("result: ", res) |
|
|
|
else: |
|
|
|
# network |
|
|
|
args.logger.important_info('start create network') |
|
|
|
if os.path.isdir(args.pretrained): |
|
|
|
models = list(glob.glob(os.path.join(args.pretrained, '*.ckpt'))) |
|
|
|
if os.path.isdir(args.pre_trained): |
|
|
|
models = list(glob.glob(os.path.join(args.pre_trained, '*.ckpt'))) |
|
|
|
print(models) |
|
|
|
if args.graph_ckpt: |
|
|
|
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('-')[-1].split('_')[0]) |
|
|
|
@@ -154,14 +153,10 @@ def test(cloud_args=None): |
|
|
|
f = lambda x: -1 * int(os.path.splitext(os.path.split(x)[-1])[0].split('_')[-1]) |
|
|
|
args.models = sorted(models, key=f) |
|
|
|
else: |
|
|
|
args.models = [args.pretrained,] |
|
|
|
args.models = [args.pre_trained,] |
|
|
|
|
|
|
|
for model in args.models: |
|
|
|
if args.dataset == "cifar10": |
|
|
|
dataset = vgg_create_dataset(args.data_path, args.image_size, args.per_batch_size, training=False) |
|
|
|
else: |
|
|
|
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size) |
|
|
|
|
|
|
|
dataset = classification_dataset(args.data_path, args.image_size, args.per_batch_size) |
|
|
|
eval_dataloader = dataset.create_tuple_iterator() |
|
|
|
network = vgg16(args.num_classes, args, phase="test") |
|
|
|
|
|
|
|
|