Browse Source

no bool parameter parser in wide_and_deep

tags/v0.7.0-beta
yao_yf 5 years ago
parent
commit
0c175b2cc0
3 changed files with 6 additions and 6 deletions
  1. +2
    -2
      mindspore/nn/layer/embedding.py
  2. +2
    -2
      model_zoo/official/recommend/wide_and_deep/src/config.py
  3. +2
    -2
      model_zoo/official/recommend/wide_and_deep_multitable/src/config.py

+ 2
- 2
mindspore/nn/layer/embedding.py View File

@@ -129,9 +129,9 @@ class EmbeddingLookup(Cell):
embedding_size (int): The size of each embedding vector.
param_init (str): The initialize way of embedding table. Default: 'normal'.
target (str): Specify the target where the op is executed. The value should in
['DEVICE', 'CPU']. Default: 'CPU'.
['DEVICE', 'CPU']. Default: 'CPU'.
slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through
nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'.
nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE.
manual_shapes (tuple): The accompaniment array in field slice mode.

Inputs:


+ 2
- 2
model_zoo/official/recommend/wide_and_deep/src/config.py View File

@@ -25,7 +25,7 @@ def argparse_init():
parser.add_argument("--data_path", type=str, default="./test_raw_data/",
help="This should be set to the same directory given to the data_download's data_dir argument")
parser.add_argument("--epochs", type=int, default=15, help="Total train epochs")
parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ")
parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ")
parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.")
parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.")
parser.add_argument("--field_size", type=int, default=39, help="The number of features.")
@@ -88,7 +88,7 @@ class WideDeepConfig():
self.device_target = args.device_target
self.data_path = args.data_path
self.epochs = args.epochs
self.full_batch = args.full_batch
self.full_batch = bool(args.full_batch)
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.field_size = args.field_size


+ 2
- 2
model_zoo/official/recommend/wide_and_deep_multitable/src/config.py View File

@@ -31,7 +31,7 @@ def argparse_init():
parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
parser.add_argument("--is_tf_dataset", type=int, default=1) # Is tf_dataset.
parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
@@ -87,7 +87,7 @@ class WideDeepConfig():
self.l2_coef = args.l2_coef
self.ftrl_lr = args.ftrl_lr
self.adam_lr = args.adam_lr
self.is_tf_dataset = args.is_tf_dataset
self.is_tf_dataset = bool(args.is_tf_dataset)
self.output_path = args.output_path
self.eval_file_name = args.eval_file_name


Loading…
Cancel
Save