Merge pull request !1988 from yao_yf/fix_modelzoo_widedeep_run_multinup_traintags/v0.5.0-beta
| @@ -13,26 +13,28 @@ The Criteo datasets are used for model training and evaluation. | |||||
| The entire code structure is as following: | The entire code structure is as following: | ||||
| ``` | ``` | ||||
| |--- wide_and_deep/ | |--- wide_and_deep/ | ||||
| train_and_test.py "Entrance of Wide&Deep model training and evaluation" | |||||
| test.py "Entrance of Wide&Deep model evaluation" | |||||
| train.py "Entrance of Wide&Deep model training" | |||||
| train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" | |||||
| |--- src/ "entrance of training and evaluation" | |||||
| config.py "parameters configuration" | |||||
| dataset.py "Dataset loader class" | |||||
| process_data.py "process dataset" | |||||
| preprocess_data.py "pre_process dataset" | |||||
| WideDeep.py "Model structure" | |||||
| callbacks.py "Callback class for training and evaluation" | |||||
| metrics.py "Metric class" | |||||
| |--- script/ "run shell dir" | |||||
| run_multinpu_train.sh "run data parallel" | |||||
| train_and_eval.py "Entrance of Wide&Deep model training and evaluation" | |||||
| eval.py "Entrance of Wide&Deep model evaluation" | |||||
| train.py "Entrance of Wide&Deep model training" | |||||
| train_and_eval_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" | |||||
| train_and_eval_auto_parallel.py | |||||
| |--- src/ "Entrance of training and evaluation" | |||||
| config.py "Parameters configuration" | |||||
| dataset.py "Dataset loader class" | |||||
| process_data.py "Process dataset" | |||||
| preprocess_data.py "Pre_process dataset" | |||||
| wide_and_deep.py "Model structure" | |||||
| callbacks.py "Callback class for training and evaluation" | |||||
| metrics.py "Metric class" | |||||
| |--- script/ "Run shell dir" | |||||
| run_multinpu_train.sh "Run data parallel" | |||||
| run_auto_parallel_train.sh "Run auto parallel" | |||||
| ``` | ``` | ||||
| ### Train and evaluate model | ### Train and evaluate model | ||||
| To train and evaluate the model, command as follows: | To train and evaluate the model, command as follows: | ||||
| ``` | ``` | ||||
| python train_and_test.py | |||||
| python train_and_eval.py | |||||
| ``` | ``` | ||||
| Arguments: | Arguments: | ||||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | ||||
| @@ -44,6 +46,7 @@ Arguments: | |||||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | * `--emb_dim`: The dense embedding dimension of sparse feature. | ||||
| * `--deep_layers_dim`: The dimension of all deep layers. | * `--deep_layers_dim`: The dimension of all deep layers. | ||||
| * `--deep_layers_act`: The activation of all deep layers. | * `--deep_layers_act`: The activation of all deep layers. | ||||
| * `--dropout_flag`: Whether do dropout. | |||||
| * `--keep_prob`: The rate to keep in dropout layer. | * `--keep_prob`: The rate to keep in dropout layer. | ||||
| * `--ckpt_path`:The location of the checkpoint file. | * `--ckpt_path`:The location of the checkpoint file. | ||||
| * `--eval_file_name` : Eval output file. | * `--eval_file_name` : Eval output file. | ||||
| @@ -63,6 +66,7 @@ Arguments: | |||||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | * `--emb_dim`: The dense embedding dimension of sparse feature. | ||||
| * `--deep_layers_dim`: The dimension of all deep layers. | * `--deep_layers_dim`: The dimension of all deep layers. | ||||
| * `--deep_layers_act`: The activation of all deep layers. | * `--deep_layers_act`: The activation of all deep layers. | ||||
| * `--dropout_flag`: Whether do dropout. | |||||
| * `--keep_prob`: The rate to keep in dropout layer. | * `--keep_prob`: The rate to keep in dropout layer. | ||||
| * `--ckpt_path`:The location of the checkpoint file. | * `--ckpt_path`:The location of the checkpoint file. | ||||
| * `--eval_file_name` : Eval output file. | * `--eval_file_name` : Eval output file. | ||||
| @@ -70,13 +74,17 @@ Arguments: | |||||
| To train the model in distributed, command as follows: | To train the model in distributed, command as follows: | ||||
| ``` | ``` | ||||
| # configure environment path, RANK_TABLE_FILE, RANK_SIZE, MINDSPORE_HCCL_CONFIG_PATH before training | |||||
| bash run_multinpu_train.sh | |||||
| # configure environment path before training | |||||
| bash run_multinpu_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | |||||
| ``` | |||||
| ``` | |||||
| # configure environment path before training | |||||
| bash run_auto_parallel_train.sh RANK_SIZE EPOCHS DATASET RANK_TABLE_FILE | |||||
| ``` | ``` | ||||
| To evaluate the model, command as follows: | To evaluate the model, command as follows: | ||||
| ``` | ``` | ||||
| python test.py | |||||
| python eval.py | |||||
| ``` | ``` | ||||
| Arguments: | Arguments: | ||||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | ||||
| @@ -0,0 +1,35 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| # bash run_multinpu_train.sh | |||||
| execute_path=$(pwd) | |||||
| script_self=$(readlink -f "$0") | |||||
| self_path=$(dirname "${script_self}") | |||||
| export RANK_SIZE=$1 | |||||
| export EPOCH_SIZE=$2 | |||||
| export DATASET=$3 | |||||
| export RANK_TABLE_FILE=$4 | |||||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | |||||
| for((i=0;i<$RANK_SIZE;i++)); | |||||
| do | |||||
| rm -rf ${execute_path}/device_$i/ | |||||
| mkdir ${execute_path}/device_$i/ | |||||
| cd ${execute_path}/device_$i/ || exit | |||||
| export RANK_ID=$i | |||||
| export DEVICE_ID=$i | |||||
| python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||||
| done | |||||
| @@ -24,12 +24,12 @@ export DATASET=$3 | |||||
| export RANK_TABLE_FILE=$4 | export RANK_TABLE_FILE=$4 | ||||
| export MINDSPORE_HCCL_CONFIG_PATH=$4 | export MINDSPORE_HCCL_CONFIG_PATH=$4 | ||||
| for((i=0;i<=$RANK_SIZE;i++)); | |||||
| for((i=0;i<$RANK_SIZE;i++)); | |||||
| do | do | ||||
| rm -rf ${execute_path}/device_$i/ | rm -rf ${execute_path}/device_$i/ | ||||
| mkdir ${execute_path}/device_$i/ | mkdir ${execute_path}/device_$i/ | ||||
| cd ${execute_path}/device_$i/ || exit | cd ${execute_path}/device_$i/ || exit | ||||
| export RANK_ID=$i | export RANK_ID=$i | ||||
| export DEVICE_ID=$i | export DEVICE_ID=$i | ||||
| python -s ${self_path}/../train_and_test_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||||
| python -s ${self_path}/../train_and_eval_multinpu.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||||
| done | done | ||||
| @@ -31,7 +31,7 @@ def argparse_init(): | |||||
| parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) | parser.add_argument("--deep_layer_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) | ||||
| parser.add_argument("--deep_layer_act", type=str, default='relu') | parser.add_argument("--deep_layer_act", type=str, default='relu') | ||||
| parser.add_argument("--keep_prob", type=float, default=1.0) | parser.add_argument("--keep_prob", type=float, default=1.0) | ||||
| parser.add_argument("--dropout_flag", type=int, default=0) | |||||
| parser.add_argument("--output_path", type=str, default="./output/") | parser.add_argument("--output_path", type=str, default="./output/") | ||||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") | ||||
| parser.add_argument("--eval_file_name", type=str, default="eval.log") | parser.add_argument("--eval_file_name", type=str, default="eval.log") | ||||
| @@ -86,7 +86,7 @@ class WideDeepConfig(): | |||||
| self.weight_bias_init = ['normal', 'normal'] | self.weight_bias_init = ['normal', 'normal'] | ||||
| self.emb_init = 'normal' | self.emb_init = 'normal' | ||||
| self.init_args = [-0.01, 0.01] | self.init_args = [-0.01, 0.01] | ||||
| self.dropout_flag = False | |||||
| self.dropout_flag = bool(args.dropout_flag) | |||||
| self.l2_coef = 8e-5 | self.l2_coef = 8e-5 | ||||
| self.output_path = args.output_path | self.output_path = args.output_path | ||||
| @@ -19,7 +19,7 @@ import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| # from mindspore.nn import Dropout | |||||
| from mindspore.nn import Dropout | |||||
| from mindspore.nn.optim import Adam, FTRL | from mindspore.nn.optim import Adam, FTRL | ||||
| # from mindspore.nn.metrics import Metric | # from mindspore.nn.metrics import Metric | ||||
| from mindspore.common.initializer import Uniform, initializer | from mindspore.common.initializer import Uniform, initializer | ||||
| @@ -82,7 +82,7 @@ class DenseLayer(nn.Cell): | |||||
| """ | """ | ||||
| def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | ||||
| keep_prob=0.7, scale_coef=1.0, convert_dtype=True): | |||||
| keep_prob=0.7, scale_coef=1.0, convert_dtype=True, drop_out=False): | |||||
| super(DenseLayer, self).__init__() | super(DenseLayer, self).__init__() | ||||
| weight_init, bias_init = weight_bias_init | weight_init, bias_init = weight_bias_init | ||||
| self.weight = init_method( | self.weight = init_method( | ||||
| @@ -92,11 +92,12 @@ class DenseLayer(nn.Cell): | |||||
| self.matmul = P.MatMul(transpose_b=False) | self.matmul = P.MatMul(transpose_b=False) | ||||
| self.bias_add = P.BiasAdd() | self.bias_add = P.BiasAdd() | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| #self.dropout = Dropout(keep_prob=keep_prob) | |||||
| self.dropout = Dropout(keep_prob=keep_prob) | |||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| self.realDiv = P.RealDiv() | self.realDiv = P.RealDiv() | ||||
| self.scale_coef = scale_coef | self.scale_coef = scale_coef | ||||
| self.convert_dtype = convert_dtype | self.convert_dtype = convert_dtype | ||||
| self.drop_out = drop_out | |||||
| def _init_activation(self, act_str): | def _init_activation(self, act_str): | ||||
| act_str = act_str.lower() | act_str = act_str.lower() | ||||
| @@ -110,8 +111,8 @@ class DenseLayer(nn.Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.act_func(x) | x = self.act_func(x) | ||||
| # if self.training: | |||||
| # x = self.dropout(x) | |||||
| if self.training and self.drop_out: | |||||
| x = self.dropout(x) | |||||
| x = self.mul(x, self.scale_coef) | x = self.mul(x, self.scale_coef) | ||||
| if self.convert_dtype: | if self.convert_dtype: | ||||
| x = self.cast(x, mstype.float16) | x = self.cast(x, mstype.float16) | ||||
| @@ -163,23 +164,28 @@ class WideDeepModel(nn.Cell): | |||||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], | self.dense_layer_1 = DenseLayer(self.all_dim_list[0], | ||||
| self.all_dim_list[1], | self.all_dim_list[1], | ||||
| self.weight_bias_init, | self.weight_bias_init, | ||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.deep_layer_act, | |||||
| convert_dtype=True, drop_out=config.dropout_flag) | |||||
| self.dense_layer_2 = DenseLayer(self.all_dim_list[1], | self.dense_layer_2 = DenseLayer(self.all_dim_list[1], | ||||
| self.all_dim_list[2], | self.all_dim_list[2], | ||||
| self.weight_bias_init, | self.weight_bias_init, | ||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.deep_layer_act, | |||||
| convert_dtype=True, drop_out=config.dropout_flag) | |||||
| self.dense_layer_3 = DenseLayer(self.all_dim_list[2], | self.dense_layer_3 = DenseLayer(self.all_dim_list[2], | ||||
| self.all_dim_list[3], | self.all_dim_list[3], | ||||
| self.weight_bias_init, | self.weight_bias_init, | ||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.deep_layer_act, | |||||
| convert_dtype=True, drop_out=config.dropout_flag) | |||||
| self.dense_layer_4 = DenseLayer(self.all_dim_list[3], | self.dense_layer_4 = DenseLayer(self.all_dim_list[3], | ||||
| self.all_dim_list[4], | self.all_dim_list[4], | ||||
| self.weight_bias_init, | self.weight_bias_init, | ||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.deep_layer_act, | |||||
| convert_dtype=True, drop_out=config.dropout_flag) | |||||
| self.dense_layer_5 = DenseLayer(self.all_dim_list[4], | self.dense_layer_5 = DenseLayer(self.all_dim_list[4], | ||||
| self.all_dim_list[5], | self.all_dim_list[5], | ||||
| self.weight_bias_init, | self.weight_bias_init, | ||||
| self.deep_layer_act, convert_dtype=True) | |||||
| self.deep_layer_act, | |||||
| convert_dtype=True, drop_out=config.dropout_flag) | |||||
| self.gather_v2 = P.GatherV2() | self.gather_v2 = P.GatherV2() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| @@ -71,11 +71,10 @@ class ModelBuilder(): | |||||
| return get_WideDeep_net(config) | return get_WideDeep_net(config) | ||||
| def test_train_eval(): | |||||
| def train_and_eval(config): | |||||
| """ | """ | ||||
| test_train_eval | test_train_eval | ||||
| """ | """ | ||||
| config = WideDeepConfig() | |||||
| data_path = config.data_path | data_path = config.data_path | ||||
| batch_size = config.batch_size | batch_size = config.batch_size | ||||
| epochs = config.epochs | epochs = config.epochs | ||||
| @@ -109,9 +108,12 @@ def test_train_eval(): | |||||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | ||||
| directory=config.ckpt_path, config=ckptconfig) | directory=config.ckpt_path, config=ckptconfig) | ||||
| context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_train.ckpt") | |||||
| model.train(epochs, ds_train, | model.train(epochs, ds_train, | ||||
| callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) | callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb]) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_train_eval() | |||||
| wide_deep_config = WideDeepConfig() | |||||
| wide_deep_config.argparse_init() | |||||
| train_and_eval(wide_deep_config) | |||||
| @@ -66,7 +66,7 @@ class ModelBuilder(): | |||||
| return get_WideDeep_net(config) | return get_WideDeep_net(config) | ||||
| def test_train_eval(config): | |||||
| def train_and_eval(config): | |||||
| """ | """ | ||||
| test_train_eval | test_train_eval | ||||
| """ | """ | ||||
| @@ -105,4 +105,4 @@ def test_train_eval(config): | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| wide_deep_config = WideDeepConfig() | wide_deep_config = WideDeepConfig() | ||||
| wide_deep_config.argparse_init() | wide_deep_config.argparse_init() | ||||
| test_train_eval(wide_deep_config) | |||||
| train_and_eval(wide_deep_config) | |||||