Browse Source

!1742 modify-widedeep

Merge pull request !1742 from wukesong/modify_widedeep
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fd045e9115
3 changed files with 6 additions and 6 deletions
  1. +1
    -1
      model_zoo/wide_and_deep/src/callbacks.py
  2. +4
    -4
      model_zoo/wide_and_deep/src/config.py
  3. +1
    -1
      model_zoo/wide_and_deep/test.py

+ 1
- 1
model_zoo/wide_and_deep/src/callbacks.py View File

@@ -53,7 +53,7 @@ class LossCallBack(Callback):
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)

# raise ValueError
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and config is not None:
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and self.config is not None:
loss_file = open(self.config.loss_file_name, "a+")
loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))


+ 4
- 4
model_zoo/wide_and_deep/src/config.py View File

@@ -22,8 +22,8 @@ def argparse_init():
parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--data_path", type=str, default="./test_raw_data/")
parser.add_argument("--epochs", type=int, default=15)
parser.add_argument("--batch_size", type=int, default=10000)
parser.add_argument("--eval_batch_size", type=int, default=15)
parser.add_argument("--batch_size", type=int, default=16000)
parser.add_argument("--eval_batch_size", type=int, default=16000)
parser.add_argument("--field_size", type=int, default=39)
parser.add_argument("--vocab_size", type=int, default=184965)
parser.add_argument("--emb_dim", type=int, default=80)
@@ -45,8 +45,8 @@ class WideDeepConfig():
def __init__(self):
self.data_path = "./test_raw_data/"
self.epochs = 15
self.batch_size = 10000
self.eval_batch_size = 10000
self.batch_size = 16000
self.eval_batch_size = 16000
self.field_size = 39
self.vocab_size = 184965
self.emb_dim = 80


+ 1
- 1
model_zoo/wide_and_deep/test.py View File

@@ -91,4 +91,4 @@ if __name__ == "__main__":
widedeep_config = WideDeepConfig()
widedeep_config.argparse_init()

test_eval(widedeep_config.widedeep)
test_eval(widedeep_config)

Loading…
Cancel
Save