|
|
|
@@ -50,6 +50,8 @@ if __name__ == "__main__": |
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) |
|
|
|
ds_train = create_dataset(os.path.join(args.data_path, "train"), |
|
|
|
cfg.batch_size) |
|
|
|
if ds_train.get_dataset_size() == 0: |
|
|
|
raise ValueError("Please check dataset size > 0 and batch_size <= dataset size") |
|
|
|
|
|
|
|
network = LeNet5(cfg.num_classes) |
|
|
|
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") |
|
|
|
|