Merge pull request !6699 from wukesong/add-dataset-checktags/v1.1.0
| @@ -77,5 +77,8 @@ if __name__ == "__main__": | |||||
| else: | else: | ||||
| raise ValueError("Unsupport dataset.") | raise ValueError("Unsupport dataset.") | ||||
| if ds_eval.get_dataset_size() == 0: | |||||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||||
| result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | result = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | ||||
| print("result : {}".format(result)) | print("result : {}".format(result)) | ||||
| @@ -91,6 +91,9 @@ if __name__ == "__main__": | |||||
| else: | else: | ||||
| raise ValueError("Unsupport dataset.") | raise ValueError("Unsupport dataset.") | ||||
| if ds_train.get_dataset_size() == 0: | |||||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||||
| network = AlexNet(cfg.num_classes) | network = AlexNet(cfg.num_classes) | ||||
| loss_scale_manager = None | loss_scale_manager = None | ||||
| @@ -57,5 +57,8 @@ if __name__ == "__main__": | |||||
| ds_eval = create_dataset(os.path.join(args.data_path, "test"), | ds_eval = create_dataset(os.path.join(args.data_path, "test"), | ||||
| cfg.batch_size, | cfg.batch_size, | ||||
| 1) | 1) | ||||
| if ds_eval.get_dataset_size() == 0: | |||||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||||
| acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode) | ||||
| print("============== {} ==============".format(acc)) | print("============== {} ==============".format(acc)) | ||||
| @@ -50,6 +50,8 @@ if __name__ == "__main__": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | ||||
| ds_train = create_dataset(os.path.join(args.data_path, "train"), | ds_train = create_dataset(os.path.join(args.data_path, "train"), | ||||
| cfg.batch_size) | cfg.batch_size) | ||||
| if ds_train.get_dataset_size() == 0: | |||||
| raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") | |||||
| network = LeNet5(cfg.num_classes) | network = LeNet5(cfg.num_classes) | ||||
| net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||