Merge pull request !1988 from yao_yf/fix_modelzoo_widedeep_run_multinup_traintags/v0.5.0-beta
| @@ -13,26 +13,28 @@ The Criteo datasets are used for model training and evaluation. | |||
| The entire code structure is as following: | |||
| ``` | |||
| |--- wide_and_deep/ | |||
| train_and_test.py "Entrance of Wide&Deep model training and evaluation" | |||
| test.py "Entrance of Wide&Deep model evaluation" | |||
| train.py "Entrance of Wide&Deep model training" | |||
| train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" | |||
| |--- src/ "entrance of training and evaluation" | |||
| config.py "parameters configuration" | |||
| dataset.py "Dataset loader class" | |||
| process_data.py "process dataset" | |||
| preprocess_data.py "pre_process dataset" | |||
| WideDeep.py "Model structure" | |||
| callbacks.py "Callback class for training and evaluation" | |||
| metrics.py "Metric class" | |||
| |--- script/ "run shell dir" | |||
| run_multinpu_train.sh "run data parallel" | |||
| train_and_eval.py "Entrance of Wide&Deep model training and evaluation" | |||
| eval.py "Entrance of Wide&Deep model evaluation" | |||
| train.py "Entrance of Wide&Deep model training" | |||
| train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" | |||
| train_and_eval_auto_parallel.py | |||
| |--- src/ "Entrance of training and evaluation" | |||
| config.py "Parameters configuration" | |||
| dataset.py "Dataset loader class" | |||
| process_data.py "Process dataset" | |||
| preprocess_data.py "Pre_process dataset" | |||
| wide_and_deep.py "Model structure" | |||
| callbacks.py "Callback class for training and evaluation" | |||
| metrics.py "Metric class" | |||
| |--- script/ "Run shell dir" | |||
| run_multinpu_train.sh "Run data parallel" | |||
| run_auto_parallel_train.sh "Run auto parallel" | |||
| ``` | |||
| ### Train and evaluate model | |||
| To train and evaluate the model, command as follows: | |||
| ``` | |||
| python train_and_test.py | |||
| python train_and_eval.py | |||
| ``` | |||
| Arguments: | |||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | |||
| @@ -44,6 +46,7 @@ Arguments: | |||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | |||
| * `--deep_layers_dim`: The dimension of all deep layers. | |||
| * `--deep_layers_act`: The activation of all deep layers. | |||
| * `--dropout_flag`: Whether do dropout. | |||
| * `--keep_prob`: The rate to keep in dropout layer. | |||
| * `--ckpt_path`:The location of the checkpoint file. | |||
| * `--eval_file_name` : Eval output file. | |||
| @@ -63,6 +66,7 @@ Arguments: | |||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | |||
| * `--deep_layers_dim`: The dimension of all deep layers. | |||
| * `--deep_layers_act`: The activation of all deep layers. | |||
| * `--dropout_flag`: Whether do dropout. | |||
| * `--keep_prob`: The rate to keep in dropout layer. | |||
| * `--ckpt_path`:The location of the checkpoint file. | |||
| * `--eval_file_name` : Eval output file. | |||
| @@ -70,13 +74,17 @@ Arguments: | |||
| To train the model in distributed, command as follows: | |||
| ``` | |||
| # configure environment path, RANK_TABLE_FILE, RANK_SIZE, MINDSPORE_HCCL_CONFIG_PATH before training | |||
| bash run_multinpu_train.sh | |||
| # configure environment path before training | |||
| bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | |||
| ``` | |||
| ``` | |||
| # configure environment path before training | |||
| bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | |||
| ``` | |||
| To evaluate the model, command as follows: | |||
| ``` | |||
| python test.py | |||
| python eval.py | |||
| ``` | |||
| Arguments: | |||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | |||
| @@ -0,0 +1,35 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # bash run_multinpu_train.sh | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| export RANK_SIZE=$1 | |||
| export EPOCH_SIZE=$2 | |||
| export DATASET=$3 | |||
| export RANK_TABLE_FILE=$4 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | |||
| for((i=0;i<$RANK_SIZE;i++)); | |||
| do | |||
| rm -rf ${execute_path}/device_$i/ | |||
| mkdir ${execute_path}/device_$i/ | |||
| cd ${execute_path}/device_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||
| done | |||
| @@ -24,12 +24,12 @@ export DATASET=$3 | |||
| export RANK_TABLE_FILE=$4 | |||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | |||
| for((i=0;i<=$RANK_SIZE;i++)); | |||
| for((i=0;i<$RANK_SIZE;i++)); | |||
| do | |||
| rm -rf ${execute_path}/device_$i/ | |||
| mkdir ${execute_path}/device_$i/ | |||
| cd ${execute_path}/device_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_test_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||
| python -s ${self_path}/../train_and_eval_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||
| done | |||
| @@ -31,7 +31,7 @@ def argparse_init(): | |||
| parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) | |||
| parser.add_argument("--deep_layer_act", type=str, default='relu') | |||
| parser.add_argument("--keep_prob", type=float, default=1.0) | |||
| parser.add_argument("--dropout_flag", type=int, default=0) | |||
| parser.add_argument("--output_path", type=str, default="./output/") | |||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | |||
| parser.add_argument("--eval_file_name", type=str, default="eval.log") | |||
| @@ -86,7 +86,7 @@ class WideDeepConfig(): | |||
| self.weight_bias_init = ['normal', 'normal'] | |||
| self.emb_init = 'normal' | |||
| self.init_args = [-0.01, 0.01] | |||
| self.dropout_flag = False | |||
| self.dropout_flag = bool(args.dropout_flag) | |||
| self.l2_coef = 8e-5 | |||
| self.output_path = args.output_path | |||
| @@ -19,7 +19,7 @@ import mindspore.common.dtype as mstype | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| # from mindspore.nn import Dropout | |||
| from mindspore.nn import Dropout | |||
| from mindspore.nn.optim import Adam, FTRL | |||
| # from mindspore.nn.metrics import Metric | |||
| from mindspore.common.initializer import Uniform, initializer | |||
| @@ -82,7 +82,7 @@ class DenseLayer(nn.Cell): | |||
| """ | |||
| def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | |||
| keep_prob=0.7, scale_coef=1.0, convert_dtype=True): | |||
| keep_prob=0.7, scale_coef=1.0, convert_dtype=True, drop_out=False): | |||
| super(DenseLayer, self).__init__() | |||
| weight_init, bias_init = weight_bias_init | |||
| self.weight = init_method( | |||
| @@ -92,11 +92,12 @@ class DenseLayer(nn.Cell): | |||
| self.matmul = P.MatMul(transpose_b=False) | |||
| self.bias_add = P.BiasAdd() | |||
| self.cast = P.Cast() | |||
| #self.dropout = Dropout(keep_prob=keep_prob) | |||
| self.dropout = Dropout(keep_prob=keep_prob) | |||
| self.mul = P.Mul() | |||
| self.realDiv = P.RealDiv() | |||
| self.scale_coef = scale_coef | |||
| self.convert_dtype = convert_dtype | |||
| self.drop_out = drop_out | |||
| def _init_activation(self, act_str): | |||
| act_str = act_str.lower() | |||
| @@ -110,8 +111,8 @@ class DenseLayer(nn.Cell): | |||
| def construct(self, x): | |||
| x = self.act_func(x) | |||
| # if self.training: | |||
| # x = self.dropout(x) | |||
| if self.training and self.drop_out: | |||
| x = self.dropout(x) | |||
| x = self.mul(x, self.scale_coef) | |||
| if self.convert_dtype: | |||
| x = self.cast(x, mstype.float16) | |||
| @@ -163,23 +164,28 @@ class WideDeepModel(nn.Cell): | |||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], | |||
| self.all_dim_list[1], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, convert_dtype=True) | |||
| self.deep_layer_act, | |||
| convert_dtype=True, drop_out=config.dropout_flag) | |||
| self.dense_layer_2 = DenseLayer(self.all_dim_list[1], | |||
| self.all_dim_list[2], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, convert_dtype=True) | |||
| self.deep_layer_act, | |||
| convert_dtype=True, drop_out=config.dropout_flag) | |||
| self.dense_layer_3 = DenseLayer(self.all_dim_list[2], | |||
| self.all_dim_list[3], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, convert_dtype=True) | |||
| self.deep_layer_act, | |||
| convert_dtype=True, drop_out=config.dropout_flag) | |||
| self.dense_layer_4 = DenseLayer(self.all_dim_list[3], | |||
| self.all_dim_list[4], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, convert_dtype=True) | |||
| self.deep_layer_act, | |||
| convert_dtype=True, drop_out=config.dropout_flag) | |||
| self.dense_layer_5 = DenseLayer(self.all_dim_list[4], | |||
| self.all_dim_list[5], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, convert_dtype=True) | |||
| self.deep_layer_act, | |||
| convert_dtype=True, drop_out=config.dropout_flag) | |||
| self.gather_v2 = P.GatherV2() | |||
| self.mul = P.Mul() | |||
| @@ -71,11 +71,10 @@ class ModelBuilder(): | |||
| return get_WideDeep_net(config) | |||
| def test_train_eval(): | |||
| def train_and_eval(config): | |||
| """ | |||
| test_train_eval | |||
| """ | |||
| config = WideDeepConfig() | |||
| data_path = config.data_path | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| @@ -109,9 +108,12 @@ def test_train_eval(): | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") | |||
| model.train(epochs, ds_train, | |||
| callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) | |||
| if __name__ == "__main__": | |||
| test_train_eval() | |||
| wide_deep_config = WideDeepConfig() | |||
| wide_deep_config.argparse_init() | |||
| train_and_eval(wide_deep_config) | |||
| @@ -66,7 +66,7 @@ class ModelBuilder(): | |||
| return get_WideDeep_net(config) | |||
| def test_train_eval(config): | |||
| def train_and_eval(config): | |||
| """ | |||
| test_train_eval | |||
| """ | |||
| @@ -105,4 +105,4 @@ def test_train_eval(config): | |||
| if __name__ == "__main__": | |||
| wide_deep_config = WideDeepConfig() | |||
| wide_deep_config.argparse_init() | |||
| test_train_eval(wide_deep_config) | |||
| train_and_eval(wide_deep_config) | |||