| @@ -17,3 +17,4 @@ bs4 | |||||
| astunparse | astunparse | ||||
| packaging >= 20.0 | packaging >= 20.0 | ||||
| pycocotools >= 2.0.0 # for st test | pycocotools >= 2.0.0 # for st test | ||||
| tables >= 3.6.1 # for st test | |||||
| @@ -0,0 +1,233 @@ | |||||
| # coding:utf-8 | |||||
| import os | |||||
| import pickle | |||||
| import collections | |||||
| import argparse | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| TRAIN_LINE_COUNT = 45840617 | |||||
| TEST_LINE_COUNT = 6042135 | |||||
| class DataStatsDict(): | |||||
| def __init__(self): | |||||
| self.field_size = 39 # value_1-13; cat_1-26; | |||||
| self.val_cols = ["val_{}".format(i + 1) for i in range(13)] | |||||
| self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)] | |||||
| # | |||||
| self.val_min_dict = {col: 0 for col in self.val_cols} | |||||
| self.val_max_dict = {col: 0 for col in self.val_cols} | |||||
| self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} | |||||
| # | |||||
| self.oov_prefix = "OOV_" | |||||
| self.cat2id_dict = {} | |||||
| self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) | |||||
| self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) | |||||
| # { "val_1": , ..., "val_13": , "OOV_cat_1": , ..., "OOV_cat_26": } | |||||
| def stats_vals(self, val_list): | |||||
| assert len(val_list) == len(self.val_cols) | |||||
| def map_max_min(i, val): | |||||
| key = self.val_cols[i] | |||||
| if val != "": | |||||
| if float(val) > self.val_max_dict[key]: | |||||
| self.val_max_dict[key] = float(val) | |||||
| if float(val) < self.val_min_dict[key]: | |||||
| self.val_min_dict[key] = float(val) | |||||
| for i, val in enumerate(val_list): | |||||
| map_max_min(i, val) | |||||
| def stats_cats(self, cat_list): | |||||
| assert len(cat_list) == len(self.cat_cols) | |||||
| def map_cat_count(i, cat): | |||||
| key = self.cat_cols[i] | |||||
| self.cat_count_dict[key][cat] += 1 | |||||
| for i, cat in enumerate(cat_list): | |||||
| map_cat_count(i, cat) | |||||
| # | |||||
| def save_dict(self, output_path, prefix=""): | |||||
| with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: | |||||
| pickle.dump(self.val_max_dict, file_wrt) | |||||
| with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: | |||||
| pickle.dump(self.val_min_dict, file_wrt) | |||||
| with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: | |||||
| pickle.dump(self.cat_count_dict, file_wrt) | |||||
| def load_dict(self, dict_path, prefix=""): | |||||
| with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: | |||||
| self.val_max_dict = pickle.load(file_wrt) | |||||
| with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: | |||||
| self.val_min_dict = pickle.load(file_wrt) | |||||
| with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: | |||||
| self.cat_count_dict = pickle.load(file_wrt) | |||||
| print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items()))) | |||||
| print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items()))) | |||||
| def get_cat2id(self, threshold=100): | |||||
| for key, cat_count_d in self.cat_count_dict.items(): | |||||
| new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) | |||||
| for cat_str, _ in new_cat_count_d.items(): | |||||
| self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) | |||||
| # print("before_all_count: {}".format( before_all_count )) # before_all_count: 33762577 | |||||
| # print("after_all_count: {}".format( after_all_count )) # after_all_count: 184926 | |||||
| print("cat2id_dict.size: {}".format(len(self.cat2id_dict))) | |||||
| print("cat2id_dict.items()[:50]: {}".format(list(self.cat2id_dict.items())[:50])) | |||||
| def map_cat2id(self, values, cats): | |||||
| def minmax_scale_value(i, val): | |||||
| # min_v = float(self.val_min_dict[ "val_{}".format(i+1) ]) | |||||
| max_v = float(self.val_max_dict["val_{}".format(i + 1)]) | |||||
| # return ( float(val) - min_v ) * 1.0 / (max_v - min_v) | |||||
| return float(val) * 1.0 / max_v | |||||
| id_list = [] | |||||
| weight_list = [] | |||||
| for i, val in enumerate(values): | |||||
| if val == "": | |||||
| id_list.append(i) | |||||
| weight_list.append(0) | |||||
| else: | |||||
| key = "val_{}".format(i + 1) | |||||
| id_list.append(self.cat2id_dict[key]) | |||||
| weight_list.append(minmax_scale_value(i, float(val))) | |||||
| for i, cat_str in enumerate(cats): | |||||
| key = "cat_{}".format(i + 1) + "_" + cat_str | |||||
| if key in self.cat2id_dict: | |||||
| id_list.append(self.cat2id_dict[key]) | |||||
| else: | |||||
| id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) | |||||
| weight_list.append(1.0) | |||||
| return id_list, weight_list | |||||
| def mkdir_path(file_path): | |||||
| if not os.path.exists(file_path): | |||||
| os.makedirs(file_path) | |||||
| def statsdata(data_source_path, output_path, data_stats1): | |||||
| with open(data_source_path, encoding="utf-8") as file_in: | |||||
| errorline_list = [] | |||||
| count = 0 | |||||
| for line in file_in: | |||||
| count += 1 | |||||
| line = line.strip("\n") | |||||
| items = line.split("\t") | |||||
| if len(items) != 40: | |||||
| errorline_list.append(count) | |||||
| print("line: {}".format(line)) | |||||
| continue | |||||
| if count % 1000000 == 0: | |||||
| print("Have handle {}w lines.".format(count // 10000)) | |||||
| values = items[1:14] | |||||
| cats = items[14:] | |||||
| assert len(values) == 13, "values.size: {}".format(len(values)) | |||||
| assert len(cats) == 26, "cats.size: {}".format(len(cats)) | |||||
| data_stats1.stats_vals(values) | |||||
| data_stats1.stats_cats(cats) | |||||
| data_stats1.save_dict(output_path) | |||||
| def add_write(file_path, wrt_str): | |||||
| with open(file_path, 'a', encoding="utf-8") as file_out: | |||||
| file_out.write(wrt_str + "\n") | |||||
| def random_split_trans2h5(input_file_path, output_path, data_stats2, part_rows=2000000, test_size=0.1, seed=2020): | |||||
| test_size = int(TRAIN_LINE_COUNT * test_size) | |||||
| all_indices = [i for i in range(TRAIN_LINE_COUNT)] | |||||
| np.random.seed(seed) | |||||
| np.random.shuffle(all_indices) | |||||
| print("all_indices.size: {}".format(len(all_indices))) | |||||
| test_indices_set = set(all_indices[: test_size]) | |||||
| print("test_indices_set.size: {}".format(len(test_indices_set))) | |||||
| print("----------" * 10 + "\n" * 2) | |||||
| train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5") | |||||
| train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5") | |||||
| test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5") | |||||
| test_label_file_name = os.path.join(output_path, "test_output_part_{}.h5") | |||||
| train_feature_list = [] | |||||
| train_label_list = [] | |||||
| test_feature_list = [] | |||||
| test_label_list = [] | |||||
| with open(input_file_path, encoding="utf-8") as file_in: | |||||
| count = 0 | |||||
| train_part_number = 0 | |||||
| test_part_number = 0 | |||||
| for i, line in enumerate(file_in): | |||||
| count += 1 | |||||
| if count % 1000000 == 0: | |||||
| print("Have handle {}w lines.".format(count // 10000)) | |||||
| line = line.strip("\n") | |||||
| items = line.split("\t") | |||||
| if len(items) != 40: | |||||
| continue | |||||
| label = float(items[0]) | |||||
| values = items[1:14] | |||||
| cats = items[14:] | |||||
| assert len(values) == 13, "values.size: {}".format(len(values)) | |||||
| assert len(cats) == 26, "cats.size: {}".format(len(cats)) | |||||
| ids, wts = data_stats2.map_cat2id(values, cats) | |||||
| if i not in test_indices_set: | |||||
| train_feature_list.append(ids + wts) | |||||
| train_label_list.append(label) | |||||
| else: | |||||
| test_feature_list.append(ids + wts) | |||||
| test_label_list.append(label) | |||||
| if train_label_list and (len(train_label_list) % part_rows == 0): | |||||
| pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), | |||||
| key="fixed") | |||||
| pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), | |||||
| key="fixed") | |||||
| train_feature_list = [] | |||||
| train_label_list = [] | |||||
| train_part_number += 1 | |||||
| if test_label_list and (len(test_label_list) % part_rows == 0): | |||||
| pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), | |||||
| key="fixed") | |||||
| pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), | |||||
| key="fixed") | |||||
| test_feature_list = [] | |||||
| test_label_list = [] | |||||
| test_part_number += 1 | |||||
| if train_label_list: | |||||
| pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), | |||||
| key="fixed") | |||||
| pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), | |||||
| key="fixed") | |||||
| if test_label_list: | |||||
| pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), | |||||
| key="fixed") | |||||
| pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), key="fixed") | |||||
| if __name__ == "__main__": | |||||
| parser = argparse.ArgumentParser(description='Get and Process datasets') | |||||
| parser.add_argument('--base_path', default="/home/wushuquan/tmp/", help='The path to save dataset') | |||||
| parser.add_argument('--output_path', default="/home/wushuquan/tmp/h5dataset/", | |||||
| help='The path to save h5 dataset') | |||||
| args, _ = parser.parse_known_args() | |||||
| base_path = args.base_path | |||||
| data_path = base_path + "" | |||||
| # mkdir_path(data_path) | |||||
| # if not os.path.exists(base_path + "dac.tar.gz"): | |||||
| # os.system( | |||||
| # "wget -P {} -c https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz --no-check-certificate".format( | |||||
| # base_path)) | |||||
| os.system("tar -zxvf {}dac.tar.gz".format(data_path)) | |||||
| print("********tar end***********") | |||||
| data_stats = DataStatsDict() | |||||
| # step 1, stats the vocab and normalize value | |||||
| data_file_path = "./train.txt" | |||||
| stats_output_path = base_path + "stats_dict/" | |||||
| mkdir_path(stats_output_path) | |||||
| statsdata(data_file_path, stats_output_path, data_stats) | |||||
| print("----------" * 10) | |||||
| data_stats.load_dict(dict_path=stats_output_path, prefix="") | |||||
| data_stats.get_cat2id(threshold=100) | |||||
| # step 2, transform data trans2h5; version 2: np.random.shuffle | |||||
| in_file_path = "./train.txt" | |||||
| mkdir_path(args.output_path) | |||||
| random_split_trans2h5(in_file_path, args.output_path, data_stats, part_rows=2000000, test_size=0.1, seed=2020) | |||||
| @@ -0,0 +1,110 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the License); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # httpwww.apache.orglicensesLICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an AS IS BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Defined callback for DeepFM. | |||||
| """ | |||||
| import time | |||||
| from mindspore.train.callback import Callback | |||||
| def add_write(file_path, out_str): | |||||
| with open(file_path, 'a+', encoding='utf-8') as file_out: | |||||
| file_out.write(out_str + '\n') | |||||
| class EvalCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF terminating training. | |||||
| Note | |||||
| If per_print_times is 0 do not print loss. | |||||
| """ | |||||
| def __init__(self, model, eval_dataset, auc_metric, eval_file_path): | |||||
| super(EvalCallBack, self).__init__() | |||||
| self.model = model | |||||
| self.eval_dataset = eval_dataset | |||||
| self.aucMetric = auc_metric | |||||
| self.aucMetric.clear() | |||||
| self.eval_file_path = eval_file_path | |||||
| def epoch_end(self, run_context): | |||||
| start_time = time.time() | |||||
| out = self.model.eval(self.eval_dataset) | |||||
| eval_time = int(time.time() - start_time) | |||||
| time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
| out_str = "{} EvalCallBack metric{}; eval_time{}s".format( | |||||
| time_str, out.values(), eval_time) | |||||
| print(out_str) | |||||
| add_write(self.eval_file_path, out_str) | |||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF terminating training. | |||||
| Note | |||||
| If per_print_times is 0 do not print loss. | |||||
| Args | |||||
| loss_file_path (str) The file absolute path, to save as loss_file; | |||||
| per_print_times (int) Print loss every times. Default 1. | |||||
| """ | |||||
| def __init__(self, loss_file_path, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0.") | |||||
| self.loss_file_path = loss_file_path | |||||
| self._per_print_times = per_print_times | |||||
| self.loss = 0 | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| loss = cb_params.net_outputs.asnumpy() | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| cur_num = cb_params.cur_step_num | |||||
| if self._per_print_times != 0 and cur_num % self._per_print_times == 0: | |||||
| with open(self.loss_file_path, "a+") as loss_file: | |||||
| time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||||
| loss_file.write("{} epoch: {} step: {}, loss is {}\n".format( | |||||
| time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss)) | |||||
| print("epoch: {} step: {}, loss is {}\n".format( | |||||
| cb_params.cur_epoch_num, cur_step_in_epoch, loss)) | |||||
| self.loss = loss | |||||
| class TimeMonitor(Callback): | |||||
| """ | |||||
| Time monitor for calculating cost of each epoch. | |||||
| Args | |||||
| data_size (int) step size of an epoch. | |||||
| """ | |||||
| def __init__(self, data_size): | |||||
| super(TimeMonitor, self).__init__() | |||||
| self.data_size = data_size | |||||
| self.per_step_time = 0 | |||||
| def epoch_begin(self, run_context): | |||||
| self.epoch_time = time.time() | |||||
| def epoch_end(self, run_context): | |||||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||||
| per_step_mseconds = epoch_mseconds / self.data_size | |||||
| print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) | |||||
| self.per_step_time = per_step_mseconds | |||||
| def step_begin(self, run_context): | |||||
| self.step_time = time.time() | |||||
| def step_end(self, run_context): | |||||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||||
| print(f"step time {step_mseconds}", flush=True) | |||||
| @@ -0,0 +1,62 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| network config setting, will be used in train.py and eval.py | |||||
| """ | |||||
| class DataConfig: | |||||
| """ | |||||
| Define parameters of dataset. | |||||
| """ | |||||
| data_vocab_size = 184965 | |||||
| train_num_of_parts = 21 | |||||
| test_num_of_parts = 3 | |||||
| batch_size = 1000 | |||||
| data_field_size = 39 | |||||
| # dataset format, 1: mindrecord, 2: tfrecord, 3: h5 | |||||
| data_format = 3 | |||||
| class ModelConfig: | |||||
| """ | |||||
| Define parameters of model. | |||||
| """ | |||||
| batch_size = DataConfig.batch_size | |||||
| data_field_size = DataConfig.data_field_size | |||||
| data_vocab_size = DataConfig.data_vocab_size | |||||
| data_emb_dim = 80 | |||||
| deep_layer_args = [[400, 400, 512], "relu"] | |||||
| init_args = [-0.01, 0.01] | |||||
| weight_bias_init = ['normal', 'normal'] | |||||
| keep_prob = 0.9 | |||||
| class TrainConfig: | |||||
| """ | |||||
| Define parameters of training. | |||||
| """ | |||||
| batch_size = DataConfig.batch_size | |||||
| l2_coef = 1e-6 | |||||
| learning_rate = 1e-5 | |||||
| epsilon = 1e-8 | |||||
| loss_scale = 1024.0 | |||||
| train_epochs = 3 | |||||
| save_checkpoint = True | |||||
| ckpt_file_name_prefix = "deepfm" | |||||
| save_checkpoint_steps = 1 | |||||
| keep_checkpoint_max = 15 | |||||
| eval_callback = True | |||||
| loss_callback = True | |||||
| @@ -0,0 +1,298 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Create train or eval dataset. | |||||
| """ | |||||
| import os | |||||
| import math | |||||
| from enum import Enum | |||||
| import pandas as pd | |||||
| import numpy as np | |||||
| import mindspore.dataset.engine as de | |||||
| import mindspore.common.dtype as mstype | |||||
| from .config import DataConfig | |||||
| class DataType(Enum): | |||||
| """ | |||||
| Enumerate supported dataset format. | |||||
| """ | |||||
| MINDRECORD = 1 | |||||
| TFRECORD = 2 | |||||
| H5 = 3 | |||||
| class H5Dataset(): | |||||
| """ | |||||
| Create dataset with H5 format. | |||||
| Args: | |||||
| data_path (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is used for train or eval (default=True). | |||||
| train_num_of_parts (int): The number of train data file (default=21). | |||||
| test_num_of_parts (int): The number of test data file (default=3). | |||||
| """ | |||||
| max_length = 39 | |||||
| def __init__(self, data_path, train_mode=True, | |||||
| train_num_of_parts=DataConfig.train_num_of_parts, | |||||
| test_num_of_parts=DataConfig.test_num_of_parts): | |||||
| self._hdf_data_dir = data_path | |||||
| self._is_training = train_mode | |||||
| if self._is_training: | |||||
| self._file_prefix = 'train' | |||||
| self._num_of_parts = train_num_of_parts | |||||
| else: | |||||
| self._file_prefix = 'test' | |||||
| self._num_of_parts = test_num_of_parts | |||||
| self.data_size = self._bin_count(self._hdf_data_dir, self._file_prefix, self._num_of_parts) | |||||
| print("data_size: {}".format(self.data_size)) | |||||
| def _bin_count(self, hdf_data_dir, file_prefix, num_of_parts): | |||||
| size = 0 | |||||
| for part in range(num_of_parts): | |||||
| _y = pd.read_hdf(os.path.join(hdf_data_dir, f'{file_prefix}_output_part_{str(part)}.h5')) | |||||
| size += _y.shape[0] | |||||
| return size | |||||
| def _iterate_hdf_files_(self, num_of_parts=None, | |||||
| shuffle_block=False): | |||||
| """ | |||||
| iterate among hdf files(blocks). when the whole data set is finished, the iterator restarts | |||||
| from the beginning, thus the data stream will never stop | |||||
| :param train_mode: True or false,false is eval_mode, | |||||
| this file iterator will go through the train set | |||||
| :param num_of_parts: number of files | |||||
| :param shuffle_block: shuffle block files at every round | |||||
| :return: input_hdf_file_name, output_hdf_file_name, finish_flag | |||||
| """ | |||||
| parts = np.arange(num_of_parts) | |||||
| while True: | |||||
| if shuffle_block: | |||||
| for _ in range(int(shuffle_block)): | |||||
| np.random.shuffle(parts) | |||||
| for i, p in enumerate(parts): | |||||
| yield os.path.join(self._hdf_data_dir, f'{self._file_prefix}_input_part_{str(p)}.h5'), \ | |||||
| os.path.join(self._hdf_data_dir, f'{self._file_prefix}_output_part_{str(p)}.h5'), \ | |||||
| i + 1 == len(parts) | |||||
| def _generator(self, X, y, batch_size, shuffle=True): | |||||
| """ | |||||
| should be accessed only in private | |||||
| :param X: | |||||
| :param y: | |||||
| :param batch_size: | |||||
| :param shuffle: | |||||
| :return: | |||||
| """ | |||||
| number_of_batches = np.ceil(1. * X.shape[0] / batch_size) | |||||
| counter = 0 | |||||
| finished = False | |||||
| sample_index = np.arange(X.shape[0]) | |||||
| if shuffle: | |||||
| for _ in range(int(shuffle)): | |||||
| np.random.shuffle(sample_index) | |||||
| assert X.shape[0] > 0 | |||||
| while True: | |||||
| batch_index = sample_index[batch_size * counter: batch_size * (counter + 1)] | |||||
| X_batch = X[batch_index] | |||||
| y_batch = y[batch_index] | |||||
| counter += 1 | |||||
| yield X_batch, y_batch, finished | |||||
| if counter == number_of_batches: | |||||
| counter = 0 | |||||
| finished = True | |||||
| def batch_generator(self, batch_size=1000, | |||||
| random_sample=False, shuffle_block=False): | |||||
| """ | |||||
| :param train_mode: True or false,false is eval_mode, | |||||
| :param batch_size | |||||
| :param num_of_parts: number of files | |||||
| :param random_sample: if True, will shuffle | |||||
| :param shuffle_block: shuffle file blocks at every round | |||||
| :return: | |||||
| """ | |||||
| for hdf_in, hdf_out, _ in self._iterate_hdf_files_(self._num_of_parts, | |||||
| shuffle_block): | |||||
| start = stop = None | |||||
| X_all = pd.read_hdf(hdf_in, start=start, stop=stop).values | |||||
| y_all = pd.read_hdf(hdf_out, start=start, stop=stop).values | |||||
| data_gen = self._generator(X_all, y_all, batch_size, | |||||
| shuffle=random_sample) | |||||
| finished = False | |||||
| while not finished: | |||||
| X, y, finished = data_gen.__next__() | |||||
| X_id = X[:, 0:self.max_length] | |||||
| X_va = X[:, self.max_length:] | |||||
| yield np.array(X_id.astype(dtype=np.int32)), \ | |||||
| np.array(X_va.astype(dtype=np.float32)), \ | |||||
| np.array(y.astype(dtype=np.float32)) | |||||
| def _get_h5_dataset(directory, train_mode=True, epochs=1, batch_size=1000): | |||||
| """ | |||||
| Get dataset with h5 format. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000) | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| data_para = {'batch_size': batch_size} | |||||
| if train_mode: | |||||
| data_para['random_sample'] = True | |||||
| data_para['shuffle_block'] = True | |||||
| h5_dataset = H5Dataset(data_path=directory, train_mode=train_mode) | |||||
| numbers_of_batch = math.ceil(h5_dataset.data_size / batch_size) | |||||
| def _iter_h5_data(): | |||||
| train_eval_gen = h5_dataset.batch_generator(**data_para) | |||||
| for _ in range(0, numbers_of_batch, 1): | |||||
| yield train_eval_gen.__next__() | |||||
| ds = de.GeneratorDataset(_iter_h5_data, ["ids", "weights", "labels"], num_samples=3000) | |||||
| ds = ds.repeat(epochs) | |||||
| return ds | |||||
| def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||||
| line_per_sample=1000, rank_size=None, rank_id=None): | |||||
| """ | |||||
| Get dataset with mindrecord format. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000). | |||||
| line_per_sample (int): The number of sample per line (default=1000). | |||||
| rank_size (int): The number of device, not necessary for single device (default=None). | |||||
| rank_id (int): Id of device, not necessary for single device (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| file_prefix_name = 'train_input_part.mindrecord' if train_mode else 'test_input_part.mindrecord' | |||||
| file_suffix_name = '00' if train_mode else '0' | |||||
| shuffle = train_mode | |||||
| if rank_size is not None and rank_id is not None: | |||||
| ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name), | |||||
| columns_list=['feat_ids', 'feat_vals', 'label'], | |||||
| num_shards=rank_size, shard_id=rank_id, shuffle=shuffle, | |||||
| num_parallel_workers=8) | |||||
| else: | |||||
| ds = de.MindDataset(os.path.join(directory, file_prefix_name + file_suffix_name), | |||||
| columns_list=['feat_ids', 'feat_vals', 'label'], | |||||
| shuffle=shuffle, num_parallel_workers=8) | |||||
| ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) | |||||
| ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39), | |||||
| np.array(y).flatten().reshape(batch_size, 39), | |||||
| np.array(z).flatten().reshape(batch_size, 1))), | |||||
| input_columns=['feat_ids', 'feat_vals', 'label'], | |||||
| columns_order=['feat_ids', 'feat_vals', 'label'], | |||||
| num_parallel_workers=8) | |||||
| ds = ds.repeat(epochs) | |||||
| return ds | |||||
| def _get_tf_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||||
| line_per_sample=1000, rank_size=None, rank_id=None): | |||||
| """ | |||||
| Get dataset with tfrecord format. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000). | |||||
| line_per_sample (int): The number of sample per line (default=1000). | |||||
| rank_size (int): The number of device, not necessary for single device (default=None). | |||||
| rank_id (int): Id of device, not necessary for single device (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| dataset_files = [] | |||||
| file_prefixt_name = 'train' if train_mode else 'test' | |||||
| shuffle = train_mode | |||||
| for (dir_path, _, filenames) in os.walk(directory): | |||||
| for filename in filenames: | |||||
| if file_prefixt_name in filename and 'tfrecord' in filename: | |||||
| dataset_files.append(os.path.join(dir_path, filename)) | |||||
| schema = de.Schema() | |||||
| schema.add_column('feat_ids', de_type=mstype.int32) | |||||
| schema.add_column('feat_vals', de_type=mstype.float32) | |||||
| schema.add_column('label', de_type=mstype.float32) | |||||
| if rank_size is not None and rank_id is not None: | |||||
| ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, | |||||
| schema=schema, num_parallel_workers=8, | |||||
| num_shards=rank_size, shard_id=rank_id, | |||||
| shard_equal_rows=True, num_samples=3000) | |||||
| else: | |||||
| ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, | |||||
| schema=schema, num_parallel_workers=8, num_samples=3000) | |||||
| ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) | |||||
| ds = ds.map(operations=(lambda x, y, z: ( | |||||
| np.array(x).flatten().reshape(batch_size, 39), | |||||
| np.array(y).flatten().reshape(batch_size, 39), | |||||
| np.array(z).flatten().reshape(batch_size, 1))), | |||||
| input_columns=['feat_ids', 'feat_vals', 'label'], | |||||
| column_order=['feat_ids', 'feat_vals', 'label'], | |||||
| num_parallel_workers=8) | |||||
| ds = ds.repeat(epochs) | |||||
| return ds | |||||
| def create_dataset(directory, train_mode=True, epochs=1, batch_size=1000, | |||||
| data_type=DataType.TFRECORD, line_per_sample=1000, | |||||
| rank_size=None, rank_id=None): | |||||
| """ | |||||
| Get dataset. | |||||
| Args: | |||||
| directory (str): Dataset directory. | |||||
| train_mode (bool): Whether dataset is use for train or eval (default=True). | |||||
| epochs (int): Dataset epoch size (default=1). | |||||
| batch_size (int): Dataset batch size (default=1000). | |||||
| data_type (DataType): The type of dataset which is one of H5, TFRECORE, MINDRECORD (default=TFRECORD). | |||||
| line_per_sample (int): The number of sample per line (default=1000). | |||||
| rank_size (int): The number of device, not necessary for single device (default=None). | |||||
| rank_id (int): Id of device, not necessary for single device (default=None). | |||||
| Returns: | |||||
| Dataset. | |||||
| """ | |||||
| if data_type == DataType.MINDRECORD: | |||||
| return _get_mindrecord_dataset(directory, train_mode, epochs, | |||||
| batch_size, line_per_sample, | |||||
| rank_size, rank_id) | |||||
| if data_type == DataType.TFRECORD: | |||||
| return _get_tf_dataset(directory, train_mode, epochs, batch_size, | |||||
| line_per_sample, rank_size=rank_size, rank_id=rank_id) | |||||
| if rank_size is not None and rank_size > 1: | |||||
| raise ValueError('Please use mindrecord dataset.') | |||||
| return _get_h5_dataset(directory, train_mode, epochs, batch_size) | |||||
| @@ -0,0 +1,370 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ test_training """ | |||||
| import os | |||||
| import numpy as np | |||||
| from sklearn.metrics import roc_auc_score | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.nn import Dropout | |||||
| from mindspore.nn.optim import Adam | |||||
| from mindspore.nn.metrics import Metric | |||||
| from mindspore import nn, ParameterTuple, Parameter | |||||
| from mindspore.common.initializer import Uniform, initializer, Normal | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | |||||
| from .callback import EvalCallBack, LossCallBack | |||||
| np_type = np.float32 | |||||
| ms_type = mstype.float32 | |||||
| class AUCMetric(Metric): | |||||
| """AUC metric for DeepFM model.""" | |||||
| def __init__(self): | |||||
| super(AUCMetric, self).__init__() | |||||
| self.pred_probs = [] | |||||
| self.true_labels = [] | |||||
| def clear(self): | |||||
| """Clear the internal evaluation result.""" | |||||
| self.pred_probs = [] | |||||
| self.true_labels = [] | |||||
| def update(self, *inputs): | |||||
| batch_predict = inputs[1].asnumpy() | |||||
| batch_label = inputs[2].asnumpy() | |||||
| self.pred_probs.extend(batch_predict.flatten().tolist()) | |||||
| self.true_labels.extend(batch_label.flatten().tolist()) | |||||
| def eval(self): | |||||
| if len(self.true_labels) != len(self.pred_probs): | |||||
| raise RuntimeError('true_labels.size() is not equal to pred_probs.size()') | |||||
| auc = roc_auc_score(self.true_labels, self.pred_probs) | |||||
| return auc | |||||
| def init_method(method, shape, name, max_val=0.01): | |||||
| """ | |||||
| The method of init parameters. | |||||
| Args: | |||||
| method (str): The method uses to initialize parameter. | |||||
| shape (list): The shape of parameter. | |||||
| name (str): The name of parameter. | |||||
| max_val (float): Max value in parameter when uses 'random' or 'uniform' to initialize parameter. | |||||
| Returns: | |||||
| Parameter. | |||||
| """ | |||||
| if method in ['random', 'uniform']: | |||||
| params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name) | |||||
| elif method == "one": | |||||
| params = Parameter(initializer("ones", shape, ms_type), name=name) | |||||
| elif method == 'zero': | |||||
| params = Parameter(initializer("zeros", shape, ms_type), name=name) | |||||
| elif method == "normal": | |||||
| params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name) | |||||
| return params | |||||
| def init_var_dict(init_args, values): | |||||
| """ | |||||
| Init parameter. | |||||
| Args: | |||||
| init_args (list): Define max and min value of parameters. | |||||
| values (list): Define name, shape and init method of parameters. | |||||
| Returns: | |||||
| dict, a dict ot Parameter. | |||||
| """ | |||||
| var_map = {} | |||||
| _, _max_val = init_args | |||||
| for key, shape, init_flag in values: | |||||
| if key not in var_map.keys(): | |||||
| if init_flag in ['random', 'uniform']: | |||||
| var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key) | |||||
| elif init_flag == "one": | |||||
| var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key) | |||||
| elif init_flag == "zero": | |||||
| var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key) | |||||
| elif init_flag == 'normal': | |||||
| var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key) | |||||
| return var_map | |||||
| class DenseLayer(nn.Cell): | |||||
| """ | |||||
| Dense Layer for Deep Layer of DeepFM Model; | |||||
| Containing: activation, matmul, bias_add; | |||||
| Args: | |||||
| input_dim (int): the shape of weight at 0-aixs; | |||||
| output_dim (int): the shape of weight at 1-aixs, and shape of bias | |||||
| weight_bias_init (list): weight and bias init method, "random", "uniform", "one", "zero", "normal"; | |||||
| act_str (str): activation function method, "relu", "sigmoid", "tanh"; | |||||
| keep_prob (float): Dropout Layer keep_prob_rate; | |||||
| scale_coef (float): input scale coefficient; | |||||
| """ | |||||
| def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0): | |||||
| super(DenseLayer, self).__init__() | |||||
| weight_init, bias_init = weight_bias_init | |||||
| self.weight = init_method(weight_init, [input_dim, output_dim], name="weight") | |||||
| self.bias = init_method(bias_init, [output_dim], name="bias") | |||||
| self.act_func = self._init_activation(act_str) | |||||
| self.matmul = P.MatMul(transpose_b=False) | |||||
| self.bias_add = P.BiasAdd() | |||||
| self.cast = P.Cast() | |||||
| self.dropout = Dropout(keep_prob=keep_prob) | |||||
| self.mul = P.Mul() | |||||
| self.realDiv = P.RealDiv() | |||||
| self.scale_coef = scale_coef | |||||
| def _init_activation(self, act_str): | |||||
| act_str = act_str.lower() | |||||
| if act_str == "relu": | |||||
| act_func = P.ReLU() | |||||
| elif act_str == "sigmoid": | |||||
| act_func = P.Sigmoid() | |||||
| elif act_str == "tanh": | |||||
| act_func = P.Tanh() | |||||
| return act_func | |||||
| def construct(self, x): | |||||
| x = self.act_func(x) | |||||
| if self.training: | |||||
| x = self.dropout(x) | |||||
| x = self.mul(x, self.scale_coef) | |||||
| x = self.cast(x, mstype.float16) | |||||
| weight = self.cast(self.weight, mstype.float16) | |||||
| wx = self.matmul(x, weight) | |||||
| wx = self.cast(wx, mstype.float32) | |||||
| wx = self.realDiv(wx, self.scale_coef) | |||||
| output = self.bias_add(wx, self.bias) | |||||
| return output | |||||
| class DeepFMModel(nn.Cell): | |||||
| """ | |||||
| From paper: "DeepFM: A Factorization-Machine based Neural Network for CTR Prediction" | |||||
| Args: | |||||
| batch_size (int): smaple_number of per step in training; (int, batch_size=128) | |||||
| filed_size (int): input filed number, or called id_feature number; (int, filed_size=39) | |||||
| vocab_size (int): id_feature vocab size, id dict size; (int, vocab_size=200000) | |||||
| emb_dim (int): id embedding vector dim, id mapped to embedding vector; (int, emb_dim=100) | |||||
| deep_layer_args (list): Deep Layer args, layer_dim_list, layer_activator; | |||||
| (int, deep_layer_args=[[100, 100, 100], "relu"]) | |||||
| init_args (list): init args for Parameter init; (list, init_args=[min, max, seeds]) | |||||
| weight_bias_init (list): weight, bias init method for deep layers; | |||||
| (list[str], weight_bias_init=['random', 'zero']) | |||||
| keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8) | |||||
| """ | |||||
| def __init__(self, config): | |||||
| super(DeepFMModel, self).__init__() | |||||
| self.batch_size = config.batch_size | |||||
| self.field_size = config.data_field_size | |||||
| self.vocab_size = config.data_vocab_size | |||||
| self.emb_dim = config.data_emb_dim | |||||
| self.deep_layer_dims_list, self.deep_layer_act = config.deep_layer_args | |||||
| self.init_args = config.init_args | |||||
| self.weight_bias_init = config.weight_bias_init | |||||
| self.keep_prob = config.keep_prob | |||||
| init_acts = [('W_l2', [self.vocab_size, 1], 'normal'), | |||||
| ('V_l2', [self.vocab_size, self.emb_dim], 'normal'), | |||||
| ('b', [1], 'normal')] | |||||
| var_map = init_var_dict(self.init_args, init_acts) | |||||
| self.fm_w = var_map["W_l2"] | |||||
| self.fm_b = var_map["b"] | |||||
| self.embedding_table = var_map["V_l2"] | |||||
| # Deep Layers | |||||
| self.deep_input_dims = self.field_size * self.emb_dim + 1 | |||||
| self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1] | |||||
| self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], | |||||
| self.weight_bias_init, self.deep_layer_act, self.keep_prob) | |||||
| # FM, linear Layers | |||||
| self.Gatherv2 = P.GatherV2() | |||||
| self.Mul = P.Mul() | |||||
| self.ReduceSum = P.ReduceSum(keep_dims=False) | |||||
| self.Reshape = P.Reshape() | |||||
| self.Square = P.Square() | |||||
| self.Shape = P.Shape() | |||||
| self.Tile = P.Tile() | |||||
| self.Concat = P.Concat(axis=1) | |||||
| self.Cast = P.Cast() | |||||
| def construct(self, id_hldr, wt_hldr): | |||||
| """ | |||||
| Args: | |||||
| id_hldr: batch ids; [bs, field_size] | |||||
| wt_hldr: batch weights; [bs, field_size] | |||||
| """ | |||||
| mask = self.Reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | |||||
| # Linear layer | |||||
| fm_id_weight = self.Gatherv2(self.fm_w, id_hldr, 0) | |||||
| wx = self.Mul(fm_id_weight, mask) | |||||
| linear_out = self.ReduceSum(wx, 1) | |||||
| # FM layer | |||||
| fm_id_embs = self.Gatherv2(self.embedding_table, id_hldr, 0) | |||||
| vx = self.Mul(fm_id_embs, mask) | |||||
| v1 = self.ReduceSum(vx, 1) | |||||
| v1 = self.Square(v1) | |||||
| v2 = self.Square(vx) | |||||
| v2 = self.ReduceSum(v2, 1) | |||||
| fm_out = 0.5 * self.ReduceSum(v1 - v2, 1) | |||||
| fm_out = self.Reshape(fm_out, (-1, 1)) | |||||
| # Deep layer | |||||
| b = self.Reshape(self.fm_b, (1, 1)) | |||||
| b = self.Tile(b, (self.batch_size, 1)) | |||||
| deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim)) | |||||
| deep_in = self.Concat((deep_in, b)) | |||||
| deep_in = self.dense_layer_1(deep_in) | |||||
| deep_in = self.dense_layer_2(deep_in) | |||||
| deep_in = self.dense_layer_3(deep_in) | |||||
| deep_out = self.dense_layer_4(deep_in) | |||||
| out = linear_out + fm_out + deep_out | |||||
| return out, fm_id_weight, fm_id_embs | |||||
| class NetWithLossClass(nn.Cell): | |||||
| """ | |||||
| NetWithLossClass definition. | |||||
| """ | |||||
| def __init__(self, network, l2_coef=1e-6): | |||||
| super(NetWithLossClass, self).__init__(auto_prefix=False) | |||||
| self.loss = P.SigmoidCrossEntropyWithLogits() | |||||
| self.network = network | |||||
| self.l2_coef = l2_coef | |||||
| self.Square = P.Square() | |||||
| self.ReduceMean_false = P.ReduceMean(keep_dims=False) | |||||
| self.ReduceSum_false = P.ReduceSum(keep_dims=False) | |||||
| def construct(self, batch_ids, batch_wts, label): | |||||
| predict, fm_id_weight, fm_id_embs = self.network(batch_ids, batch_wts) | |||||
| log_loss = self.loss(predict, label) | |||||
| mean_log_loss = self.ReduceMean_false(log_loss) | |||||
| l2_loss_w = self.ReduceSum_false(self.Square(fm_id_weight)) | |||||
| l2_loss_v = self.ReduceSum_false(self.Square(fm_id_embs)) | |||||
| l2_loss_all = self.l2_coef * (l2_loss_v + l2_loss_w) * 0.5 | |||||
| loss = mean_log_loss + l2_loss_all | |||||
| return loss | |||||
| class TrainStepWrap(nn.Cell): | |||||
| """ | |||||
| TrainStepWrap definition | |||||
| """ | |||||
| def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0): | |||||
| super(TrainStepWrap, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.network.set_train() | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = Adam(self.weights, learning_rate=lr, eps=eps, loss_scale=loss_scale) | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| self.sens = loss_scale | |||||
| def construct(self, batch_ids, batch_wts, label): | |||||
| weights = self.weights | |||||
| loss = self.network(batch_ids, batch_wts, label) | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) # | |||||
| grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens) | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| class PredictWithSigmoid(nn.Cell): | |||||
| """ | |||||
| Eval model with sigmoid. | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(PredictWithSigmoid, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.sigmoid = P.Sigmoid() | |||||
| def construct(self, batch_ids, batch_wts, labels): | |||||
| logits, _, _, = self.network(batch_ids, batch_wts) | |||||
| pred_probs = self.sigmoid(logits) | |||||
| return logits, pred_probs, labels | |||||
| class ModelBuilder: | |||||
| """ | |||||
| Model builder for DeepFM. | |||||
| Args: | |||||
| model_config (ModelConfig): Model configuration. | |||||
| train_config (TrainConfig): Train configuration. | |||||
| """ | |||||
| def __init__(self, model_config, train_config): | |||||
| self.model_config = model_config | |||||
| self.train_config = train_config | |||||
| def get_callback_list(self, model=None, eval_dataset=None): | |||||
| """ | |||||
| Get callbacks which contains checkpoint callback, eval callback and loss callback. | |||||
| Args: | |||||
| model (Cell): The network is added callback (default=None). | |||||
| eval_dataset (Dataset): Dataset for eval (default=None). | |||||
| """ | |||||
| callback_list = [] | |||||
| if self.train_config.save_checkpoint: | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=self.train_config.save_checkpoint_steps, | |||||
| keep_checkpoint_max=self.train_config.keep_checkpoint_max) | |||||
| ckpt_cb = ModelCheckpoint(prefix=self.train_config.ckpt_file_name_prefix, | |||||
| directory=self.train_config.output_path, | |||||
| config=config_ck) | |||||
| callback_list.append(ckpt_cb) | |||||
| if self.train_config.eval_callback: | |||||
| if model is None: | |||||
| raise RuntimeError("train_config.eval_callback is {}; get_callback_list() args model is {}".format( | |||||
| self.train_config.eval_callback, model)) | |||||
| if eval_dataset is None: | |||||
| raise RuntimeError("train_config.eval_callback is {}; get_callback_list() " | |||||
| "args eval_dataset is {}".format(self.train_config.eval_callback, eval_dataset)) | |||||
| auc_metric = AUCMetric() | |||||
| eval_callback = EvalCallBack(model, eval_dataset, auc_metric, | |||||
| eval_file_path=os.path.join(self.train_config.output_path, | |||||
| self.train_config.eval_file_name)) | |||||
| callback_list.append(eval_callback) | |||||
| if self.train_config.loss_callback: | |||||
| loss_callback = LossCallBack(loss_file_path=os.path.join(self.train_config.output_path, | |||||
| self.train_config.loss_file_name)) | |||||
| callback_list.append(loss_callback) | |||||
| if callback_list: | |||||
| return callback_list | |||||
| return None | |||||
| def get_train_eval_net(self): | |||||
| deepfm_net = DeepFMModel(self.model_config) | |||||
| loss_net = NetWithLossClass(deepfm_net, l2_coef=self.train_config.l2_coef) | |||||
| train_net = TrainStepWrap(loss_net, lr=self.train_config.learning_rate, | |||||
| eps=self.train_config.epsilon, | |||||
| loss_scale=self.train_config.loss_scale) | |||||
| eval_net = PredictWithSigmoid(deepfm_net) | |||||
| return train_net, eval_net | |||||
| @@ -0,0 +1,80 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """train_criteo.""" | |||||
| import os | |||||
| import pytest | |||||
| from mindspore import context | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.common import set_seed | |||||
| from src.deepfm import ModelBuilder, AUCMetric | |||||
| from src.config import DataConfig, ModelConfig, TrainConfig | |||||
| from src.dataset import create_dataset, DataType | |||||
| from src.callback import EvalCallBack, LossCallBack, TimeMonitor | |||||
| set_seed(1) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_deepfm(): | |||||
| data_config = DataConfig() | |||||
| train_config = TrainConfig() | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) | |||||
| rank_size = None | |||||
| rank_id = None | |||||
| dataset_path = "/home/workspace/mindspore_dataset/criteo_data/criteo_h5/" | |||||
| print("dataset_path:", dataset_path) | |||||
| ds_train = create_dataset(dataset_path, | |||||
| train_mode=True, | |||||
| epochs=1, | |||||
| batch_size=train_config.batch_size, | |||||
| data_type=DataType(data_config.data_format), | |||||
| rank_size=rank_size, | |||||
| rank_id=rank_id) | |||||
| model_builder = ModelBuilder(ModelConfig, TrainConfig) | |||||
| train_net, eval_net = model_builder.get_train_eval_net() | |||||
| auc_metric = AUCMetric() | |||||
| model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) | |||||
| loss_file_name = './loss.log' | |||||
| time_callback = TimeMonitor(data_size=ds_train.get_dataset_size()) | |||||
| loss_callback = LossCallBack(loss_file_path=loss_file_name) | |||||
| callback_list = [time_callback, loss_callback] | |||||
| eval_file_name = './auc.log' | |||||
| ds_eval = create_dataset(dataset_path, train_mode=False, | |||||
| epochs=1, | |||||
| batch_size=train_config.batch_size, | |||||
| data_type=DataType(data_config.data_format)) | |||||
| eval_callback = EvalCallBack(model, ds_eval, auc_metric, | |||||
| eval_file_path=eval_file_name) | |||||
| callback_list.append(eval_callback) | |||||
| print("train_config.train_epochs:", train_config.train_epochs) | |||||
| model.train(train_config.train_epochs, ds_train, callbacks=callback_list) | |||||
| export_loss_value = 0.51 | |||||
| print("loss_callback.loss:", loss_callback.loss) | |||||
| assert loss_callback.loss < export_loss_value | |||||
| export_per_step_time = 10.4 | |||||
| print("time_callback:", time_callback.per_step_time) | |||||
| assert time_callback.per_step_time < export_per_step_time | |||||
| print("*******test case pass!********") | |||||