|
|
|
@@ -66,12 +66,11 @@ class ModelBuilder(): |
|
|
|
return get_WideDeep_net(config) |
|
|
|
|
|
|
|
|
|
|
|
def test_train_eval(): |
|
|
|
def test_train_eval(config): |
|
|
|
""" |
|
|
|
test_train_eval |
|
|
|
""" |
|
|
|
np.random.seed(1000) |
|
|
|
config = WideDeepConfig() |
|
|
|
data_path = config.data_path |
|
|
|
batch_size = config.batch_size |
|
|
|
epochs = config.epochs |
|
|
|
@@ -104,4 +103,6 @@ def test_train_eval(): |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
test_train_eval() |
|
|
|
wide_deep_config = WideDeepConfig() |
|
|
|
wide_deep_config.argparse_init() |
|
|
|
test_train_eval(wide_deep_config) |