Browse Source

!1643 WideDeep modelzoo adjust dir

Merge pull request !1643 from yao_yf/widedepp_modelzoo_adjust
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
c8d910b9c8
9 changed files with 187 additions and 93 deletions
  1. +93
    -0
      model_zoo/wide_and_deep/README.md
  2. +64
    -64
      model_zoo/wide_and_deep/src/callbacks.py
  3. +3
    -2
      model_zoo/wide_and_deep/src/config.py
  4. +0
    -0
      model_zoo/wide_and_deep/src/metrics.py
  5. +2
    -2
      model_zoo/wide_and_deep/src/wide_and_deep.py
  6. +6
    -6
      model_zoo/wide_and_deep/test.py
  7. +8
    -8
      model_zoo/wide_and_deep/train.py
  8. +7
    -7
      model_zoo/wide_and_deep/train_and_test.py
  9. +4
    -4
      model_zoo/wide_and_deep/train_and_test_multinpu.py

+ 93
- 0
model_zoo/wide_and_deep/README.md View File

@@ -0,0 +1,93 @@
recommendation Model
## Overview
This is an implementation of WideDeep as described in the [Wide & Deep Learning for Recommender System](https://arxiv.org/pdf/1606.07792.pdf) paper.

WideDeep model jointly trained wide linear models and deep neural network, which combined the benefits of memorization and generalization for recommender systems.

## Dataset
The [Criteo datasets](http://labs.criteo.com/2014/02/download-kaggle-display-advertising-challenge-dataset/) are used for model training and evaluation.

## Running Code

### Download and preprocess dataset
To download the dataset, please install Pandas package first. Then issue the following command:
```
bash download.sh
```

### Code Structure
The entire code structure is as following:
```
|--- wide_and_deep/
train_and_test.py "Entrance of Wide&Deep model training and evaluation"
test.py "Entrance of Wide&Deep model evaluation"
train.py "Entrance of Wide&Deep model training"
train_and_test_multinpu.py "Entrance of Wide&Deep model data parallel training and evaluation"
|--- src/ "entrance of training and evaluation"
config.py "parameters configuration"
dataset.py "Dataset loader class"
WideDeep.py "Model structure"
callbacks.py "Callback class for training and evaluation"
metrics.py "Metric class"
```

### Train and evaluate model
To train and evaluate the model, issue the following command:
```
python train_and_test.py
```
Arguments:
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.
* `--epochs`: Total train epochs.
* `--batch_size`: Training batch size.
* `--eval_batch_size`: Eval batch size.
* `--field_size`: The number of features.
* `--vocab_size`: The total features of dataset.
* `--emb_dim`: The dense embedding dimension of sparse feature.
* `--deep_layers_dim`: The dimension of all deep layers.
* `--deep_layers_act`: The activation of all deep layers.
* `--keep_prob`: The rate to keep in dropout layer.
* `--ckpt_path`:The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
* `--loss_file_name` : Loss output file.

To train the model, issue the following command:
```
python train.py
```
Arguments:
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.
* `--epochs`: Total train epochs.
* `--batch_size`: Training batch size.
* `--eval_batch_size`: Eval batch size.
* `--field_size`: The number of features.
* `--vocab_size`: The total features of dataset.
* `--emb_dim`: The dense embedding dimension of sparse feature.
* `--deep_layers_dim`: The dimension of all deep layers.
* `--deep_layers_act`: The activation of all deep layers.
* `--keep_prob`: The rate to keep in dropout layer.
* `--ckpt_path`:The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
* `--loss_file_name` : Loss output file.

To evaluate the model, issue the following command:
```
python test.py
```
Arguments:
* `--data_path`: This should be set to the same directory given to the data_download's data_dir argument.
* `--epochs`: Total train epochs.
* `--batch_size`: Training batch size.
* `--eval_batch_size`: Eval batch size.
* `--field_size`: The number of features.
* `--vocab_size`: The total features of dataset.
* `--emb_dim`: The dense embedding dimension of sparse feature.
* `--deep_layers_dim`: The dimension of all deep layers.
* `--deep_layers_act`: The activation of all deep layers.
* `--keep_prob`: The rate to keep in dropout layer.
* `--ckpt_path`:The location of the checkpoint file.
* `--eval_file_name` : Eval output file.
* `--loss_file_name` : Loss output file.

There are other arguments about models and training process. Use the `--help` or `-h` flag to get a full list of possible arguments with detailed descriptions.


+ 64
- 64
model_zoo/wide_and_deep/src/callbacks.py View File

@@ -26,79 +26,79 @@ def add_write(file_path, out_str):
file_out.write(out_str + "\n")


class LossCallBack(Callback):
"""
Monitor the loss in training.
class LossCallBack(Callback):
"""
Monitor the loss in training.

If the loss is NAN or INF, terminate the training.
If the loss is NAN or INF, terminate the training.

Note:
If per_print_times is 0, do NOT print loss.
Note:
If per_print_times is 0, do NOT print loss.

Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, config, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("per_print_times must be in and >= 0.")
self._per_print_times = per_print_times
self.config = config
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, config=None, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("per_print_times must be in and >= 0.")
self._per_print_times = per_print_times
self.config = config

def step_end(self, run_context):
cb_params = run_context.original_args()
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)
def step_end(self, run_context):
cb_params = run_context.original_args()
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss)

