Merge pull request !1643 from yao_yf/widedepp_modelzoo_adjusttags/v0.5.0-beta
| @@ -0,0 +1,93 @@ | |||
| recommendation Model | |||
| ## Overview | |||
| This is an implementation of WideDeep as described in the [Wide & Deep Learning for Recommender System](https://arxiv.org/pdf/1606.07792.pdf) paper. | |||
| WideDeep model jointly trained wide linear models and deep neural network, which combined the benefits of memorization and generalization for recommender systems. | |||
| ## Dataset | |||
| The [Criteo datasets](http://labs.criteo.com/2014/02/download-kaggle-display-advertising-challenge-dataset/) are used for model training and evaluation. | |||
| ## Running Code | |||
| ### Download and preprocess dataset | |||
| To download the dataset, please install Pandas package first. Then issue the following command: | |||
| ``` | |||
| bash download.sh | |||
| ``` | |||
| ### Code Structure | |||
| The entire code structure is as following: | |||
| ``` | |||
| |--- wide_and_deep/ | |||
| train_and_test.py "Entrance of Wide&Deep model training and evaluation" | |||
| test.py "Entrance of Wide&Deep model evaluation" | |||
| train.py "Entrance of Wide&Deep model training" | |||
| train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation" | |||
| |--- src/ "entrance of training and evaluation" | |||
| config.py "parameters configuration" | |||
| dataset.py "Dataset loader class" | |||
| WideDeep.py "Model structure" | |||
| callbacks.py "Callback class for training and evaluation" | |||
| metrics.py "Metric class" | |||
| ``` | |||
| ### Train and evaluate model | |||
| To train and evaluate the model, issue the following command: | |||
| ``` | |||
| python train_and_test.py | |||
| ``` | |||
| Arguments: | |||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | |||
| * `--epochs`: Total train epochs. | |||
| * `--batch_size`: Training batch size. | |||
| * `--eval_batch_size`: Eval batch size. | |||
| * `--field_size`: The number of features. | |||
| * `--vocab_size`: The total features of dataset. | |||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | |||
| * `--deep_layers_dim`: The dimension of all deep layers. | |||
| * `--deep_layers_act`: The activation of all deep layers. | |||
| * `--keep_prob`: The rate to keep in dropout layer. | |||
| * `--ckpt_path`:The location of the checkpoint file. | |||
| * `--eval_file_name` : Eval output file. | |||
| * `--loss_file_name` : Loss output file. | |||
| To train the model, issue the following command: | |||
| ``` | |||
| python train.py | |||
| ``` | |||
| Arguments: | |||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | |||
| * `--epochs`: Total train epochs. | |||
| * `--batch_size`: Training batch size. | |||
| * `--eval_batch_size`: Eval batch size. | |||
| * `--field_size`: The number of features. | |||
| * `--vocab_size`: The total features of dataset. | |||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | |||
| * `--deep_layers_dim`: The dimension of all deep layers. | |||
| * `--deep_layers_act`: The activation of all deep layers. | |||
| * `--keep_prob`: The rate to keep in dropout layer. | |||
| * `--ckpt_path`:The location of the checkpoint file. | |||
| * `--eval_file_name` : Eval output file. | |||
| * `--loss_file_name` : Loss output file. | |||
| To evaluate the model, issue the following command: | |||
| ``` | |||
| python test.py | |||
| ``` | |||
| Arguments: | |||
| * `--data_path`: This should be set to the same directory given to the data_download's data_dir argument. | |||
| * `--epochs`: Total train epochs. | |||
| * `--batch_size`: Training batch size. | |||
| * `--eval_batch_size`: Eval batch size. | |||
| * `--field_size`: The number of features. | |||
| * `--vocab_size`: The total features of dataset. | |||
| * `--emb_dim`: The dense embedding dimension of sparse feature. | |||
| * `--deep_layers_dim`: The dimension of all deep layers. | |||
| * `--deep_layers_act`: The activation of all deep layers. | |||
| * `--keep_prob`: The rate to keep in dropout layer. | |||
| * `--ckpt_path`:The location of the checkpoint file. | |||
| * `--eval_file_name` : Eval output file. | |||
| * `--loss_file_name` : Loss output file. | |||
| There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions. | |||
| @@ -26,79 +26,79 @@ def add_write(file_path, out_str): | |||
| file_out.write(out_str + "\n") | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF, terminate the training. | |||
| If the loss is NAN or INF, terminate the training. | |||
| Note: | |||
| If per_print_times is 0, do NOT print loss. | |||
| Note: | |||
| If per_print_times is 0, do NOT print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, config, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("per_print_times must be in and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.config = config | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, config=None, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("per_print_times must be in and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.config = config | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| cur_num = cb_params.cur_step_num | |||
| print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss) | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy() | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| cur_num = cb_params.cur_step_num | |||
| print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss) | |||
| # raise ValueError | |||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0: | |||
| loss_file = open(self.config.loss_file_name, "a+") | |||
| loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) | |||
| # raise ValueError | |||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and config is not None: | |||
| loss_file = open(self.config.loss_file_name, "a+") | |||
| loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)) | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Monitor the loss in evaluating. | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Monitor the loss in evaluating. | |||
| If the loss is NAN or INF, terminate evaluating. | |||
| If the loss is NAN or INF, terminate evaluating. | |||
| Note: | |||
| If per_print_times is 0, do NOT print loss. | |||
| Note: | |||
| If per_print_times is 0, do NOT print loss. | |||
| Args: | |||
| print_per_step (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): | |||
| super(EvalCallBack, self).__init__() | |||
| if not isinstance(print_per_step, int) or print_per_step < 0: | |||
| raise ValueError("print_per_step must be int and >= 0.") | |||
| self.print_per_step = print_per_step | |||
| self.model = model | |||
| self.eval_dataset = eval_dataset | |||
| self.aucMetric = auc_metric | |||
| self.aucMetric.clear() | |||
| self.eval_file_name = config.eval_file_name | |||
| Args: | |||
| print_per_step (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): | |||
| super(EvalCallBack, self).__init__() | |||
| if not isinstance(print_per_step, int) or print_per_step < 0: | |||
| raise ValueError("print_per_step must be int and >= 0.") | |||
| self.print_per_step = print_per_step | |||
| self.model = model | |||
| self.eval_dataset = eval_dataset | |||
| self.aucMetric = auc_metric | |||
| self.aucMetric.clear() | |||
| self.eval_file_name = config.eval_file_name | |||
| def epoch_name(self, run_context): | |||
| """ | |||
| epoch name | |||
| """ | |||
| self.aucMetric.clear() | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file="", | |||
| strategy_ckpt_load_file="./strategy_train.ckpt") | |||
| start_time = time.time() | |||
| out = self.model.eval(self.eval_dataset) | |||
| end_time = time.time() | |||
| eval_time = int(end_time - start_time) | |||
| def epoch_name(self, run_context): | |||
| """ | |||
| epoch name | |||
| """ | |||
| self.aucMetric.clear() | |||
| context.set_auto_parallel_context(strategy_ckpt_save_file="", | |||
| strategy_ckpt_load_file="./strategy_train.ckpt") | |||
| start_time = time.time() | |||
| out = self.model.eval(self.eval_dataset) | |||
| end_time = time.time() | |||
| eval_time = int(end_time - start_time) | |||
| time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) | |||
| out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) | |||
| print(out_str) | |||
| add_write(self.eval_file_name, out_str) | |||
| time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime()) | |||
| out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time) | |||
| print(out_str) | |||
| add_write(self.eval_file_name, out_str) | |||
| @@ -38,9 +38,9 @@ def argparse_init(): | |||
| return parser | |||
| class Config_WideDeep(): | |||
| class WideDeepConfig(): | |||
| """ | |||
| Config_WideDeep | |||
| WideDeepConfig | |||
| """ | |||
| def __init__(self): | |||
| self.data_path = "./test_raw_data/" | |||
| @@ -70,6 +70,7 @@ class Config_WideDeep(): | |||
| """ | |||
| parser = argparse_init() | |||
| args, _ = parser.parse_known_args() | |||
| self.data_path = args.data_path | |||
| self.epochs = args.epochs | |||
| self.batch_size = args.batch_size | |||
| self.eval_batch_size = args.eval_batch_size | |||
| @@ -135,8 +135,8 @@ class WideDeepModel(nn.Cell): | |||
| self.field_size = config.field_size | |||
| self.vocab_size = config.vocab_size | |||
| self.emb_dim = config.emb_dim | |||
| self.deep_layer_args = config.deep_layer_args | |||
| self.deep_layer_dims_list, self.deep_layer_act = self.deep_layer_args | |||
| self.deep_layer_dims_list = config.deep_layer_dim | |||
| self.deep_layer_act = config.deep_layer_act | |||
| self.init_args = config.init_args | |||
| self.weight_init, self.bias_init = config.weight_bias_init | |||
| self.weight_bias_init = config.weight_bias_init | |||
| @@ -20,11 +20,11 @@ import os | |||
| from mindspore import Model, context | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from wide_deep.utils.callbacks import LossCallBack, EvalCallBack | |||
| from wide_deep.data.datasets import create_dataset | |||
| from wide_deep.utils.metrics import AUCMetric | |||
| from tools.config import Config_WideDeep | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| from src.datasets import create_dataset | |||
| from src.metrics import AUCMetric | |||
| from src.config import WideDeepConfig | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", | |||
| save_graphs=True) | |||
| @@ -88,7 +88,7 @@ def test_eval(config): | |||
| if __name__ == "__main__": | |||
| widedeep_config = Config_WideDeep() | |||
| widedeep_config = WideDeepConfig() | |||
| widedeep_config.argparse_init() | |||
| test_eval(widedeep_config.widedeep) | |||
| @@ -16,19 +16,19 @@ import os | |||
| from mindspore import Model, context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWarp, NetWithLossClass, WideDeepModel | |||
| from wide_deep.utils.callbacks import LossCallBack | |||
| from wide_deep.data.datasets import create_dataset | |||
| from tools.config import Config_WideDeep | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack | |||
| from src.datasets import create_dataset | |||
| from src.config import WideDeepConfig | |||
| context.set_context(model=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) | |||
| def get_WideDeep_net(configure): | |||
| WideDeep_net = WideDeepModel(configure) | |||
| loss_net = NetWithLossClass(WideDeep_net, configure) | |||
| train_net = TrainStepWarp(loss_net) | |||
| train_net = TrainStepWrap(loss_net) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| return train_net, eval_net | |||
| @@ -71,7 +71,7 @@ def test_train(configure): | |||
| train_net.set_train() | |||
| model = Model(train_net) | |||
| callback = LossCallBack(configure) | |||
| callback = LossCallBack(config=configure) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=1, | |||
| keep_checkpoint_max=5) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) | |||
| @@ -79,7 +79,7 @@ def test_train(configure): | |||
| if __name__ == "__main__": | |||
| config = Config_WideDeep() | |||
| config = WideDeepConfig() | |||
| config.argparse_init() | |||
| test_train(config) | |||
| @@ -17,11 +17,11 @@ import os | |||
| from mindspore import Model, context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from wide_deep.utils.callbacks import LossCallBack, EvalCallBack | |||
| from wide_deep.data.datasets import create_dataset | |||
| from wide_deep.utils.metrics import AUCMetric | |||
| from tools.config import Config_WideDeep | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| from src.datasets import create_dataset | |||
| from src.metrics import AUCMetric | |||
| from src.config import WideDeepConfig | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Davinci") | |||
| @@ -81,7 +81,7 @@ def test_train_eval(config): | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | |||
| callback = LossCallBack() | |||
| callback = LossCallBack(config=config) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) | |||
| @@ -91,7 +91,7 @@ def test_train_eval(config): | |||
| if __name__ == "__main__": | |||
| wide_deep_config = Config_WideDeep() | |||
| wide_deep_config = WideDeepConfig() | |||
| wide_deep_config.argparse_init() | |||
| test_train_eval(wide_deep_config) | |||
| @@ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train_imagenet.""" | |||
| """train_multinpu.""" | |||
| import os | |||
| @@ -27,7 +27,7 @@ from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClas | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| from src.datasets import create_dataset | |||
| from src.metrics import AUCMetric | |||
| from src.config import Config_WideDeep | |||
| from src.config import WideDeepConfig | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True) | |||
| @@ -71,7 +71,7 @@ def test_train_eval(): | |||
| test_train_eval | |||
| """ | |||
| np.random.seed(1000) | |||
| config = Config_WideDeep | |||
| config = WideDeepConfig | |||
| data_path = Config.data_path | |||
| batch_size = config.batch_size | |||
| epochs = config.epochs | |||
| @@ -93,7 +93,7 @@ def test_train_eval(): | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | |||
| callback = LossCallBack(config) | |||
| callback = LossCallBack(config=config) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||