From db173cff73411802627a82eeb3d0b20a1185654e Mon Sep 17 00:00:00 2001 From: wukesong Date: Tue, 22 Sep 2020 10:31:00 +0800 Subject: [PATCH] add check dataset_size --- model_zoo/official/cv/alexnet/eval.py | 3 +++ model_zoo/official/cv/alexnet/train.py | 3 +++ model_zoo/official/cv/lenet/eval.py | 3 +++ model_zoo/official/cv/lenet/train.py | 2 ++ 4 files changed, 11 insertions(+) diff --git a/model_zoo/official/cv/alexnet/eval.py b/model_zoo/official/cv/alexnet/eval.py index 3c7cac7758..0b2e6096e8 100644 --- a/model_zoo/official/cv/alexnet/eval.py +++ b/model_zoo/official/cv/alexnet/eval.py @@ -77,5 +77,8 @@ if __name__ == "__main__": else: raise ValueError("Unsupport dataset.") + if ds_eval.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) print("result : {}".format(result)) diff --git a/model_zoo/official/cv/alexnet/train.py b/model_zoo/official/cv/alexnet/train.py index ce23307438..d4d1de55c8 100644 --- a/model_zoo/official/cv/alexnet/train.py +++ b/model_zoo/official/cv/alexnet/train.py @@ -91,6 +91,9 @@ if __name__ == "__main__": else: raise ValueError("Unsupport dataset.") + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + network = AlexNet(cfg.num_classes) loss_scale_manager = None diff --git a/model_zoo/official/cv/lenet/eval.py b/model_zoo/official/cv/lenet/eval.py index c4bcf79da2..669fb22421 100644 --- a/model_zoo/official/cv/lenet/eval.py +++ b/model_zoo/official/cv/lenet/eval.py @@ -57,5 +57,8 @@ if __name__ == "__main__": ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) + if ds_eval.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") + acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) print("============== {} ==============".format(acc)) diff --git a/model_zoo/official/cv/lenet/train.py b/model_zoo/official/cv/lenet/train.py index 4dbcaedae5..a47f543d4b 100644 --- a/model_zoo/official/cv/lenet/train.py +++ b/model_zoo/official/cv/lenet/train.py @@ -50,6 +50,8 @@ if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size) + if ds_train.get_dataset_size() == 0: + raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")