You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

preprocess_data.py 11 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Download raw data and preprocessed data."""
  16. import os
  17. import pickle
  18. import collections
  19. import argparse
  20. import numpy as np
  21. from mindspore.mindrecord import FileWriter
  22. TRAIN_LINE_COUNT = 45840617
  23. TEST_LINE_COUNT = 6042135
  24. class CriteoStatsDict():
  25. """preprocessed data"""
  26. def __init__(self):
  27. self.field_size = 39
  28. self.val_cols = ["val_{}".format(i + 1) for i in range(13)]
  29. self.cat_cols = ["cat_{}".format(i + 1) for i in range(26)]
  30. self.val_min_dict = {col: 0 for col in self.val_cols}
  31. self.val_max_dict = {col: 0 for col in self.val_cols}
  32. self.cat_count_dict = {col: collections.defaultdict(int) for col in self.cat_cols}
  33. self.oov_prefix = "OOV"
  34. self.cat2id_dict = {}
  35. self.cat2id_dict.update({col: i for i, col in enumerate(self.val_cols)})
  36. self.cat2id_dict.update(
  37. {self.oov_prefix + col: i + len(self.val_cols) for i, col in enumerate(self.cat_cols)})
  38. def stats_vals(self, val_list):
  39. """Handling weights column"""
  40. assert len(val_list) == len(self.val_cols)
  41. def map_max_min(i, val):
  42. key = self.val_cols[i]
  43. if val != "":
  44. if float(val) > self.val_max_dict[key]:
  45. self.val_max_dict[key] = float(val)
  46. if float(val) < self.val_min_dict[key]:
  47. self.val_min_dict[key] = float(val)
  48. for i, val in enumerate(val_list):
  49. map_max_min(i, val)
  50. def stats_cats(self, cat_list):
  51. """Handling cats column"""
  52. assert len(cat_list) == len(self.cat_cols)
  53. def map_cat_count(i, cat):
  54. key = self.cat_cols[i]
  55. self.cat_count_dict[key][cat] += 1
  56. for i, cat in enumerate(cat_list):
  57. map_cat_count(i, cat)
  58. def save_dict(self, dict_path, prefix=""):
  59. with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "wb") as file_wrt:
  60. pickle.dump(self.val_max_dict, file_wrt)
  61. with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "wb") as file_wrt:
  62. pickle.dump(self.val_min_dict, file_wrt)
  63. with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "wb") as file_wrt:
  64. pickle.dump(self.cat_count_dict, file_wrt)
  65. def load_dict(self, dict_path, prefix=""):
  66. with open(os.path.join(dict_path, "{}val_max_dict.pkl".format(prefix)), "rb") as file_wrt:
  67. self.val_max_dict = pickle.load(file_wrt)
  68. with open(os.path.join(dict_path, "{}val_min_dict.pkl".format(prefix)), "rb") as file_wrt:
  69. self.val_min_dict = pickle.load(file_wrt)
  70. with open(os.path.join(dict_path, "{}cat_count_dict.pkl".format(prefix)), "rb") as file_wrt:
  71. self.cat_count_dict = pickle.load(file_wrt)
  72. print("val_max_dict.items()[:50]:{}".format(list(self.val_max_dict.items())))
  73. print("val_min_dict.items()[:50]:{}".format(list(self.val_min_dict.items())))
  74. def get_cat2id(self, threshold=100):
  75. for key, cat_count_d in self.cat_count_dict.items():
  76. new_cat_count_d = dict(filter(lambda x: x[1] > threshold, cat_count_d.items()))
  77. for cat_str, _ in new_cat_count_d.items():
  78. self.cat2id_dict[key + "_" + cat_str] = len(self.cat2id_dict)
  79. print("cat2id_dict.size:{}".format(len(self.cat2id_dict)))
  80. print("cat2id.dict.items()[:50]:{}".format(list(self.cat2id_dict.items())[:50]))
  81. def map_cat2id(self, values, cats):
  82. """Cat to id"""
  83. def minmax_scale_value(i, val):
  84. max_v = float(self.val_max_dict["val_{}".format(i + 1)])
  85. return float(val) * 1.0 / max_v
  86. id_list = []
  87. weight_list = []
  88. for i, val in enumerate(values):
  89. if val == "":
  90. id_list.append(i)
  91. weight_list.append(0)
  92. else:
  93. key = "val_{}".format(i + 1)
  94. id_list.append(self.cat2id_dict[key])
  95. weight_list.append(minmax_scale_value(i, float(val)))
  96. for i, cat_str in enumerate(cats):
  97. key = "cat_{}".format(i + 1) + "_" + cat_str
  98. if key in self.cat2id_dict:
  99. id_list.append(self.cat2id_dict[key])
  100. else:
  101. id_list.append(self.cat2id_dict[self.oov_prefix + "cat_{}".format(i + 1)])
  102. weight_list.append(1.0)
  103. return id_list, weight_list
  104. def mkdir_path(file_path):
  105. if not os.path.exists(file_path):
  106. os.makedirs(file_path)
  107. def statsdata(file_path, dict_output_path, criteo_stats_dict):
  108. """Preprocess data and save data"""
  109. with open(file_path, encoding="utf-8") as file_in:
  110. errorline_list = []
  111. count = 0
  112. for line in file_in:
  113. count += 1
  114. line = line.strip("\n")
  115. items = line.split("\t")
  116. if len(items) != 40:
  117. errorline_list.append(count)
  118. print("line: {}".format(line))
  119. continue
  120. if count % 1000000 == 0:
  121. print("Have handled {}w lines.".format(count // 10000))
  122. values = items[1:14]
  123. cats = items[14:]
  124. assert len(values) == 13, "values.size: {}".format(len(values))
  125. assert len(cats) == 26, "cats.size: {}".format(len(cats))
  126. criteo_stats_dict.stats_vals(values)
  127. criteo_stats_dict.stats_cats(cats)
  128. criteo_stats_dict.save_dict(dict_output_path)
  129. def random_split_trans2mindrecord(input_file_path, output_file_path, criteo_stats_dict, part_rows=2000000,
  130. line_per_sample=1000,
  131. test_size=0.1, seed=2020):
  132. """Random split data and save mindrecord"""
  133. test_size = int(TRAIN_LINE_COUNT * test_size)
  134. all_indices = [i for i in range(TRAIN_LINE_COUNT)]
  135. np.random.seed(seed)
  136. np.random.shuffle(all_indices)
  137. print("all_indices.size:{}".format(len(all_indices)))
  138. test_indices_set = set(all_indices[:test_size])
  139. print("test_indices_set.size:{}".format(len(test_indices_set)))
  140. print("-----------------------" * 10 + "\n" * 2)
  141. train_data_list = []
  142. test_data_list = []
  143. ids_list = []
  144. wts_list = []
  145. label_list = []
  146. writer_train = FileWriter(os.path.join(output_file_path, "train_input_part.mindrecord"), 21)
  147. writer_test = FileWriter(os.path.join(output_file_path, "test_input_part.mindrecord"), 3)
  148. schema = {"label": {"type": "float32", "shape": [-1]}, "feat_vals": {"type": "float32", "shape": [-1]},
  149. "feat_ids": {"type": "int32", "shape": [-1]}}
  150. writer_train.add_schema(schema, "CRITEO_TRAIN")
  151. writer_test.add_schema(schema, "CRITEO_TEST")
  152. with open(input_file_path, encoding="utf-8") as file_in:
  153. items_error_size_lineCount = []
  154. count = 0
  155. train_part_number = 0
  156. test_part_number = 0
  157. for i, line in enumerate(file_in):
  158. count += 1
  159. if count % 1000000 == 0:
  160. print("Have handle {}w lines.".format(count // 10000))
  161. line = line.strip("\n")
  162. items = line.split("\t")
  163. if len(items) != 40:
  164. items_error_size_lineCount.append(i)
  165. continue
  166. label = float(items[0])
  167. values = items[1:14]
  168. cats = items[14:]
  169. assert len(values) == 13, "values.size: {}".format(len(values))
  170. assert len(cats) == 26, "cats.size: {}".format(len(cats))
  171. ids, wts = criteo_stats_dict.map_cat2id(values, cats)
  172. ids_list.extend(ids)
  173. wts_list.extend(wts)
  174. label_list.append(label)
  175. if count % line_per_sample == 0:
  176. if i not in test_indices_set:
  177. train_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32),
  178. "feat_vals": np.array(wts_list, dtype=np.float32),
  179. "label": np.array(label_list, dtype=np.float32)
  180. })
  181. else:
  182. test_data_list.append({"feat_ids": np.array(ids_list, dtype=np.int32),
  183. "feat_vals": np.array(wts_list, dtype=np.float32),
  184. "label": np.array(label_list, dtype=np.float32)
  185. })
  186. if train_data_list and len(train_data_list) % part_rows == 0:
  187. writer_train.write_raw_data(train_data_list)
  188. train_data_list.clear()
  189. train_part_number += 1
  190. if test_data_list and len(test_data_list) % part_rows == 0:
  191. writer_test.write_raw_data(test_data_list)
  192. test_data_list.clear()
  193. test_part_number += 1
  194. ids_list.clear()
  195. wts_list.clear()
  196. label_list.clear()
  197. if train_data_list:
  198. writer_train.write_raw_data(train_data_list)
  199. if test_data_list:
  200. writer_test.write_raw_data(test_data_list)
  201. writer_train.commit()
  202. writer_test.commit()
  203. print("-------------" * 10)
  204. print("items_error_size_lineCount.size(): {}.".format(len(items_error_size_lineCount)))
  205. print("-------------" * 10)
  206. np.save("items_error_size_lineCount.npy", items_error_size_lineCount)
  207. if __name__ == '__main__':
  208. parser = argparse.ArgumentParser(description="criteo data")
  209. parser.add_argument("--data_path", type=str, default="./criteo_data/")
  210. args, _ = parser.parse_known_args()
  211. data_path = args.data_path
  212. download_data_path = data_path + "origin_data/"
  213. mkdir_path(download_data_path)
  214. os.system(
  215. "wget -P {} -c https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz --no-check-certificate".format(
  216. download_data_path))
  217. os.system("tar -zxvf {}dac.tar.gz".format(download_data_path))
  218. criteo_stats = CriteoStatsDict()
  219. data_file_path = data_path + "origin_data/train.txt"
  220. stats_output_path = data_path + "stats_dict/"
  221. mkdir_path(stats_output_path)
  222. statsdata(data_file_path, stats_output_path, criteo_stats)
  223. criteo_stats.load_dict(dict_path=stats_output_path, prefix="")
  224. criteo_stats.get_cat2id(threshold=100)
  225. in_file_path = data_path + "origin_data/train.txt"
  226. output_path = data_path + "mindrecord/"
  227. mkdir_path(output_path)
  228. random_split_trans2mindrecord(in_file_path, output_path, criteo_stats, part_rows=2000000, line_per_sample=1000,
  229. test_size=0.1, seed=2020)