| @@ -69,8 +69,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp | |||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr); | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| const size_t thread_num = 8; | |||
| std::thread threads[8]; | |||
| const size_t thread_num = 16; | |||
| std::thread threads[16]; | |||
| size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; | |||
| size_t i; | |||
| size_t task_offset = 0; | |||
| @@ -12,7 +12,7 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train_imagenet.""" | |||
| """train_dataset.""" | |||
| import os | |||
| @@ -164,9 +164,6 @@ class WideDeepModel(nn.Cell): | |||
| init_acts = [('Wide_b', [1], self.emb_init)] | |||
| var_map = init_var_dict(self.init_args, init_acts) | |||
| self.wide_b = var_map["Wide_b"] | |||
| if parameter_server: | |||
| self.wide_w.set_param_ps() | |||
| self.embedding_table.set_param_ps() | |||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], | |||
| self.all_dim_list[1], | |||
| self.weight_bias_init, | |||
| @@ -217,6 +214,8 @@ class WideDeepModel(nn.Cell): | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) | |||
| self.embedding_table = self.deep_embeddinglookup.embedding_table | |||
| self.wide_w.set_param_ps() | |||
| self.embedding_table.set_param_ps() | |||
| else: | |||
| self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE') | |||
| self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE') | |||
| @@ -0,0 +1,3 @@ | |||
| numpy | |||
| pandas | |||
| pickle | |||
| @@ -0,0 +1,34 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # bash run_multinpu_train.sh | |||
| execute_path=$(pwd) | |||
| script_self=$(readlink -f "$0") | |||
| self_path=$(dirname "${script_self}") | |||
| export RANK_SIZE=$1 | |||
| export EPOCH_SIZE=$2 | |||
| export DATASET=$3 | |||
| export RANK_TABLE_FILE=$4 | |||
| for((i=0;i<$RANK_SIZE;i++)); | |||
| do | |||
| rm -rf ${execute_path}/device_$i/ | |||
| mkdir ${execute_path}/device_$i/ | |||
| cd ${execute_path}/device_$i/ || exit | |||
| export RANK_ID=$i | |||
| export DEVICE_ID=$i | |||
| python -s ${self_path}/../train_and_eval_distribute.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 & | |||
| done | |||
| @@ -0,0 +1,96 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ | |||
| callbacks | |||
| """ | |||
| import time | |||
| from mindspore.train.callback import Callback | |||
| def add_write(file_path, out_str): | |||
| with open(file_path, 'a+', encoding="utf-8") as file_out: | |||
| file_out.write(out_str + "\n") | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF terminating training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, config, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.config = config | |||
| def step_end(self, run_context): | |||
| """Monitor the loss in training.""" | |||
| cb_params = run_context.original_args() | |||
| wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), \ | |||
| cb_params.net_outputs[1].asnumpy() | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| cur_num = cb_params.cur_step_num | |||
| print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True) | |||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0: | |||
| loss_file = open(self.config.loss_file_name, "a+") | |||
| loss_file.write( | |||
| "epoch: %s step: %s, wide_loss is %s, deep_loss is %s" % | |||
| (cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, | |||
| deep_loss)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| print("epoch: %s step: %s, wide_loss is %s, deep_loss is %s" % ( | |||
| cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, | |||
| deep_loss)) | |||
| class EvalCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF terminating training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): | |||
| super(EvalCallBack, self).__init__() | |||
| if not isinstance(print_per_step, int) or print_per_step < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self.print_per_step = print_per_step | |||
| self.model = model | |||
| self.eval_dataset = eval_dataset | |||
| self.aucMetric = auc_metric | |||
| self.aucMetric.clear() | |||
| self.eval_file_name = config.eval_file_name | |||
| def epoch_end(self, run_context): | |||
| """Monitor the auc in training.""" | |||
| self.aucMetric.clear() | |||
| start_time = time.time() | |||
| out = self.model.eval(self.eval_dataset) | |||
| end_time = time.time() | |||
| eval_time = int(end_time - start_time) | |||
| time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| out_str = "{}=====EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time) | |||
| print(out_str) | |||
| add_write(self.eval_file_name, out_str) | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ config. """ | |||
| import argparse | |||
| def argparse_init(): | |||
| """ | |||
| argparse_init | |||
| """ | |||
| parser = argparse.ArgumentParser(description='WideDeep') | |||
| parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data. | |||
| parser.add_argument("--epochs", type=int, default=200) # The number of epochs used to train. | |||
| parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation | |||
| parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation. | |||
| parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP | |||
| parser.add_argument("--deep_layers_act", type=str, default='relu') # The act of hidden layers for MLP | |||
| parser.add_argument("--keep_prob", type=float, default=1.0) # The Embedding size of MF model. | |||
| parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr | |||
| parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr. | |||
| parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient. | |||
| parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient. | |||
| parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file. | |||
| parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file. | |||
| parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file. | |||
| parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file. | |||
| return parser | |||
| class WideDeepConfig(): | |||
| """ | |||
| WideDeepConfig | |||
| """ | |||
| def __init__(self): | |||
| self.data_path = '' | |||
| self.epochs = 200 | |||
| self.batch_size = 131072 | |||
| self.eval_batch_size = 131072 | |||
| self.deep_layers_act = 'relu' | |||
| self.weight_bias_init = ['normal', 'normal'] | |||
| self.emb_init = 'normal' | |||
| self.init_args = [-0.01, 0.01] | |||
| self.dropout_flag = False | |||
| self.keep_prob = 1.0 | |||
| self.l2_coef = 0.0 | |||
| self.adam_lr = 0.003 | |||
| self.ftrl_lr = 0.1 | |||
| self.is_tf_dataset = True | |||
| self.input_emb_dim = 0 | |||
| self.output_path = "./output/" | |||
| self.eval_file_name = "eval.log" | |||
| self.loss_file_name = "loss.log" | |||
| self.ckpt_path = "./checkpoints/" | |||
| def argparse_init(self): | |||
| """ | |||
| argparse_init | |||
| """ | |||
| parser = argparse_init() | |||
| args, _ = parser.parse_known_args() | |||
| self.data_path = args.data_path | |||
| self.epochs = args.epochs | |||
| self.batch_size = args.batch_size | |||
| self.eval_batch_size = args.eval_batch_size | |||
| self.deep_layers_act = args.deep_layers_act | |||
| self.keep_prob = args.keep_prob | |||
| self.weight_bias_init = ['normal', 'normal'] | |||
| self.emb_init = 'normal' | |||
| self.init_args = [-0.01, 0.01] | |||
| self.dropout_flag = False | |||
| self.l2_coef = args.l2_coef | |||
| self.ftrl_lr = args.ftrl_lr | |||
| self.adam_lr = args.adam_lr | |||
| self.is_tf_dataset = args.is_tf_dataset | |||
| self.output_path = args.output_path | |||
| self.eval_file_name = args.eval_file_name | |||
| self.loss_file_name = args.loss_file_name | |||
| self.ckpt_path = args.ckpt_path | |||
| @@ -0,0 +1,341 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """train_dataset.""" | |||
| import os | |||
| import math | |||
| import pickle | |||
| import numpy as np | |||
| import pandas as pd | |||
| import mindspore.dataset.engine as de | |||
| import mindspore.common.dtype as mstype | |||
| class H5Dataset(): | |||
| """ | |||
| H5Dataset | |||
| """ | |||
| input_length = 39 | |||
| def __init__(self, | |||
| data_path, | |||
| train_mode=True, | |||
| train_num_of_parts=21, | |||
| test_num_of_parts=3): | |||
| self._hdf_data_dir = data_path | |||
| self._is_training = train_mode | |||
| if self._is_training: | |||
| self._file_prefix = 'train' | |||
| self._num_of_parts = train_num_of_parts | |||
| else: | |||
| self._file_prefix = 'test' | |||
| self._num_of_parts = test_num_of_parts | |||
| self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, | |||
| self._num_of_parts) | |||
| print("data_size: {}".format(self.data_size)) | |||
| def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts): | |||
| size = 0 | |||
| for part in range(num_of_parts): | |||
| _y = pd.read_hdf( | |||
| os.path.join(hdf_data_dir, file_prefix + '_output_part_' + | |||
| str(part) + '.h5')) | |||
| size += _y.shape[0] | |||
| return size | |||
| def _iterate_hdf_files_(self, num_of_parts=None, shuffle_block=False): | |||
| """ | |||
| iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts | |||
| from the beginning, thus the data stream will never stop | |||
| :param train_mode: True or false,false is eval_mode, | |||
| this file iterator will go through the train set | |||
| :param num_of_parts: number of files | |||
| :param shuffle_block: shuffle block files at every round | |||
| :return: input_hdf_file_name, output_hdf_file_name, finish_flag | |||
| """ | |||
| parts = np.arange(num_of_parts) | |||
| while True: | |||
| if shuffle_block: | |||
| for _ in range(int(shuffle_block)): | |||
| np.random.shuffle(parts) | |||
| for i, p in enumerate(parts): | |||
| yield os.path.join(self._hdf_data_dir, | |||
| self._file_prefix + '_input_part_' + str( | |||
| p) + '.h5'), \ | |||
| os.path.join(self._hdf_data_dir, | |||
| self._file_prefix + '_output_part_' + str( | |||
| p) + '.h5'), i + 1 == len(parts) | |||
| def _generator(self, X, y, batch_size, shuffle=True): | |||
| """ | |||
| should be accessed only in private | |||
| :param X: | |||
| :param y: | |||
| :param batch_size: | |||
| :param shuffle: | |||
| :return: | |||
| """ | |||
| number_of_batches = np.ceil(1. * X.shape[0] / batch_size) | |||
| counter = 0 | |||
| finished = False | |||
| sample_index = np.arange(X.shape[0]) | |||
| if shuffle: | |||
| for _ in range(int(shuffle)): | |||
| np.random.shuffle(sample_index) | |||
| assert X.shape[0] > 0 | |||
| while True: | |||
| batch_index = sample_index[batch_size * counter:batch_size * | |||
| (counter + 1)] | |||
| X_batch = X[batch_index] | |||
| y_batch = y[batch_index] | |||
| counter += 1 | |||
| yield X_batch, y_batch, finished | |||
| if counter == number_of_batches: | |||
| counter = 0 | |||
| finished = True | |||
| def batch_generator(self, | |||
| batch_size=1000, | |||
| random_sample=False, | |||
| shuffle_block=False): | |||
| """ | |||
| :param train_mode: True or false,false is eval_mode, | |||
| :param batch_size | |||
| :param num_of_parts: number of files | |||
| :param random_sample: if True, will shuffle | |||
| :param shuffle_block: shuffle file blocks at every round | |||
| :return: | |||
| """ | |||
| for hdf_in, hdf_out, _ in self._iterate_hdf_files_( | |||
| self._num_of_parts, shuffle_block): | |||
| start = stop = None | |||
| X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values | |||
| y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values | |||
| data_gen = self._generator(X_all, | |||
| y_all, | |||
| batch_size, | |||
| shuffle=random_sample) | |||
| finished = False | |||
| while not finished: | |||
| X, y, finished = data_gen.__next__() | |||
| X_id = X[:, 0:self.input_length] | |||
| X_va = X[:, self.input_length:] | |||
| yield np.array(X_id.astype(dtype=np.int32)), np.array( | |||
| X_va.astype(dtype=np.float32)), np.array( | |||
| y.astype(dtype=np.float32)) | |||
| def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000): | |||
| """ | |||
| _get_h5_dataset | |||
| """ | |||
| data_para = { | |||
| 'batch_size': batch_size, | |||
| } | |||
| if train_mode: | |||
| data_para['random_sample'] = True | |||
| data_para['shuffle_block'] = True | |||
| h5_dataset = H5Dataset(data_path=data_dir, train_mode=train_mode) | |||
| numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size) | |||
| def _iter_h5_data(): | |||
| train_eval_gen = h5_dataset.batch_generator(**data_para) | |||
| for _ in range(0, numbers_of_batch, 1): | |||
| yield train_eval_gen.__next__() | |||
| ds = de.GeneratorDataset(_iter_h5_data(), | |||
| ["ids", "weights", "labels"]) | |||
| ds.set_dataset_size(numbers_of_batch) | |||
| ds = ds.repeat(epochs) | |||
| return ds | |||
| def _get_tf_dataset(data_dir, | |||
| schema_dict, | |||
| input_shape_dict, | |||
| train_mode=True, | |||
| epochs=1, | |||
| batch_size=4096, | |||
| line_per_sample=4096, | |||
| rank_size=None, | |||
| rank_id=None): | |||
| """ | |||
| _get_tf_dataset | |||
| """ | |||
| dataset_files = [] | |||
| file_prefix_name = 'train' if train_mode else 'eval' | |||
| shuffle = bool(train_mode) | |||
| for (dirpath, _, filenames) in os.walk(data_dir): | |||
| for filename in filenames: | |||
| if file_prefix_name in filename and "tfrecord" in filename: | |||
| dataset_files.append(os.path.join(dirpath, filename)) | |||
| schema = de.Schema() | |||
| float_key_list = ["label", "continue_val"] | |||
| columns_list = [] | |||
| for key, attr_dict in schema_dict.items(): | |||
| print("key: {}; shape: {}".format(key, attr_dict["tf_shape"])) | |||
| columns_list.append(key) | |||
| if key in set(float_key_list): | |||
| ms_dtype = mstype.float32 | |||
| else: | |||
| ms_dtype = mstype.int32 | |||
| schema.add_column(key, de_type=ms_dtype) | |||
| if rank_size is not None and rank_id is not None: | |||
| ds = de.TFRecordDataset(dataset_files=dataset_files, | |||
| shuffle=shuffle, | |||
| schema=schema, | |||
| num_parallel_workers=8, | |||
| num_shards=rank_size, | |||
| shard_id=rank_id, | |||
| shard_equal_rows=True) | |||
| else: | |||
| ds = de.TFRecordDataset(dataset_files=dataset_files, | |||
| shuffle=shuffle, | |||
| schema=schema, | |||
| num_parallel_workers=8) | |||
| ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) | |||
| operations_list = [] | |||
| for key in columns_list: | |||
| operations_list.append(lambda x: np.array(x).flatten().reshape(input_shape_dict[key])) | |||
| print("ssssssssssssssssssssss---------------------" * 10) | |||
| print(input_shape_dict) | |||
| print("---------------------" * 10) | |||
| print(schema_dict) | |||
| def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u): | |||
| a = np.asarray(a.reshape(batch_size,)) | |||
| b = np.array(b).flatten().reshape(batch_size, -1) | |||
| c = np.array(c).flatten().reshape(batch_size, -1) | |||
| d = np.array(d).flatten().reshape(batch_size, -1) | |||
| e = np.array(e).flatten().reshape(batch_size, -1) | |||
| f = np.array(f).flatten().reshape(batch_size, -1) | |||
| g = np.array(g).flatten().reshape(batch_size, -1) | |||
| h = np.array(h).flatten().reshape(batch_size, -1) | |||
| i = np.array(i).flatten().reshape(batch_size, -1) | |||
| j = np.array(j).flatten().reshape(batch_size, -1) | |||
| k = np.array(k).flatten().reshape(batch_size, -1) | |||
| l = np.array(l).flatten().reshape(batch_size, -1) | |||
| m = np.array(m).flatten().reshape(batch_size, -1) | |||
| n = np.array(n).flatten().reshape(batch_size, -1) | |||
| o = np.array(o).flatten().reshape(batch_size, -1) | |||
| p = np.array(p).flatten().reshape(batch_size, -1) | |||
| q = np.array(q).flatten().reshape(batch_size, -1) | |||
| r = np.array(r).flatten().reshape(batch_size, -1) | |||
| s = np.array(s).flatten().reshape(batch_size, -1) | |||
| t = np.array(t).flatten().reshape(batch_size, -1) | |||
| u = np.array(u).flatten().reshape(batch_size, -1) | |||
| return a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u | |||
| ds = ds.map( | |||
| operations=mixup, | |||
| input_columns=[ | |||
| 'label', 'continue_val', 'indicator_id', 'emb_128_id', | |||
| 'emb_64_single_id', 'multi_doc_ad_category_id', | |||
| 'multi_doc_ad_category_id_mask', 'multi_doc_event_entity_id', | |||
| 'multi_doc_event_entity_id_mask', 'multi_doc_ad_entity_id', | |||
| 'multi_doc_ad_entity_id_mask', 'multi_doc_event_topic_id', | |||
| 'multi_doc_event_topic_id_mask', 'multi_doc_event_category_id', | |||
| 'multi_doc_event_category_id_mask', 'multi_doc_ad_topic_id', | |||
| 'multi_doc_ad_topic_id_mask', 'ad_id', 'display_ad_and_is_leak', | |||
| 'display_id', 'is_leak' | |||
| ], | |||
| columns_order=[ | |||
| 'label', 'continue_val', 'indicator_id', 'emb_128_id', | |||
| 'emb_64_single_id', 'multi_doc_ad_category_id', | |||
| 'multi_doc_ad_category_id_mask', 'multi_doc_event_entity_id', | |||
| 'multi_doc_event_entity_id_mask', 'multi_doc_ad_entity_id', | |||
| 'multi_doc_ad_entity_id_mask', 'multi_doc_event_topic_id', | |||
| 'multi_doc_event_topic_id_mask', 'multi_doc_event_category_id', | |||
| 'multi_doc_event_category_id_mask', 'multi_doc_ad_topic_id', | |||
| 'multi_doc_ad_topic_id_mask', 'display_id', 'ad_id', | |||
| 'display_ad_and_is_leak', 'is_leak' | |||
| ], | |||
| num_parallel_workers=8) | |||
| ds = ds.repeat(epochs) | |||
| return ds | |||
| def compute_emb_dim(config): | |||
| """ | |||
| compute_emb_dim | |||
| """ | |||
| with open( | |||
| os.path.join(config.data_path + 'dataformat/', | |||
| "input_shape_dict.pkl"), "rb") as file_in: | |||
| input_shape_dict = pickle.load(file_in) | |||
| input_field_size = {} | |||
| for key, shape in input_shape_dict.items(): | |||
| if len(shape) < 2: | |||
| input_field_size[key] = 1 | |||
| else: | |||
| input_field_size[key] = shape[1] | |||
| multi_key_list = [ | |||
| "multi_doc_event_topic_id", "multi_doc_event_entity_id", | |||
| "multi_doc_ad_category_id", "multi_doc_event_category_id", | |||
| "multi_doc_ad_entity_id", "multi_doc_ad_topic_id" | |||
| ] | |||
| config.input_emb_dim = input_field_size["continue_val"] + \ | |||
| input_field_size["indicator_id"] * 64 + \ | |||
| input_field_size["emb_128_id"] * 128 + \ | |||
| input_field_size["emb_64_single_id"] * 64 + \ | |||
| len(multi_key_list) * 64 | |||
| def create_dataset(data_dir, | |||
| train_mode=True, | |||
| epochs=1, | |||
| batch_size=4096, | |||
| is_tf_dataset=True, | |||
| line_per_sample=4096, | |||
| rank_size=None, | |||
| rank_id=None): | |||
| """ | |||
| create_dataset | |||
| """ | |||
| if is_tf_dataset: | |||
| with open(os.path.join(data_dir + 'dataformat/', "schema_dict.pkl"), | |||
| "rb") as file_in: | |||
| print(os.path.join(data_dir + 'dataformat/', "schema_dict.pkl")) | |||
| schema_dict = pickle.load(file_in) | |||
| with open( | |||
| os.path.join(data_dir + 'dataformat/', "input_shape_dict.pkl"), | |||
| "rb") as file_in: | |||
| input_shape_dict = pickle.load(file_in) | |||
| return _get_tf_dataset(data_dir, | |||
| schema_dict, | |||
| input_shape_dict, | |||
| train_mode, | |||
| epochs, | |||
| batch_size, | |||
| line_per_sample, | |||
| rank_size=rank_size, | |||
| rank_id=rank_id) | |||
| if rank_size is not None and rank_size > 1: | |||
| raise RuntimeError("please use tfrecord dataset.") | |||
| return _get_h5_dataset(data_dir, train_mode, epochs, batch_size) | |||
| @@ -0,0 +1,153 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Area under cure metric | |||
| """ | |||
| import time | |||
| import numpy as np | |||
| import pandas as pd | |||
| from sklearn.metrics import roc_auc_score, average_precision_score | |||
| from mindspore.nn.metrics import Metric | |||
| def groupby_df_v1(test_df, gb_key): | |||
| """ | |||
| groupby_df_v1 | |||
| """ | |||
| data_groups = test_df.groupby(gb_key) | |||
| return data_groups | |||
| def _compute_metric_v1(batch_groups, topk): | |||
| """ | |||
| _compute_metric_v1 | |||
| """ | |||
| results = [] | |||
| for df in batch_groups: | |||
| df = df.sort_values(by="preds", ascending=False) | |||
| if df.shape[0] > topk: | |||
| df = df.head(topk) | |||
| preds = df["preds"].values | |||
| labels = df["labels"].values | |||
| if np.sum(labels) > 0: | |||
| results.append(average_precision_score(labels, preds)) | |||
| else: | |||
| results.append(0.0) | |||
| return results | |||
| def mean_AP_topk(batch_labels, batch_preds, topk=12): | |||
| """ | |||
| mean_AP_topk | |||
| """ | |||
| def ap_score(label, y_preds, topk): | |||
| ind_list = np.argsort(y_preds)[::-1] | |||
| ind_list = ind_list[:topk] | |||
| if label not in set(ind_list): | |||
| return 0.0 | |||
| rank = list(ind_list).index(label) | |||
| return 1.0 / (rank + 1) | |||
| mAP_list = [] | |||
| for label, preds in zip(batch_labels, batch_preds): | |||
| mAP = ap_score(label, preds, topk) | |||
| mAP_list.append(mAP) | |||
| return mAP_list | |||
| def new_compute_mAP(test_df, gb_key="display_ids", top_k=12): | |||
| """ | |||
| new_compute_mAP | |||
| """ | |||
| total_start = time.time() | |||
| display_ids = test_df["display_ids"] | |||
| labels = test_df["labels"] | |||
| predictions = test_df["preds"] | |||
| test_df.sort_values(by=[gb_key], inplace=True, ascending=True) | |||
| display_ids = test_df["display_ids"] | |||
| labels = test_df["labels"] | |||
| predictions = test_df["preds"] | |||
| _, display_ids_idx = np.unique(display_ids, return_index=True) | |||
| preds = np.split(predictions.tolist(), display_ids_idx.tolist()[1:]) | |||
| labels = np.split(labels.tolist(), display_ids_idx.tolist()[1:]) | |||
| def pad_fn(ele_l): | |||
| res_list = ele_l + [0.0 for i in range(30 - len(ele_l))] | |||
| return res_list | |||
| preds = list(map(lambda x: pad_fn(x.tolist()), preds)) | |||
| labels = [np.argmax(l) for l in labels] | |||
| result_list = [] | |||
| batch_size = 100000 | |||
| for idx in range(0, len(labels), batch_size): | |||
| batch_labels = labels[idx:idx + batch_size] | |||
| batch_preds = preds[idx:idx + batch_size] | |||
| meanAP = mean_AP_topk(batch_labels, batch_preds, topk=top_k) | |||
| result_list.extend(meanAP) | |||
| mean_AP = np.mean(result_list) | |||
| print("compute time: {}".format(time.time() - total_start)) | |||
| print("mean_AP: {}".format(mean_AP)) | |||
| return mean_AP | |||
| class AUCMetric(Metric): | |||
| """ | |||
| AUCMetric | |||
| """ | |||
| def __init__(self): | |||
| super(AUCMetric, self).__init__() | |||
| self.index = 1 | |||
| def clear(self): | |||
| """Clear the internal evaluation result.""" | |||
| self.true_labels = [] | |||
| self.pred_probs = [] | |||
| self.display_id = [] | |||
| def update(self, *inputs): | |||
| """ | |||
| update | |||
| """ | |||
| all_predict = inputs[1].asnumpy() # predict | |||
| all_label = inputs[2].asnumpy() # label | |||
| all_display_id = inputs[3].asnumpy() # label | |||
| self.true_labels.extend(all_label.flatten().tolist()) | |||
| self.pred_probs.extend(all_predict.flatten().tolist()) | |||
| self.display_id.extend(all_display_id.flatten().tolist()) | |||
| def eval(self): | |||
| """ | |||
| eval | |||
| """ | |||
| if len(self.true_labels) != len(self.pred_probs): | |||
| raise RuntimeError( | |||
| 'true_labels.size() is not equal to pred_probs.size()') | |||
| result_df = pd.DataFrame({ | |||
| "display_ids": self.display_id, | |||
| "preds": self.pred_probs, | |||
| "labels": self.true_labels, | |||
| }) | |||
| auc = roc_auc_score(self.true_labels, self.pred_probs) | |||
| MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12) | |||
| print("=====" * 20 + " auc_metric end ") | |||
| print("=====" * 20 + " auc: {}, map: {}".format(auc, MAP)) | |||
| return auc | |||
| @@ -0,0 +1,638 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """wide and deep model""" | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import nn | |||
| from mindspore import Tensor, Parameter, ParameterTuple | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn import Dropout, Flatten | |||
| from mindspore.nn.optim import Adam, FTRL | |||
| from mindspore.common.initializer import Uniform, initializer | |||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean | |||
| from mindspore.train.parallel_utils import ParallelMode | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| np_type = np.float32 | |||
| ms_type = mstype.float32 | |||
| def init_method(method, shape, name, max_val=1.0): | |||
| """ | |||
| Init method | |||
| """ | |||
| if method in ['uniform']: | |||
| params = Parameter(initializer(Uniform(max_val), shape, ms_type), | |||
| name=name) | |||
| elif method == "one": | |||
| params = Parameter(initializer("ones", shape, ms_type), name=name) | |||
| elif method == 'zero': | |||
| params = Parameter(initializer("zeros", shape, ms_type), name=name) | |||
| elif method == "normal": | |||
| params = Parameter(Tensor( | |||
| np.random.normal(loc=0.0, scale=0.01, | |||
| size=shape).astype(dtype=np_type)), | |||
| name=name) | |||
| return params | |||
| def init_var_dict(init_args, in_vars): | |||
| """ | |||
| Init parameters by dict | |||
| """ | |||
| var_map = {} | |||
| _, _max_val = init_args | |||
| for _, iterm in enumerate(in_vars): | |||
| key, shape, method = iterm | |||
| if key not in var_map.keys(): | |||
| if method in ['random', 'uniform']: | |||
| var_map[key] = Parameter(initializer(Uniform(_max_val), shape, | |||
| ms_type), | |||
| name=key) | |||
| elif method == "one": | |||
| var_map[key] = Parameter(initializer("ones", shape, ms_type), | |||
| name=key) | |||
| elif method == "zero": | |||
| var_map[key] = Parameter(initializer("zeros", shape, ms_type), | |||
| name=key) | |||
| elif method == 'normal': | |||
| var_map[key] = Parameter(Tensor( | |||
| np.random.normal(loc=0.0, scale=0.01, | |||
| size=shape).astype(dtype=np_type)), | |||
| name=key) | |||
| return var_map | |||
| class DenseLayer(nn.Cell): | |||
| """ | |||
| Dense Layer for Deep Layer of WideDeep Model; | |||
| Containing: activation, matmul, bias_add; | |||
| """ | |||
| def __init__(self, | |||
| input_dim, | |||
| output_dim, | |||
| weight_bias_init, | |||
| act_str, | |||
| keep_prob=0.7, | |||
| scale_coef=1.0, | |||
| convert_dtype=True): | |||
| super(DenseLayer, self).__init__() | |||
| weight_init, bias_init = weight_bias_init | |||
| self.weight = init_method(weight_init, [input_dim, output_dim], | |||
| name="weight") | |||
| self.bias = init_method(bias_init, [output_dim], name="bias") | |||
| self.act_func = self._init_activation(act_str) | |||
| self.matmul = P.MatMul(transpose_b=False) | |||
| self.bias_add = P.BiasAdd() | |||
| self.cast = P.Cast() | |||
| self.dropout = Dropout(keep_prob=0.8) | |||
| self.mul = P.Mul() | |||
| self.realDiv = P.RealDiv() | |||
| self.scale_coef = scale_coef | |||
| self.convert_dtype = convert_dtype | |||
| def _init_activation(self, act_str): | |||
| act_str = act_str.lower() | |||
| if act_str == "relu": | |||
| act_func = P.ReLU() | |||
| elif act_str == "sigmoid": | |||
| act_func = P.Sigmoid() | |||
| elif act_str == "tanh": | |||
| act_func = P.Tanh() | |||
| return act_func | |||
| def construct(self, x): | |||
| """ | |||
| DenseLayer construct | |||
| """ | |||
| x = self.act_func(x) | |||
| if self.training: | |||
| x = self.dropout(x) | |||
| x = self.mul(x, self.scale_coef) | |||
| if self.convert_dtype: | |||
| x = self.cast(x, mstype.float16) | |||
| weight = self.cast(self.weight, mstype.float16) | |||
| wx = self.matmul(x, weight) | |||
| wx = self.cast(wx, mstype.float32) | |||
| else: | |||
| wx = self.matmul(x, self.weight) | |||
| wx = self.realDiv(wx, self.scale_coef) | |||
| output = self.bias_add(wx, self.bias) | |||
| return output | |||
| class WideDeepModel(nn.Cell): | |||
| """ | |||
| From paper: " Wide & Deep Learning for Recommender Systems" | |||
| Args: | |||
| config (Class): The default config of Wide&Deep | |||
| """ | |||
| def __init__(self, config): | |||
| super(WideDeepModel, self).__init__() | |||
| emb_128_size = 650000 | |||
| emb64_single_size = 17300 | |||
| emb64_multi_size = 20900 | |||
| indicator_size = 16 | |||
| deep_dim_list = [1024, 1024, 1024, 1024, 1024] | |||
| # deep_dropout=0.0 | |||
| wide_reg_coef = [0.0, 0.0] | |||
| deep_reg_coef = [0.0, 0.0] | |||
| wide_lr = 0.2 | |||
| deep_lr = 1.0 | |||
| self.input_emb_dim = config.input_emb_dim | |||
| self.batch_size = config.batch_size | |||
| self.deep_layer_act = config.deep_layers_act | |||
| self.init_args = config.init_args | |||
| self.weight_init, self.bias_init = config.weight_bias_init | |||
| self.weight_bias_init = config.weight_bias_init | |||
| self.emb_init = config.emb_init | |||
| self.keep_prob = config.keep_prob | |||
| self.layer_dims = deep_dim_list + [1] | |||
| self.all_dim_list = [self.input_emb_dim] + self.layer_dims | |||
| self.continue_field_size = 32 | |||
| self.emb_128_size = emb_128_size | |||
| self.emb64_single_size = emb64_single_size | |||
| self.emb64_multi_size = emb64_multi_size | |||
| self.indicator_size = indicator_size | |||
| self.wide_l1_coef, self.wide_l2_coef = wide_reg_coef | |||
| self.deep_l1_coef, self.deep_l2_coef = deep_reg_coef | |||
| self.wide_lr = wide_lr | |||
| self.deep_lr = deep_lr | |||
| init_acts_embedding_metrix = [ | |||
| ('emb128_embedding', [self.emb_128_size, 128], self.emb_init), | |||
| ('emb64_single', [self.emb64_single_size, 64], self.emb_init), | |||
| ('emb64_multi', [self.emb64_multi_size, 64], self.emb_init), | |||
| ('emb64_indicator', [self.indicator_size, 64], self.emb_init) | |||
| ] | |||
| var_map = init_var_dict(self.init_args, init_acts_embedding_metrix) | |||
| self.emb128_embedding = var_map["emb128_embedding"] | |||
| self.emb64_single = var_map["emb64_single"] | |||
| self.emb64_multi = var_map["emb64_multi"] | |||
| self.emb64_indicator = var_map["emb64_indicator"] | |||
| init_acts_wide_weight = [ | |||
| ('wide_continue_w', [self.continue_field_size], self.emb_init), | |||
| ('wide_emb128_w', [self.emb_128_size], self.emb_init), | |||
| ('wide_emb64_single_w', [self.emb64_single_size], self.emb_init), | |||
| ('wide_emb64_multi_w', [self.emb64_multi_size], self.emb_init), | |||
| ('wide_indicator_w', [self.indicator_size], self.emb_init), | |||
| ('wide_bias', [1], self.emb_init) | |||
| ] | |||
| var_map = init_var_dict(self.init_args, init_acts_wide_weight) | |||
| self.wide_continue_w = var_map["wide_continue_w"] | |||
| self.wide_emb128_w = var_map["wide_emb128_w"] | |||
| self.wide_emb64_single_w = var_map["wide_emb64_single_w"] | |||
| self.wide_emb64_multi_w = var_map["wide_emb64_multi_w"] | |||
| self.wide_indicator_w = var_map["wide_indicator_w"] | |||
| self.wide_bias = var_map["wide_bias"] | |||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], | |||
| self.all_dim_list[1], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, | |||
| convert_dtype=True) | |||
| self.dense_layer_2 = DenseLayer(self.all_dim_list[1], | |||
| self.all_dim_list[2], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, | |||
| convert_dtype=True) | |||
| self.dense_layer_3 = DenseLayer(self.all_dim_list[2], | |||
| self.all_dim_list[3], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, | |||
| convert_dtype=True) | |||
| self.dense_layer_4 = DenseLayer(self.all_dim_list[3], | |||
| self.all_dim_list[4], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, | |||
| convert_dtype=True) | |||
| self.dense_layer_5 = DenseLayer(self.all_dim_list[4], | |||
| self.all_dim_list[5], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, | |||
| convert_dtype=True) | |||
| self.deep_predict = DenseLayer(self.all_dim_list[5], | |||
| self.all_dim_list[6], | |||
| self.weight_bias_init, | |||
| self.deep_layer_act, | |||
| convert_dtype=True) | |||
| self.gather_v2 = P.GatherV2() | |||
| self.mul = P.Mul() | |||
| self.reduce_sum_false = P.ReduceSum(keep_dims=False) | |||
| self.reduce_sum_true = P.ReduceSum(keep_dims=True) | |||
| self.reshape = P.Reshape() | |||
| self.square = P.Square() | |||
| self.shape = P.Shape() | |||
| self.tile = P.Tile() | |||
| self.concat = P.Concat(axis=1) | |||
| self.cast = P.Cast() | |||
| self.reduceMean_false = P.ReduceMean(keep_dims=False) | |||
| self.Concat = P.Concat(axis=1) | |||
| self.BiasAdd = P.BiasAdd() | |||
| self.expand_dims = P.ExpandDims() | |||
| self.flatten = Flatten() | |||
| def construct(self, continue_val, indicator_id, emb_128_id, | |||
| emb_64_single_id, multi_doc_ad_category_id, | |||
| multi_doc_ad_category_id_mask, multi_doc_event_entity_id, | |||
| multi_doc_event_entity_id_mask, multi_doc_ad_entity_id, | |||
| multi_doc_ad_entity_id_mask, multi_doc_event_topic_id, | |||
| multi_doc_event_topic_id_mask, multi_doc_event_category_id, | |||
| multi_doc_event_category_id_mask, multi_doc_ad_topic_id, | |||
| multi_doc_ad_topic_id_mask, display_id, ad_id, | |||
| display_ad_and_is_leak, is_leak): | |||
| """ | |||
| Args: | |||
| id_hldr: batch ids; | |||
| wt_hldr: batch weights; | |||
| """ | |||
| val_hldr = continue_val | |||
| ind_hldr = indicator_id | |||
| emb128_id_hldr = emb_128_id | |||
| emb64_single_hldr = emb_64_single_id | |||
| ind_emb = self.gather_v2(self.emb64_indicator, ind_hldr, 0) | |||
| ind_emb = self.flatten(ind_emb) | |||
| emb128_id_emb = self.gather_v2(self.emb128_embedding, emb128_id_hldr, | |||
| 0) | |||
| emb128_id_emb = self.flatten(emb128_id_emb) | |||
| emb64_sgl_emb = self.gather_v2(self.emb64_single, emb64_single_hldr, 0) | |||
| emb64_sgl_emb = self.flatten(emb64_sgl_emb) | |||
| mult_emb_1 = self.gather_v2(self.emb64_multi, multi_doc_ad_category_id, | |||
| 0) | |||
| mult_emb_1 = self.mul( | |||
| self.cast(mult_emb_1, mstype.float32), | |||
| self.cast(self.expand_dims(multi_doc_ad_category_id_mask, 2), | |||
| mstype.float32)) | |||
| mult_emb_1 = self.reduceMean_false(mult_emb_1, 1) | |||
| mult_emb_2 = self.gather_v2(self.emb64_multi, | |||
| multi_doc_event_entity_id, 0) | |||
| mult_emb_2 = self.mul( | |||
| self.cast(mult_emb_2, mstype.float32), | |||
| self.cast(self.expand_dims(multi_doc_event_entity_id_mask, 2), | |||
| mstype.float32)) | |||
| mult_emb_2 = self.reduceMean_false(mult_emb_2, 1) | |||
| mult_emb_3 = self.gather_v2(self.emb64_multi, multi_doc_ad_entity_id, | |||
| 0) | |||
| mult_emb_3 = self.mul( | |||
| self.cast(mult_emb_3, mstype.float32), | |||
| self.cast(self.expand_dims(multi_doc_ad_entity_id_mask, 2), | |||
| mstype.float32)) | |||
| mult_emb_3 = self.reduceMean_false(mult_emb_3, 1) | |||
| mult_emb_4 = self.gather_v2(self.emb64_multi, multi_doc_event_topic_id, | |||
| 0) | |||
| mult_emb_4 = self.mul( | |||
| self.cast(mult_emb_4, mstype.float32), | |||
| self.cast(self.expand_dims(multi_doc_event_topic_id_mask, 2), | |||
| mstype.float32)) | |||
| mult_emb_4 = self.reduceMean_false(mult_emb_4, 1) | |||
| mult_emb_5 = self.gather_v2(self.emb64_multi, | |||
| multi_doc_event_category_id, 0) | |||
| mult_emb_5 = self.mul( | |||
| self.cast(mult_emb_5, mstype.float32), | |||
| self.cast(self.expand_dims(multi_doc_event_category_id_mask, 2), | |||
| mstype.float32)) | |||
| mult_emb_5 = self.reduceMean_false(mult_emb_5, 1) | |||
| mult_emb_6 = self.gather_v2(self.emb64_multi, multi_doc_ad_topic_id, 0) | |||
| mult_emb_6 = self.mul( | |||
| self.cast(mult_emb_6, mstype.float32), | |||
| self.cast(self.expand_dims(multi_doc_ad_topic_id_mask, 2), | |||
| mstype.float32)) | |||
| mult_emb_6 = self.reduceMean_false(mult_emb_6, 1) | |||
| mult_embedding = self.Concat((mult_emb_1, mult_emb_2, mult_emb_3, | |||
| mult_emb_4, mult_emb_5, mult_emb_6)) | |||
| input_embedding = self.Concat((val_hldr * 1, ind_emb, emb128_id_emb, | |||
| emb64_sgl_emb, mult_embedding)) | |||
| deep_out = self.dense_layer_1(input_embedding) | |||
| deep_out = self.dense_layer_2(deep_out) | |||
| deep_out = self.dense_layer_3(deep_out) | |||
| deep_out = self.dense_layer_4(deep_out) | |||
| deep_out = self.dense_layer_5(deep_out) | |||
| deep_out = self.deep_predict(deep_out) | |||
| val_weight = self.mul(val_hldr, | |||
| self.expand_dims(self.wide_continue_w, 0)) | |||
| val_w_sum = self.reduce_sum_true(val_weight, 1) | |||
| ind_weight = self.gather_v2(self.wide_indicator_w, ind_hldr, 0) | |||
| ind_w_sum = self.reduce_sum_true(ind_weight, 1) | |||
| emb128_id_weight = self.gather_v2(self.wide_emb128_w, emb128_id_hldr, | |||
| 0) | |||
| emb128_w_sum = self.reduce_sum_true(emb128_id_weight, 1) | |||
| emb64_sgl_weight = self.gather_v2(self.wide_emb64_single_w, | |||
| emb64_single_hldr, 0) | |||
| emb64_w_sum = self.reduce_sum_true(emb64_sgl_weight, 1) | |||
| mult_weight_1 = self.gather_v2(self.wide_emb64_multi_w, | |||
| multi_doc_ad_category_id, 0) | |||
| mult_weight_1 = self.mul( | |||
| self.cast(mult_weight_1, mstype.float32), | |||
| self.cast(multi_doc_ad_category_id_mask, mstype.float32)) | |||
| mult_weight_1 = self.reduce_sum_true(mult_weight_1, 1) | |||
| mult_weight_2 = self.gather_v2(self.wide_emb64_multi_w, | |||
| multi_doc_event_entity_id, 0) | |||
| mult_weight_2 = self.mul( | |||
| self.cast(mult_weight_2, mstype.float32), | |||
| self.cast(multi_doc_event_entity_id_mask, mstype.float32)) | |||
| mult_weight_2 = self.reduce_sum_true(mult_weight_2, 1) | |||
| mult_weight_3 = self.gather_v2(self.wide_emb64_multi_w, | |||
| multi_doc_ad_entity_id, 0) | |||
| mult_weight_3 = self.mul( | |||
| self.cast(mult_weight_3, mstype.float32), | |||
| self.cast(multi_doc_ad_entity_id_mask, mstype.float32)) | |||
| mult_weight_3 = self.reduce_sum_true(mult_weight_3, 1) | |||
| mult_weight_4 = self.gather_v2(self.wide_emb64_multi_w, | |||
| multi_doc_event_topic_id, 0) | |||
| mult_weight_4 = self.mul( | |||
| self.cast(mult_weight_4, mstype.float32), | |||
| self.cast(multi_doc_event_topic_id_mask, mstype.float32)) | |||
| mult_weight_4 = self.reduce_sum_true(mult_weight_4, 1) | |||
| mult_weight_5 = self.gather_v2(self.wide_emb64_multi_w, | |||
| multi_doc_event_category_id, 0) | |||
| mult_weight_5 = self.mul( | |||
| self.cast(mult_weight_5, mstype.float32), | |||
| self.cast(multi_doc_event_category_id_mask, mstype.float32)) | |||
| mult_weight_5 = self.reduce_sum_true(mult_weight_5, 1) | |||
| mult_weight_6 = self.gather_v2(self.wide_emb64_multi_w, | |||
| multi_doc_ad_topic_id, 0) | |||
| mult_weight_6 = self.mul( | |||
| self.cast(mult_weight_6, mstype.float32), | |||
| self.cast(multi_doc_ad_topic_id_mask, mstype.float32)) | |||
| mult_weight_6 = self.reduce_sum_true(mult_weight_6, 1) | |||
| mult_weight_sum = mult_weight_1 + mult_weight_2 + mult_weight_3 + mult_weight_4 + mult_weight_5 + mult_weight_6 | |||
| wide_out = self.BiasAdd( | |||
| val_w_sum + ind_w_sum + emb128_w_sum + emb64_w_sum + | |||
| mult_weight_sum, self.wide_bias) | |||
| out = wide_out + deep_out | |||
| return out, self.emb128_embedding, self.emb64_single, self.emb64_multi | |||
| class NetWithLossClass(nn.Cell): | |||
| """" | |||
| Provide WideDeep training loss through network. | |||
| Args: | |||
| network (Cell): The training network | |||
| config (Class): WideDeep config | |||
| """ | |||
| def __init__(self, network, config): | |||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.l2_coef = config.l2_coef | |||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||
| self.square = P.Square() | |||
| self.reduceMean_false = P.ReduceMean(keep_dims=False) | |||
| self.reduceSum_false = P.ReduceSum(keep_dims=False) | |||
| self.reshape = P.Reshape() | |||
| def construct(self, label, continue_val, indicator_id, emb_128_id, | |||
| emb_64_single_id, multi_doc_ad_category_id, | |||
| multi_doc_ad_category_id_mask, multi_doc_event_entity_id, | |||
| multi_doc_event_entity_id_mask, multi_doc_ad_entity_id, | |||
| multi_doc_ad_entity_id_mask, multi_doc_event_topic_id, | |||
| multi_doc_event_topic_id_mask, multi_doc_event_category_id, | |||
| multi_doc_event_category_id_mask, multi_doc_ad_topic_id, | |||
| multi_doc_ad_topic_id_mask, display_id, ad_id, | |||
| display_ad_and_is_leak, is_leak): | |||
| """ | |||
| NetWithLossClass construct | |||
| """ | |||
| # emb128_embedding, emb64_single, emb64_multi | |||
| predict, _, _, _ = self.network( | |||
| continue_val, indicator_id, emb_128_id, emb_64_single_id, | |||
| multi_doc_ad_category_id, multi_doc_ad_category_id_mask, | |||
| multi_doc_event_entity_id, multi_doc_event_entity_id_mask, | |||
| multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask, | |||
| multi_doc_event_topic_id, multi_doc_event_topic_id_mask, | |||
| multi_doc_event_category_id, multi_doc_event_category_id_mask, | |||
| multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id, | |||
| ad_id, display_ad_and_is_leak, is_leak) | |||
| predict = self.reshape(predict, (-1,)) | |||
| basic_loss = self.loss(predict, label) | |||
| wide_loss = self.reduceMean_false(basic_loss) | |||
| deep_loss = self.reduceMean_false(basic_loss) | |||
| return wide_loss, deep_loss | |||
| class IthOutputCell(nn.Cell): | |||
| """ | |||
| IthOutputCell | |||
| """ | |||
| def __init__(self, network, output_index): | |||
| super(IthOutputCell, self).__init__() | |||
| self.network = network | |||
| self.output_index = output_index | |||
| def construct(self, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, | |||
| x14, x15, x16, x17, x18, x19, x20, x21): | |||
| """ | |||
| IthOutputCell construct | |||
| """ | |||
| predict = self.network(x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, | |||
| x12, x13, x14, x15, x16, x17, x18, x19, x20, | |||
| x21)[self.output_index] | |||
| return predict | |||
| class TrainStepWrap(nn.Cell): | |||
| """ | |||
| Encapsulation class of WideDeep network training. | |||
| Append Adam and FTRL optimizers to the training network after that construct | |||
| function can be called to create the backward graph. | |||
| Args: | |||
| network (Cell): the training network. Note that loss function should have been added. | |||
| sens (Number): The adjust parameter. Default: 1000.0 | |||
| """ | |||
| def __init__(self, network, config, sens=1000.0): | |||
| super(TrainStepWrap, self).__init__() | |||
| self.network = network | |||
| self.network.set_train() | |||
| self.trainable_params = network.trainable_params() | |||
| weights_w = [] | |||
| weights_d = [] | |||
| for params in self.trainable_params: | |||
| if 'wide' in params.name: | |||
| weights_w.append(params) | |||
| else: | |||
| weights_d.append(params) | |||
| self.weights_w = ParameterTuple(weights_w) | |||
| self.weights_d = ParameterTuple(weights_d) | |||
| self.optimizer_w = FTRL(learning_rate=config.ftrl_lr, | |||
| params=self.weights_w, | |||
| l1=5e-4, | |||
| l2=5e-4, | |||
| initial_accum=0.1, | |||
| loss_scale=sens) | |||
| #self.optimizer_d = ProximalAdagrad(self.weights_d, learning_rate=config.adam_lr,loss_scale=sens) | |||
| self.optimizer_d = Adam(self.weights_d, | |||
| learning_rate=config.adam_lr, | |||
| eps=1e-6, | |||
| loss_scale=sens) | |||
| self.hyper_map = C.HyperMap() | |||
| self.grad_w = C.GradOperation('grad_w', | |||
| get_by_list=True, | |||
| sens_param=True) | |||
| self.grad_d = C.GradOperation('grad_d', | |||
| get_by_list=True, | |||
| sens_param=True) | |||
| self.sens = sens | |||
| self.loss_net_w = IthOutputCell(network, output_index=0) | |||
| self.loss_net_d = IthOutputCell(network, output_index=1) | |||
| self.reducer_flag = False | |||
| self.grad_reducer_w = None | |||
| self.grad_reducer_d = None | |||
| parallel_mode = _get_parallel_mode() | |||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, | |||
| ParallelMode.HYBRID_PARALLEL): | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = _get_mirror_mean() | |||
| degree = _get_device_num() | |||
| self.grad_reducer_w = DistributedGradReducer( | |||
| self.optimizer_w.parameters, mean, degree) | |||
| self.grad_reducer_d = DistributedGradReducer( | |||
| self.optimizer_d.parameters, mean, degree) | |||
| def construct(self, label, continue_val, indicator_id, emb_128_id, | |||
| emb_64_single_id, multi_doc_ad_category_id, | |||
| multi_doc_ad_category_id_mask, multi_doc_event_entity_id, | |||
| multi_doc_event_entity_id_mask, multi_doc_ad_entity_id, | |||
| multi_doc_ad_entity_id_mask, multi_doc_event_topic_id, | |||
| multi_doc_event_topic_id_mask, multi_doc_event_category_id, | |||
| multi_doc_event_category_id_mask, multi_doc_ad_topic_id, | |||
| multi_doc_ad_topic_id_mask, display_id, ad_id, | |||
| display_ad_and_is_leak, is_leak): | |||
| """ | |||
| TrainStepWrap construct | |||
| """ | |||
| weights_w = self.weights_w | |||
| weights_d = self.weights_d | |||
| loss_w, loss_d = self.network( | |||
| label, continue_val, indicator_id, emb_128_id, emb_64_single_id, | |||
| multi_doc_ad_category_id, multi_doc_ad_category_id_mask, | |||
| multi_doc_event_entity_id, multi_doc_event_entity_id_mask, | |||
| multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask, | |||
| multi_doc_event_topic_id, multi_doc_event_topic_id_mask, | |||
| multi_doc_event_category_id, multi_doc_event_category_id_mask, | |||
| multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id, | |||
| ad_id, display_ad_and_is_leak, is_leak) | |||
| sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) # | |||
| sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) # | |||
| grads_w = self.grad_w(self.loss_net_w, weights_w)( | |||
| label, continue_val, indicator_id, emb_128_id, emb_64_single_id, | |||
| multi_doc_ad_category_id, multi_doc_ad_category_id_mask, | |||
| multi_doc_event_entity_id, multi_doc_event_entity_id_mask, | |||
| multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask, | |||
| multi_doc_event_topic_id, multi_doc_event_topic_id_mask, | |||
| multi_doc_event_category_id, multi_doc_event_category_id_mask, | |||
| multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id, | |||
| ad_id, display_ad_and_is_leak, is_leak, sens_w) | |||
| grads_d = self.grad_d(self.loss_net_d, weights_d)( | |||
| label, continue_val, indicator_id, emb_128_id, emb_64_single_id, | |||
| multi_doc_ad_category_id, multi_doc_ad_category_id_mask, | |||
| multi_doc_event_entity_id, multi_doc_event_entity_id_mask, | |||
| multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask, | |||
| multi_doc_event_topic_id, multi_doc_event_topic_id_mask, | |||
| multi_doc_event_category_id, multi_doc_event_category_id_mask, | |||
| multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id, | |||
| ad_id, display_ad_and_is_leak, is_leak, sens_d) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads_w = self.grad_reducer_w(grads_w) | |||
| grads_d = self.grad_reducer_d(grads_d) | |||
| return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend( | |||
| loss_d, self.optimizer_d(grads_d)) | |||
| class PredictWithSigmoid(nn.Cell): | |||
| """ | |||
| PredictWithSigomid | |||
| """ | |||
| def __init__(self, network): | |||
| super(PredictWithSigmoid, self).__init__() | |||
| self.network = network | |||
| self.sigmoid = P.Sigmoid() | |||
| self.reshape = P.Reshape() | |||
| def construct(self, label, continue_val, indicator_id, emb_128_id, | |||
| emb_64_single_id, multi_doc_ad_category_id, | |||
| multi_doc_ad_category_id_mask, multi_doc_event_entity_id, | |||
| multi_doc_event_entity_id_mask, multi_doc_ad_entity_id, | |||
| multi_doc_ad_entity_id_mask, multi_doc_event_topic_id, | |||
| multi_doc_event_topic_id_mask, multi_doc_event_category_id, | |||
| multi_doc_event_category_id_mask, multi_doc_ad_topic_id, | |||
| multi_doc_ad_topic_id_mask, display_id, ad_id, | |||
| display_ad_and_is_leak, is_leak): | |||
| """ | |||
| PredictWithSigomid construct | |||
| """ | |||
| logits, _, _, _ = self.network( | |||
| continue_val, indicator_id, emb_128_id, emb_64_single_id, | |||
| multi_doc_ad_category_id, multi_doc_ad_category_id_mask, | |||
| multi_doc_event_entity_id, multi_doc_event_entity_id_mask, | |||
| multi_doc_ad_entity_id, multi_doc_ad_entity_id_mask, | |||
| multi_doc_event_topic_id, multi_doc_event_topic_id_mask, | |||
| multi_doc_event_category_id, multi_doc_event_category_id_mask, | |||
| multi_doc_ad_topic_id, multi_doc_ad_topic_id_mask, display_id, | |||
| ad_id, display_ad_and_is_leak, is_leak) | |||
| logits = self.reshape(logits, (-1,)) | |||
| pred_probs = self.sigmoid(logits) | |||
| return logits, pred_probs, label, display_id | |||
| @@ -0,0 +1,107 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ training_and_evaluating """ | |||
| import os | |||
| import sys | |||
| from mindspore import Model, context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train.callback import TimeMonitor | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| from src.datasets import create_dataset, compute_emb_dim | |||
| from src.metrics import AUCMetric | |||
| from src.config import WideDeepConfig | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| def get_WideDeep_net(config): | |||
| """ | |||
| Get network of wide&deep model. | |||
| """ | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| train_net = TrainStepWrap(loss_net, config) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| return train_net, eval_net | |||
| class ModelBuilder(): | |||
| """ | |||
| ModelBuilder. | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def get_hook(self): | |||
| pass | |||
| def get_train_hook(self): | |||
| hooks = [] | |||
| callback = LossCallBack() | |||
| hooks.append(callback) | |||
| if int(os.getenv('DEVICE_ID')) == 0: | |||
| pass | |||
| return hooks | |||
| def get_net(self, config): | |||
| return get_WideDeep_net(config) | |||
| def train_and_eval(config): | |||
| """ | |||
| train_and_eval. | |||
| """ | |||
| data_path = config.data_path | |||
| epochs = config.epochs | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| net_builder = ModelBuilder() | |||
| train_net, eval_net = net_builder.get_net(config) | |||
| train_net.set_train() | |||
| auc_metric = AUCMetric() | |||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | |||
| callback = LossCallBack(config) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), | |||
| keep_checkpoint_max=10) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, | |||
| callback, ckpoint_cb]) | |||
| if __name__ == "__main__": | |||
| wide_and_deep_config = WideDeepConfig() | |||
| wide_and_deep_config.argparse_init() | |||
| compute_emb_dim(wide_and_deep_config) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", | |||
| save_graphs=True) | |||
| train_and_eval(wide_and_deep_config) | |||
| @@ -0,0 +1,113 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ training_multinpu""" | |||
| import os | |||
| import sys | |||
| from mindspore import Model, context | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||
| from mindspore.train.callback import TimeMonitor | |||
| from mindspore.train import ParallelMode | |||
| from mindspore.communication.management import get_rank, get_group_size, init | |||
| from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel | |||
| from src.callbacks import LossCallBack, EvalCallBack | |||
| from src.datasets import create_dataset, compute_emb_dim | |||
| from src.metrics import AUCMetric | |||
| from src.config import WideDeepConfig | |||
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |||
| def get_WideDeep_net(config): | |||
| """ | |||
| get_WideDeep_net | |||
| """ | |||
| WideDeep_net = WideDeepModel(config) | |||
| loss_net = NetWithLossClass(WideDeep_net, config) | |||
| train_net = TrainStepWrap(loss_net, config) | |||
| eval_net = PredictWithSigmoid(WideDeep_net) | |||
| return train_net, eval_net | |||
| class ModelBuilder(): | |||
| """ | |||
| ModelBuilder | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| def get_hook(self): | |||
| pass | |||
| def get_train_hook(self): | |||
| hooks = [] | |||
| callback = LossCallBack() | |||
| hooks.append(callback) | |||
| if int(os.getenv('DEVICE_ID')) == 0: | |||
| pass | |||
| return hooks | |||
| def get_net(self, config): | |||
| return get_WideDeep_net(config) | |||
| def train_and_eval(config): | |||
| """ | |||
| train_and_eval | |||
| """ | |||
| data_path = config.data_path | |||
| epochs = config.epochs | |||
| print("epochs is {}".format(epochs)) | |||
| ds_train = create_dataset(data_path, train_mode=True, epochs=1, | |||
| batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset, | |||
| rank_id=get_rank(), rank_size=get_group_size()) | |||
| ds_eval = create_dataset(data_path, train_mode=False, epochs=1, | |||
| batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset, | |||
| rank_id=get_rank(), rank_size=get_group_size()) | |||
| print("ds_train.size: {}".format(ds_train.get_dataset_size())) | |||
| print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) | |||
| net_builder = ModelBuilder() | |||
| train_net, eval_net = net_builder.get_net(config) | |||
| train_net.set_train() | |||
| auc_metric = AUCMetric() | |||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | |||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) | |||
| callback = LossCallBack(config) | |||
| ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), | |||
| keep_checkpoint_max=10) | |||
| ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', | |||
| directory=config.ckpt_path, config=ckptconfig) | |||
| model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, | |||
| callback, ckpoint_cb]) | |||
| if __name__ == "__main__": | |||
| wide_and_deep_config = WideDeepConfig() | |||
| wide_and_deep_config.argparse_init() | |||
| compute_emb_dim(wide_and_deep_config) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Davinci", | |||
| save_graphs=True) | |||
| init() | |||
| context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True, | |||
| device_num=get_group_size()) | |||
| train_and_eval(wide_and_deep_config) | |||