|
|
|
@@ -20,6 +20,8 @@ def argparse_init(): |
|
|
|
argparse_init |
|
|
|
""" |
|
|
|
parser = argparse.ArgumentParser(description='WideDeep') |
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", choices=["Ascend", "GPU"], |
|
|
|
help="device where the code will be implemented. (Default: Ascend)") |
|
|
|
parser.add_argument("--data_path", type=str, default="./test_raw_data/") |
|
|
|
parser.add_argument("--epochs", type=int, default=15) |
|
|
|
parser.add_argument("--full_batch", type=bool, default=False) |
|
|
|
@@ -44,6 +46,7 @@ class WideDeepConfig(): |
|
|
|
WideDeepConfig |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
self.device_target = "Ascend" |
|
|
|
self.data_path = "./test_raw_data/" |
|
|
|
self.full_batch = False |
|
|
|
self.epochs = 15 |
|
|
|
@@ -72,6 +75,7 @@ class WideDeepConfig(): |
|
|
|
""" |
|
|
|
parser = argparse_init() |
|
|
|
args, _ = parser.parse_known_args() |
|
|
|
self.device_target = args.device_target |
|
|
|
self.data_path = args.data_path |
|
|
|
self.epochs = args.epochs |
|
|
|
self.full_batch = args.full_batch |
|
|
|
|