@@ -22,17 +22,18 @@ import tarfile
import numpy as np
from mindspore.mindrecord import FileWriter
TRAIN_LINE_COUNT = 45840617
TEST_LINE_COUNT = 6042135
class CriteoStatsDict():
class StatsDict():
"""preprocessed data"""
def __init__(self):
self.field_size = 39
self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
def __init__(self, field_size, dense_dim, slot_dim, skip_id_convert):
self.field_size = field_size
self.dense_dim = dense_dim
self.slot_dim = slot_dim
self.skip_id_convert = bool(skip_id_convert)
self.val_cols = ["val_{}".format(i + 1) for i in range(self.dense_dim)]
self.cat_cols = ["cat_{}".format(i + 1) for i in range(self.slot_dim)]
self.val_min_dict = {col: 0 for col in self.val_cols}
self.val_max_dict = {col: 0 for col in self.val_cols}
@@ -120,7 +121,14 @@ class CriteoStatsDict():
for i, cat_str in enumerate(cats):
key = "cat_{}".format(i + 1) + "_" + cat_str
if key in self.cat2id_dict:
id_list.append(self.cat2id_dict[key])
if self.skip_id_convert is True:
# For the synthetic data, if the generated id is between [0, max_vcoab], but the num examples is l
# ess than vocab_size/ slot_nums the id will still be converted to [0, real_vocab], where real_vocab
# the actually the vocab size, rather than the max_vocab. So a simple way to alleviate this
# problem is skip the id convert, regarding the synthetic data id as the final id.
id_list.append(cat_str)
else:
id_list.append(self.cat2id_dict[key])
else:
id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
weight_list.append(1.0)
@@ -132,7 +140,7 @@ def mkdir_path(file_path):
os.makedirs(file_path)
def statsdata(file_path, dict_output_path, criteo_stats_dict):
def statsdata(file_path, dict_output_path, criteo_stats_dict, dense_dim=13, slot_dim=26 ):
"""Preprocess data and save data"""
with open(file_path, encoding="utf-8") as file_in:
errorline_list = []
@@ -141,28 +149,31 @@ def statsdata(file_path, dict_output_path, criteo_stats_dict):
count += 1
line = line.strip("\n")
items = line.split("\t")
if len(items) != 40 :
if len(items) != (dense_dim + slot_dim + 1) :
errorline_list.append(count)
print("line: {}".format(line))
print("Found line length: {}, suppose to be {}, the line is {}".format(len(items),
dense_dim + slot_dim + 1, line))
continue
if count % 1000000 == 0:
print("Have handled {}w lines.".format(count // 10000))
values = items[1:14 ]
cats = items[14 :]
values = items[1: dense_dim + 1]
cats = items[dense_dim + 1:]
assert len(values) == 13 , "values.size: {}".format(len(values))
assert len(cats) == 26 , "cats.size: {}".format(len(cats))
assert len(values) == dense_dim , "values.size: {}".format(len(values))
assert len(cats) == slot_dim , "cats.size: {}".format(len(cats))
criteo_stats_dict.stats_vals(values)
criteo_stats_dict.stats_cats(cats)
criteo_stats_dict.save_dict(dict_output_path)
def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000,
line_per_sample=1000,
test_size=0.1, seed=2020):
line_per_sample=1000, train_line_count=None,
test_size=0.1, seed=2020, dense_dim=13, slot_dim=26 ):
"""Random split data and save mindrecord"""
test_size = int(TRAIN_LINE_COUNT * test_size)
all_indices = [i for i in range(TRAIN_LINE_COUNT)]
if train_line_count is None:
raise ValueError("Please provide training file line count")
test_size = int(train_line_count * test_size)
all_indices = [i for i in range(train_line_count)]
np.random.seed(seed)
np.random.shuffle(all_indices)
print("all_indices.size:{}".format(len(all_indices)))
@@ -195,15 +206,15 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
print("Have handle {}w lines.".format(count // 10000))
line = line.strip("\n")
items = line.split("\t")
if len(items) != 40 :
if len(items) != (1 + dense_dim + slot_dim) :
items_error_size_lineCount.append(i)
continue
label = float(items[0])
values = items[1:14 ]
cats = items[14 :]
values = items[1:1 + dense_dim ]
cats = items[1 + dense_dim :]
assert len(values) == 13 , "values.size: {}".format(len(values))
assert len(cats) == 26 , "cats.size: {}".format(len(cats))
assert len(values) == dense_dim , "values.size: {}".format(len(values))
assert len(cats) == slot_dim , "cats.size: {}".format(len(cats))
ids, wts = criteo_stats_dict.map_cat2id(values, cats)
@@ -251,35 +262,48 @@ def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stat
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="criteo data")
parser.add_argument("--data_path", type=str, default="./criteo_data/")
parser.add_argument("--data_type", type=str, default='criteo', choices=['criteo', 'synthetic'],
help='Currently we support criteo dataset and synthetic dataset')
parser.add_argument("--data_path", type=str, default="./criteo_data/", help='The path of the data file')
parser.add_argument("--dense_dim", type=int, default=13, help='The number of your continues fields')
parser.add_argument("--slot_dim", type=int, default=26,
help='The number of your sparse fields, it can also be called catelogy features.')
parser.add_argument("--threshold", type=int, default=100,
help='Word frequency below this will be regarded as OOV. It aims to reduce the vocab size')
parser.add_argument("--train_line_count", type=int, help='The number of examples in your dataset')
parser.add_argument("--skip_id_convert", type=int, default=0, choices=[0, 1],
help='Skip the id convert, regarding the original id as the final id.')
args, _ = parser.parse_known_args()
data_path = args.data_path
download_data_path = data_path + "origin_data/"
mkdir_path(download_data_path)
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
file_name = download_data_path + '/' + url.split('/')[-1]
urllib.request.urlretrieve(url, filename=file_name)
tar = tarfile.open(file_name)
names = tar.getnames()
for name in names:
tar.extract(name, path=download_data_path)
tar.close()
criteo_stats = CriteoStatsDict()
if args.data_type == 'criteo':
download_data_path = data_path + "origin_data/"
mkdir_path(download_data_path)
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
file_name = download_data_path + '/' + url.split('/')[-1]
urllib.request.urlretrieve(url, filename=file_name)
tar = tarfile.open(file_name)
names = tar.getnames()
for name in names:
tar.extract(name, path=download_data_path)
tar.close()
target_field_size = args.dense_dim + args.slot_dim
stats = StatsDict(field_size=target_field_size, dense_dim=args.dense_dim, slot_dim=args.slot_dim,
skip_id_convert=args.skip_id_convert)
data_file_path = data_path + "origin_data/train.txt"
stats_output_path = data_path + "stats_dict/"
mkdir_path(stats_output_path)
statsdata(data_file_path, stats_output_path, criteo_stats )
statsdata(data_file_path, stats_output_path, stats, dense_dim=args.dense_dim, slot_dim=args.slot_dim )
criteo_ stats.load_dict(dict_path=stats_output_path, prefix="")
criteo_stats.get_cat2id(threshold=100 )
stats.load_dict(dict_path=stats_output_path, prefix="")
stats.get_cat2id(threshold=args.threshold )
in_file_path = data_path + "origin_data/train.txt"
output_path = data_path + "mindrecord/"
mkdir_path(output_path)
random_split_trans2mindrecord(in_file_path, output_path, criteo_stats, part_rows=2000000, line_per_sample=1000,
test_size=0.1, seed=2020)
random_split_trans2mindrecord(in_file_path, output_path, stats, part_rows=2000000,
train_line_count=args.train_line_count, line_per_sample=1000,
test_size=0.1, seed=2020, dense_dim=args.dense_dim, slot_dim=args.slot_dim)