# raise ValueError
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
loss_file = open(self.config.loss_file_name, "a+")
loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))
loss_file.write("\n")
loss_file.close()
print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))
# raise ValueError
if self._per_print_times != 0 and cur_num % self._per_print_times == 0 and config is not None:
loss_file = open(self.config.loss_file_name, "a+")
loss_file.write("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))
loss_file.write("\n")
loss_file.close()
print("epoch: %s, step: %s, wide_loss: %s, deep_loss: %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss))


class EvalCallBack(Callback):
"""
Monitor the loss in evaluating.
class EvalCallBack(Callback):
"""
Monitor the loss in evaluating.

If the loss is NAN or INF, terminate evaluating.
If the loss is NAN or INF, terminate evaluating.

Note:
If per_print_times is 0, do NOT print loss.
Note:
If per_print_times is 0, do NOT print loss.

Args:
print_per_step (int): Print loss every times. Default: 1.
"""
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
super(EvalCallBack, self).__init__()
if not isinstance(print_per_step, int) or print_per_step < 0:
raise ValueError("print_per_step must be int and >= 0.")
self.print_per_step = print_per_step
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name
Args:
print_per_step (int): Print loss every times. Default: 1.
"""
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
super(EvalCallBack, self).__init__()
if not isinstance(print_per_step, int) or print_per_step < 0:
raise ValueError("print_per_step must be int and >= 0.")
self.print_per_step = print_per_step
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name

def epoch_name(self, run_context):
"""
epoch name
"""
self.aucMetric.clear()
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
start_time = time.time()
out = self.model.eval(self.eval_dataset)
end_time = time.time()
eval_time = int(end_time - start_time)
def epoch_name(self, run_context):
"""
epoch name
"""
self.aucMetric.clear()
context.set_auto_parallel_context(strategy_ckpt_save_file="",
strategy_ckpt_load_file="./strategy_train.ckpt")
start_time = time.time()
out = self.model.eval(self.eval_dataset)
end_time = time.time()
eval_time = int(end_time - start_time)

time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time)
print(out_str)
add_write(self.eval_file_name, out_str)
time_str = time.strftime("%Y-%m-%d %H:%M%S", time.localtime())
out_str = "{}==== EvalCallBack model.eval(): {}; eval_time: {}s".format(time_str, out.values(), eval_time)
print(out_str)
add_write(self.eval_file_name, out_str)

model_zoo/wide_and_deep/tools/config.py → model_zoo/wide_and_deep/src/config.py View File

@@ -38,9 +38,9 @@ def argparse_init():
return parser


class Config_WideDeep():
class WideDeepConfig():
"""
Config_WideDeep
WideDeepConfig
"""
def __init__(self):
self.data_path = "./test_raw_data/"
@@ -70,6 +70,7 @@ class Config_WideDeep():
"""
parser = argparse_init()
args, _ = parser.parse_known_args()
self.data_path = args.data_path
self.epochs = args.epochs
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size

model_zoo/wide_and_deep/metrics.py → model_zoo/wide_and_deep/src/metrics.py View File


+ 2
- 2
model_zoo/wide_and_deep/src/wide_and_deep.py View File

@@ -135,8 +135,8 @@ class WideDeepModel(nn.Cell):
self.field_size = config.field_size
self.vocab_size = config.vocab_size
self.emb_dim = config.emb_dim
self.deep_layer_args = config.deep_layer_args
self.deep_layer_dims_list, self.deep_layer_act = self.deep_layer_args
self.deep_layer_dims_list = config.deep_layer_dim
self.deep_layer_act = config.deep_layer_act
self.init_args = config.init_args
self.weight_init, self.bias_init = config.weight_bias_init
self.weight_bias_init = config.weight_bias_init


+ 6
- 6
model_zoo/wide_and_deep/test.py View File

@@ -20,11 +20,11 @@ import os
from mindspore import Model, context
from mindspore.train.serialization import load_checkpoint, load_param_into_net

from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from wide_deep.utils.callbacks import LossCallBack, EvalCallBack
from wide_deep.data.datasets import create_dataset
from wide_deep.utils.metrics import AUCMetric
from tools.config import Config_WideDeep
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import WideDeepConfig

context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True)
@@ -88,7 +88,7 @@ def test_eval(config):


if __name__ == "__main__":
widedeep_config = Config_WideDeep()
widedeep_config = WideDeepConfig()
widedeep_config.argparse_init()

test_eval(widedeep_config.widedeep)

+ 8
- 8
model_zoo/wide_and_deep/train.py View File

@@ -16,19 +16,19 @@ import os
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWarp, NetWithLossClass, WideDeepModel
from wide_deep.utils.callbacks import LossCallBack
from wide_deep.data.datasets import create_dataset
from tools.config import Config_WideDeep
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack
from src.datasets import create_dataset
from src.config import WideDeepConfig

context.set_context(model=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)


def get_WideDeep_net(configure):
WideDeep_net = WideDeepModel(configure)

loss_net = NetWithLossClass(WideDeep_net, configure)
train_net = TrainStepWarp(loss_net)
train_net = TrainStepWrap(loss_net)
eval_net = PredictWithSigmoid(WideDeep_net)

return train_net, eval_net
@@ -71,7 +71,7 @@ def test_train(configure):
train_net.set_train()

model = Model(train_net)
callback = LossCallBack(configure)
callback = LossCallBack(config=configure)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig)
@@ -79,7 +79,7 @@ def test_train(configure):


if __name__ == "__main__":
config = Config_WideDeep()
config = WideDeepConfig()
config.argparse_init()

test_train(config)

model_zoo/wide_and_deep/tools/train_and_test.py → model_zoo/wide_and_deep/train_and_test.py View File

@@ -17,11 +17,11 @@ import os
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig

from wide_deep.models.WideDeep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from wide_deep.utils.callbacks import LossCallBack, EvalCallBack
from wide_deep.data.datasets import create_dataset
from wide_deep.utils.metrics import AUCMetric
from tools.config import Config_WideDeep
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import WideDeepConfig

context.set_context(mode=context.GRAPH_MODE, device_target="Davinci")

@@ -81,7 +81,7 @@ def test_train_eval(config):

eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

callback = LossCallBack()
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)

@@ -91,7 +91,7 @@ def test_train_eval(config):


if __name__ == "__main__":
wide_deep_config = Config_WideDeep()
wide_deep_config = WideDeepConfig()
wide_deep_config.argparse_init()

test_train_eval(wide_deep_config)

+ 4
- 4
model_zoo/wide_and_deep/train_and_test_multinpu.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
"""train_multinpu."""


import os
@@ -27,7 +27,7 @@ from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClas
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset
from src.metrics import AUCMetric
from src.config import Config_WideDeep
from src.config import WideDeepConfig

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True)
@@ -71,7 +71,7 @@ def test_train_eval():
test_train_eval
"""
np.random.seed(1000)
config = Config_WideDeep
config = WideDeepConfig
data_path = Config.data_path
batch_size = config.batch_size
epochs = config.epochs
@@ -93,7 +93,7 @@ def test_train_eval():

eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)

callback = LossCallBack(config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)


Loading…
Cancel
Save