Browse Source

!6699 add dataset size check

Merge pull request !6699 from wukesong/add-dataset-check
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9a1cfe76c7
4 changed files with 11 additions and 0 deletions
  1. +3
    -0
      model_zoo/official/cv/alexnet/eval.py
  2. +3
    -0
      model_zoo/official/cv/alexnet/train.py
  3. +3
    -0
      model_zoo/official/cv/lenet/eval.py
  4. +2
    -0
      model_zoo/official/cv/lenet/train.py

+ 3
- 0
model_zoo/official/cv/alexnet/eval.py View File

@@ -77,5 +77,8 @@ if __name__ == "__main__":
else:
raise ValueError("Unsupport dataset.")

if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")

result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("result : {}".format(result))

+ 3
- 0
model_zoo/official/cv/alexnet/train.py View File

@@ -91,6 +91,9 @@ if __name__ == "__main__":
else:
raise ValueError("Unsupport dataset.")

if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")

network = AlexNet(cfg.num_classes)

loss_scale_manager = None


+ 3
- 0
model_zoo/official/cv/lenet/eval.py View File

@@ -57,5 +57,8 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"),
cfg.batch_size,
1)
if ds_eval.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")

acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))

+ 2
- 0
model_zoo/official/cv/lenet/train.py View File

@@ -50,6 +50,8 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size")

network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")


Loading…
Cancel
Save