| @@ -0,0 +1,268 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """ | |||
| Criteo data process | |||
| """ | |||
| import os | |||
| import pickle | |||
| import collections | |||
| import argparse | |||
| import numpy as np | |||
| import pandas as pd | |||
| TRAIN_LINE_COUNT = 45840617 | |||
| TEST_LINE_COUNT = 6042135 | |||
| class CriteoStatsDict(): | |||
| """create data dict""" | |||
| def __init__(self): | |||
| self.field_size = 39 # value_1-13; cat_1-26; | |||
| self.val_cols = ["val_{}".format(i+1) for i in range(13)] | |||
| self.cat_cols = ["cat_{}".format(i+1) for i in range(26)] | |||
| # | |||
| self.val_min_dict = {col: 0 for col in self.val_cols} | |||
| self.val_max_dict = {col: 0 for col in self.val_cols} | |||
| self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols} | |||
| # | |||
| self.oov_prefix = "OOV_" | |||
| self.cat2id_dict = {} | |||
| self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)}) | |||
| self.cat2id_dict.update({self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)}) | |||
| # | |||
| def stats_vals(self, val_list): | |||
| """vals status""" | |||
| assert len(val_list) == len(self.val_cols) | |||
| def map_max_min(i, val): | |||
| key = self.val_cols[i] | |||
| if val != "": | |||
| if float(val) > self.val_max_dict[key]: | |||
| self.val_max_dict[key] = float(val) | |||
| if float(val) < self.val_min_dict[key]: | |||
| self.val_min_dict[key] = float(val) | |||
| # | |||
| for i, val in enumerate(val_list): | |||
| map_max_min(i, val) | |||
| # | |||
| def stats_cats(self, cat_list): | |||
| assert len(cat_list) == len(self.cat_cols) | |||
| def map_cat_count(i, cat): | |||
| key = self.cat_cols[i] | |||
| self.cat_count_dict[key][cat] += 1 | |||
| # | |||
| for i, cat in enumerate(cat_list): | |||
| map_cat_count(i, cat) | |||
| # | |||
| def save_dict(self, output_path, prefix=""): | |||
| with open(os.path.join(output_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt: | |||
| pickle.dump(self.val_max_dict, file_wrt) | |||
| with open(os.path.join(output_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt: | |||
| pickle.dump(self.val_min_dict, file_wrt) | |||
| with open(os.path.join(output_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt: | |||
| pickle.dump(self.cat_count_dict, file_wrt) | |||
| # | |||
| def load_dict(self, dict_path, prefix=""): | |||
| with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt: | |||
| self.val_max_dict = pickle.load(file_wrt) | |||
| with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt: | |||
| self.val_min_dict = pickle.load(file_wrt) | |||
| with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt: | |||
| self.cat_count_dict = pickle.load(file_wrt) | |||
| print("val_max_dict.items()[:50]: {}".format(list(self.val_max_dict.items()))) | |||
| print("val_min_dict.items()[:50]: {}".format(list(self.val_min_dict.items()))) | |||
| # | |||
| # | |||
| def get_cat2id(self, threshold=100): | |||
| """get cat to id""" | |||
| # before_all_count = 0 | |||
| # after_all_count = 0 | |||
| for key, cat_count_d in self.cat_count_dict.items(): | |||
| new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items())) | |||
| for cat_str, _ in new_cat_count_d.items(): | |||
| self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict) | |||
| # print("before_all_count: {}".format(before_all_count)) # before_all_count: 33762577 | |||
| # print("after_all_count: {}".format(after_all_count)) # after_all_count: 184926 | |||
| print("cat2id_dict.size: {}".format(len(self.cat2id_dict))) | |||
| print("cat2id_dict.items()[:50]: {}".format(self.cat2id_dict.items()[:50])) | |||
| # | |||
| def map_cat2id(self, values, cats): | |||
| """map cat to id""" | |||
| def minmax_sclae_value(i, val): | |||
| # min_v = float(self.val_min_dict["val_{}".format(i+1)]) | |||
| max_v = float(self.val_max_dict["val_{}".format(i + 1)]) | |||
| # return (float(val) - min_v) * 1.0 / (max_v - min_v) | |||
| return float(val) * 1.0 / max_v | |||
| id_list = [] | |||
| weight_list = [] | |||
| for i, val in enumerate(values): | |||
| if val == "": | |||
| id_list.append(i) | |||
| weight_list.append(0) | |||
| else: | |||
| key = "val_{}".format(i + 1) | |||
| id_list.append(self.cat2id_dict[key]) | |||
| weight_list.append(minmax_sclae_value(i, float(val))) | |||
| # | |||
| for i, cat_str in enumerate(cats): | |||
| key = "cat_{}".format(i + 1) + "_" + cat_str | |||
| if key in self.cat2id_dict: | |||
| id_list.append(self.cat2id_dict[key]) | |||
| else: | |||
| id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)]) | |||
| weight_list.append(1.0) | |||
| return id_list, weight_list | |||
| # | |||
| def mkdir_path(file_path): | |||
| if not os.path.exists(file_path): | |||
| os.makedirs(file_path) | |||
| # | |||
| def statsdata(data_file_path, output_path, criteo_stats): | |||
| """data status""" | |||
| with open(data_file_path, encoding="utf-8") as file_in: | |||
| errorline_list = [] | |||
| count = 0 | |||
| for line in file_in: | |||
| count += 1 | |||
| line = line.strip("\n") | |||
| items = line.strip("\t") | |||
| if len(items) != 40: | |||
| errorline_list.append(count) | |||
| print("line: {}".format(line)) | |||
| continue | |||
| if count % 1000000 == 0: | |||
| print("Have handle {}w lines.".format(count//10000)) | |||
| # if count % 5000000 == 0: | |||
| # print("Have handle {}w lines.".format(count//10000)) | |||
| # label = items[0] | |||
| values = items[1:14] | |||
| cats = items[14:] | |||
| assert len(values) == 13, "value.size: {}".format(len(values)) | |||
| assert len(cats) == 26, "cat.size: {}".format(len(cats)) | |||
| criteo_stats.stats_vals(values) | |||
| criteo_stats.stats_cats(cats) | |||
| criteo_stats.save_dict(output_path) | |||
| # | |||
| def add_write(file_path, wr_str): | |||
| with open(file_path, "a", encoding="utf-8") as file_out: | |||
| file_out.write(wr_str + "\n") | |||
| # | |||
| def random_split_trans2h5(in_file_path, output_path, criteo_stats, part_rows=2000000, test_size=0.1, seed=2020): | |||
| """random split trans2h5""" | |||
| test_size = int(TRAIN_LINE_COUNT * test_size) | |||
| # train_size = TRAIN_LINE_COUNT - test_size | |||
| all_indices = [i for i in range(TRAIN_LINE_COUNT)] | |||
| np.random.seed(seed) | |||
| np.random.shuffle(all_indices) | |||
| print("all_indices.size: {}".format(len(all_indices))) | |||
| # lines_count_dict = collections.defaultdict(int) | |||
| test_indices_set = set(all_indices[:test_size]) | |||
| print("test_indices_set.size: {}".format(len(test_indices_set))) | |||
| print("------" * 10 + "\n" * 2) | |||
| train_feature_file_name = os.path.join(output_path, "train_input_part_{}.h5") | |||
| train_label_file_name = os.path.join(output_path, "train_output_part_{}.h5") | |||
| test_feature_file_name = os.path.join(output_path, "test_input_part_{}.h5") | |||
| test_label_file_name = os.path.join(output_path, "test_input_part_{}.h5") | |||
| train_feature_list = [] | |||
| train_label_list = [] | |||
| test_feature_list = [] | |||
| test_label_list = [] | |||
| with open(in_file_path, encoding="utf-8") as file_in: | |||
| count = 0 | |||
| train_part_number = 0 | |||
| test_part_number = 0 | |||
| for i, line in enumerate(file_in): | |||
| count += 1 | |||
| if count % 1000000 == 0: | |||
| print("Have handle {}w lines.".format(count // 10000)) | |||
| line = line.strip("\n") | |||
| items = line.split("\t") | |||
| if len(items) != 40: | |||
| continue | |||
| label = float(items[0]) | |||
| values = items[1:14] | |||
| cats = items[14:] | |||
| assert len(values) == 13, "value.size: {}".format(len(values)) | |||
| assert len(cats) == 26, "cat.size: {}".format(len(cats)) | |||
| ids, wts = criteo_stats.map_cat2id(values, cats) | |||
| if i not in test_indices_set: | |||
| train_feature_list.append(ids + wts) | |||
| train_label_list.append(label) | |||
| else: | |||
| test_feature_list.append(ids + wts) | |||
| test_label_list.append(label) | |||
| if train_label_list and (len(train_label_list) % part_rows == 0): | |||
| pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), | |||
| key="fixed") | |||
| pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), | |||
| key="fixed") | |||
| train_feature_list = [] | |||
| train_label_list = [] | |||
| train_part_number += 1 | |||
| if test_label_list and (len(test_label_list) % part_rows == 0): | |||
| pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), | |||
| key="fixed") | |||
| pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), | |||
| key="fixed") | |||
| test_feature_list = [] | |||
| test_label_list = [] | |||
| test_part_number += 1 | |||
| # | |||
| if train_label_list: | |||
| pd.DataFrame(np.asarray(train_feature_list)).to_hdf(train_feature_file_name.format(train_part_number), | |||
| key="fixed") | |||
| pd.DataFrame(np.asarray(train_label_list)).to_hdf(train_label_file_name.format(train_part_number), | |||
| key="fixed") | |||
| if test_label_list: | |||
| pd.DataFrame(np.asarray(test_feature_list)).to_hdf(test_feature_file_name.format(test_part_number), | |||
| key="fixed") | |||
| pd.DataFrame(np.asarray(test_label_list)).to_hdf(test_label_file_name.format(test_part_number), | |||
| key="fixed") | |||
| # | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser(description="Get and Process datasets") | |||
| parser.add_argument("--raw_data_path", default="/opt/npu/data/origin_criteo_data/", help="The path to save dataset") | |||
| parser.add_argument("--output_path", default="/opt/npu/data/origin_criteo_data/h5_data/", | |||
| help="The path to save dataset") | |||
| args, _ = parser.parse_known_args() | |||
| base_path = args.raw_data_path | |||
| criteo_stat = CriteoStatsDict() | |||
| # step 1, stats the vocab and normalize value | |||
| datafile_path = base_path + "train_small.txt" | |||
| stats_out_path = base_path + "stats_dict/" | |||
| mkdir_path(stats_out_path) | |||
| statsdata(datafile_path, stats_out_path, criteo_stat) | |||
| print("------" * 10) | |||
| criteo_stat.load_dict(dict_path=stats_out_path, prefix="") | |||
| criteo_stat.get_cat2id(threshold=100) | |||
| # step 2, transform data trans2h5; version 2: np.random.shuffle | |||
| infile_path = base_path + "train_small.txt" | |||
| mkdir_path(args.output_path) | |||
| random_split_trans2h5(infile_path, args.output_path, criteo_stat, part_rows=2000000, test_size=0.1, seed=2020) | |||