|
|
|
@@ -25,7 +25,7 @@ def argparse_init(): |
|
|
|
parser.add_argument("--data_path", type=str, default="./test_raw_data/", |
|
|
|
help="This should be set to the same directory given to the data_download's data_dir argument") |
|
|
|
parser.add_argument("--epochs", type=int, default=15, help="Total train epochs") |
|
|
|
parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ") |
|
|
|
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ") |
|
|
|
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.") |
|
|
|
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") |
|
|
|
parser.add_argument("--field_size", type=int, default=39, help="The number of features.") |
|
|
|
@@ -88,7 +88,7 @@ class WideDeepConfig(): |
|
|
|
self.device_target = args.device_target |
|
|
|
self.data_path = args.data_path |
|
|
|
self.epochs = args.epochs |
|
|
|
self.full_batch = args.full_batch |
|
|
|
self.full_batch = bool(args.full_batch) |
|
|
|
self.batch_size = args.batch_size |
|
|
|
self.eval_batch_size = args.eval_batch_size |
|
|
|
self.field_size = args.field_size |
|
|
|
|