diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index f5be4d9a25..950c2057e7 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -129,9 +129,9 @@ class EmbeddingLookup(Cell): embedding_size (int): The size of each embedding vector. param_init (str): The initialize way of embedding table. Default: 'normal'. target (str): Specify the target where the op is executed. The value should in - ['DEVICE', 'CPU']. Default: 'CPU'. + ['DEVICE', 'CPU']. Default: 'CPU'. slice_mode (str): The slicing way in semi auto parallel/auto parallel. The value should get through - nn.EmbeddingLookUpSplitMode. Default: 'batch_slice'. + nn.EmbeddingLookUpSplitMode. Default: nn.EmbeddingLookUpSplitMode.BATCH_SLICE. manual_shapes (tuple): The accompaniment array in field slice mode. Inputs: diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py index a7d1035a10..d8d4ff6ae5 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -25,7 +25,7 @@ def argparse_init(): parser.add_argument("--data_path", type=str, default="./test_raw_data/", help="This should be set to the same directory given to the data_download's data_dir argument") parser.add_argument("--epochs", type=int, default=15, help="Total train epochs") - parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ") + parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ") parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.") parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") parser.add_argument("--field_size", type=int, default=39, help="The number of features.") @@ -88,7 +88,7 @@ class WideDeepConfig(): self.device_target = args.device_target self.data_path = args.data_path self.epochs = args.epochs - self.full_batch = args.full_batch + self.full_batch = bool(args.full_batch) self.batch_size = args.batch_size self.eval_batch_size = args.eval_batch_size self.field_size = args.field_size diff --git a/model_zoo/official/recommend/wide_and_deep_multitable/src/config.py b/model_zoo/official/recommend/wide_and_deep_multitable/src/config.py index 19d73a6d96..6dffade4f7 100644 --- a/model_zoo/official/recommend/wide_and_deep_multitable/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep_multitable/src/config.py @@ -31,7 +31,7 @@ def argparse_init(): parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr. parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient. - parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient. + parser.add_argument("--is_tf_dataset", type=int, default=1) # Is tf_dataset. parser.add_argument("--dropout_flag", type=int, default=1) # The dropout rate parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file. @@ -87,7 +87,7 @@ class WideDeepConfig(): self.l2_coef = args.l2_coef self.ftrl_lr = args.ftrl_lr self.adam_lr = args.adam_lr - self.is_tf_dataset = args.is_tf_dataset + self.is_tf_dataset = bool(args.is_tf_dataset) self.output_path = args.output_path self.eval_file_name = args.eval_file_name