You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Dataset loading, creation and processing"""
  16. import logging
  17. import math
  18. import os
  19. import time
  20. import timeit
  21. import pickle
  22. import numpy as np
  23. import pandas as pd
  24. from mindspore.dataset import GeneratorDataset, Sampler
  25. import src.constants as rconst
  26. import src.movielens as movielens
  27. import src.stat_utils as stat_utils
  28. DATASET_TO_NUM_USERS_AND_ITEMS = {
  29. "ml-1m": (6040, 3706),
  30. "ml-20m": (138493, 26744)
  31. }
  32. _EXPECTED_CACHE_KEYS = (
  33. rconst.TRAIN_USER_KEY, rconst.TRAIN_ITEM_KEY, rconst.EVAL_USER_KEY,
  34. rconst.EVAL_ITEM_KEY, rconst.USER_MAP, rconst.ITEM_MAP)
  35. def load_data(data_dir, dataset):
  36. """
  37. Load data in .csv format and output structured data.
  38. This function reads in the raw CSV of positive items, and performs three
  39. preprocessing transformations:
  40. 1) Filter out all users who have not rated at least a certain number
  41. of items. (Typically 20 items)
  42. 2) Zero index the users and items such that the largest user_id is
  43. `num_users - 1` and the largest item_id is `num_items - 1`
  44. 3) Sort the dataframe by user_id, with timestamp as a secondary sort key.
  45. This allows the dataframe to be sliced by user in-place, and for the last
  46. item to be selected simply by calling the `-1` index of a user's slice.
  47. While all of these transformations are performed by Pandas (and are therefore
  48. single-threaded), they only take ~2 minutes, and the overhead to apply a
  49. MapReduce pattern to parallel process the dataset adds significant complexity
  50. for no computational gain. For a larger dataset parallelizing this
  51. preprocessing could yield speedups. (Also, this preprocessing step is only
  52. performed once for an entire run.
  53. """
  54. logging.info("Beginning loading data...")
  55. raw_rating_path = os.path.join(data_dir, dataset, movielens.RATINGS_FILE)
  56. cache_path = os.path.join(data_dir, dataset, rconst.RAW_CACHE_FILE)
  57. valid_cache = os.path.exists(cache_path)
  58. if valid_cache:
  59. with open(cache_path, 'rb') as f:
  60. cached_data = pickle.load(f)
  61. for key in _EXPECTED_CACHE_KEYS:
  62. if key not in cached_data:
  63. valid_cache = False
  64. if not valid_cache:
  65. logging.info("Removing stale raw data cache file.")
  66. os.remove(cache_path)
  67. if valid_cache:
  68. data = cached_data
  69. else:
  70. # process data and save to .csv
  71. with open(raw_rating_path) as f:
  72. df = pd.read_csv(f)
  73. # Get the info of users who have more than 20 ratings on items
  74. grouped = df.groupby(movielens.USER_COLUMN)
  75. df = grouped.filter(lambda x: len(x) >= rconst.MIN_NUM_RATINGS)
  76. original_users = df[movielens.USER_COLUMN].unique()
  77. original_items = df[movielens.ITEM_COLUMN].unique()
  78. # Map the ids of user and item to 0 based index for following processing
  79. logging.info("Generating user_map and item_map...")
  80. user_map = {user: index for index, user in enumerate(original_users)}
  81. item_map = {item: index for index, item in enumerate(original_items)}
  82. df[movielens.USER_COLUMN] = df[movielens.USER_COLUMN].apply(
  83. lambda user: user_map[user])
  84. df[movielens.ITEM_COLUMN] = df[movielens.ITEM_COLUMN].apply(
  85. lambda item: item_map[item])
  86. num_users = len(original_users)
  87. num_items = len(original_items)
  88. assert num_users <= np.iinfo(rconst.USER_DTYPE).max
  89. assert num_items <= np.iinfo(rconst.ITEM_DTYPE).max
  90. assert df[movielens.USER_COLUMN].max() == num_users - 1
  91. assert df[movielens.ITEM_COLUMN].max() == num_items - 1
  92. # This sort is used to shard the dataframe by user, and later to select
  93. # the last item for a user to be used in validation.
  94. logging.info("Sorting by user, timestamp...")
  95. # This sort is equivalent to
  96. # df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
  97. # inplace=True)
  98. # except that the order of items with the same user and timestamp are
  99. # sometimes different. For some reason, this sort results in a better
  100. # hit-rate during evaluation, matching the performance of the MLPerf
  101. # reference implementation.
  102. df.sort_values(by=movielens.TIMESTAMP_COLUMN, inplace=True)
  103. df.sort_values([movielens.USER_COLUMN, movielens.TIMESTAMP_COLUMN],
  104. inplace=True, kind="mergesort")
  105. # The dataframe does not reconstruct indices in the sort or filter steps.
  106. df = df.reset_index()
  107. grouped = df.groupby(movielens.USER_COLUMN, group_keys=False)
  108. eval_df, train_df = grouped.tail(1), grouped.apply(lambda x: x.iloc[:-1])
  109. data = {
  110. rconst.TRAIN_USER_KEY:
  111. train_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
  112. rconst.TRAIN_ITEM_KEY:
  113. train_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
  114. rconst.EVAL_USER_KEY:
  115. eval_df[movielens.USER_COLUMN].values.astype(rconst.USER_DTYPE),
  116. rconst.EVAL_ITEM_KEY:
  117. eval_df[movielens.ITEM_COLUMN].values.astype(rconst.ITEM_DTYPE),
  118. rconst.USER_MAP: user_map,
  119. rconst.ITEM_MAP: item_map,
  120. "create_time": time.time(),
  121. }
  122. logging.info("Writing raw data cache.")
  123. with open(cache_path, "wb") as f:
  124. pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
  125. num_users, num_items = DATASET_TO_NUM_USERS_AND_ITEMS[dataset]
  126. if num_users != len(data[rconst.USER_MAP]):
  127. raise ValueError("Expected to find {} users, but found {}".format(
  128. num_users, len(data[rconst.USER_MAP])))
  129. if num_items != len(data[rconst.ITEM_MAP]):
  130. raise ValueError("Expected to find {} items, but found {}".format(
  131. num_items, len(data[rconst.ITEM_MAP])))
  132. return data, num_users, num_items
  133. def construct_lookup_variables(train_pos_users, train_pos_items, num_users):
  134. """Lookup variables"""
  135. index_bounds = None
  136. sorted_train_pos_items = None
  137. def index_segment(user):
  138. lower, upper = index_bounds[user:user + 2]
  139. items = sorted_train_pos_items[lower:upper]
  140. negatives_since_last_positive = np.concatenate(
  141. [items[0][np.newaxis], items[1:] - items[:-1] - 1])
  142. return np.cumsum(negatives_since_last_positive)
  143. start_time = timeit.default_timer()
  144. inner_bounds = np.argwhere(train_pos_users[1:] -
  145. train_pos_users[:-1])[:, 0] + 1
  146. (upper_bound,) = train_pos_users.shape
  147. index_bounds = np.array([0] + inner_bounds.tolist() + [upper_bound])
  148. # Later logic will assume that the users are in sequential ascending order.
  149. assert np.array_equal(train_pos_users[index_bounds[:-1]], np.arange(num_users))
  150. sorted_train_pos_items = train_pos_items.copy()
  151. for i in range(num_users):
  152. lower, upper = index_bounds[i:i + 2]
  153. sorted_train_pos_items[lower:upper].sort()
  154. total_negatives = np.concatenate([
  155. index_segment(i) for i in range(num_users)])
  156. logging.info("Negative total vector built. Time: {:.1f} seconds".format(
  157. timeit.default_timer() - start_time))
  158. return total_negatives, index_bounds, sorted_train_pos_items
  159. class NCFDataset:
  160. """
  161. A dataset for NCF network.
  162. """
  163. def __init__(self,
  164. pos_users,
  165. pos_items,
  166. num_users,
  167. num_items,
  168. batch_size,
  169. total_negatives,
  170. index_bounds,
  171. sorted_train_pos_items,
  172. num_neg,
  173. is_training=True):
  174. self._pos_users = pos_users
  175. self._pos_items = pos_items
  176. self._num_users = num_users
  177. self._num_items = num_items
  178. self._batch_size = batch_size
  179. self._total_negatives = total_negatives
  180. self._index_bounds = index_bounds
  181. self._sorted_train_pos_items = sorted_train_pos_items
  182. self._is_training = is_training
  183. if self._is_training:
  184. self._train_pos_count = self._pos_users.shape[0]
  185. else:
  186. self._eval_users_per_batch = int(
  187. batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
  188. _pos_count = pos_users.shape[0]
  189. _num_samples = (1 + num_neg) * _pos_count
  190. self.dataset_len = math.ceil(_num_samples / batch_size)
  191. def lookup_negative_items(self, negative_users):
  192. """Lookup negative items"""
  193. output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1
  194. left_index = self._index_bounds[negative_users]
  195. right_index = self._index_bounds[negative_users + 1] - 1
  196. num_positives = right_index - left_index + 1
  197. num_negatives = self._num_items - num_positives
  198. neg_item_choice = stat_utils.very_slightly_biased_randint(num_negatives)
  199. # Shortcuts:
  200. # For points where the negative is greater than or equal to the tally before
  201. # the last positive point there is no need to bisect. Instead the item id
  202. # corresponding to the negative item choice is simply:
  203. # last_postive_index + 1 + (neg_choice - last_negative_tally)
  204. # Similarly, if the selection is less than the tally at the first positive
  205. # then the item_id is simply the selection.
  206. #
  207. # Because MovieLens organizes popular movies into low integers (which is
  208. # preserved through the preprocessing), the first shortcut is very
  209. # efficient, allowing ~60% of samples to bypass the bisection. For the same
  210. # reason, the second shortcut is rarely triggered (<0.02%) and is therefore
  211. # not worth implementing.
  212. use_shortcut = neg_item_choice >= self._total_negatives[right_index]
  213. output[use_shortcut] = (
  214. self._sorted_train_pos_items[right_index] + 1 +
  215. (neg_item_choice - self._total_negatives[right_index])
  216. )[use_shortcut]
  217. if np.all(use_shortcut):
  218. # The bisection code is ill-posed when there are no elements.
  219. return output
  220. not_use_shortcut = np.logical_not(use_shortcut)
  221. left_index = left_index[not_use_shortcut]
  222. right_index = right_index[not_use_shortcut]
  223. neg_item_choice = neg_item_choice[not_use_shortcut]
  224. num_loops = np.max(
  225. np.ceil(np.log2(num_positives[not_use_shortcut])).astype(np.int32))
  226. for _ in range(num_loops):
  227. mid_index = (left_index + right_index) // 2
  228. right_criteria = self._total_negatives[mid_index] > neg_item_choice
  229. left_criteria = np.logical_not(right_criteria)
  230. right_index[right_criteria] = mid_index[right_criteria]
  231. left_index[left_criteria] = mid_index[left_criteria]
  232. # Expected state after bisection pass:
  233. # The right index is the smallest index whose tally is greater than the
  234. # negative item choice index.
  235. assert np.all((right_index - left_index) <= 1)
  236. output[not_use_shortcut] = (
  237. self._sorted_train_pos_items[right_index] - (self._total_negatives[right_index] - neg_item_choice)
  238. )
  239. assert np.all(output >= 0)
  240. return output
  241. def _get_train_item(self, index):
  242. """Get train item"""
  243. (mask_start_index,) = index.shape
  244. index_mod = np.mod(index, self._train_pos_count)
  245. # get batch of users
  246. users = self._pos_users[index_mod]
  247. # get batch of items
  248. negative_indices = np.greater_equal(index, self._train_pos_count)
  249. negative_users = users[negative_indices]
  250. negative_items = self.lookup_negative_items(negative_users=negative_users)
  251. items = self._pos_items[index_mod]
  252. items[negative_indices] = negative_items
  253. # get batch of labels
  254. labels = np.logical_not(negative_indices)
  255. # pad last partial batch
  256. pad_length = self._batch_size - index.shape[0]
  257. if pad_length:
  258. user_pad = np.arange(pad_length, dtype=users.dtype) % self._num_users
  259. item_pad = np.arange(pad_length, dtype=items.dtype) % self._num_items
  260. label_pad = np.zeros(shape=(pad_length,), dtype=labels.dtype)
  261. users = np.concatenate([users, user_pad])
  262. items = np.concatenate([items, item_pad])
  263. labels = np.concatenate([labels, label_pad])
  264. users = np.reshape(users, (self._batch_size, 1)) # (_batch_size, 1), int32
  265. items = np.reshape(items, (self._batch_size, 1)) # (_batch_size, 1), int32
  266. mask_start_index = np.array(mask_start_index, dtype=np.int32) # (_batch_size, 1), int32
  267. valid_pt_mask = np.expand_dims(
  268. np.less(np.arange(self._batch_size), mask_start_index), -1).astype(np.float32) # (_batch_size, 1), bool
  269. labels = np.reshape(labels, (self._batch_size, 1)).astype(np.int32) # (_batch_size, 1), bool
  270. return users, items, labels, valid_pt_mask
  271. @staticmethod
  272. def _assemble_eval_batch(users, positive_items, negative_items,
  273. users_per_batch):
  274. """Construct duplicate_mask and structure data accordingly.
  275. The positive items should be last so that they lose ties. However, they
  276. should not be masked out if the true eval positive happens to be
  277. selected as a negative. So instead, the positive is placed in the first
  278. position, and then switched with the last element after the duplicate
  279. mask has been computed.
  280. Args:
  281. users: An array of users in a batch. (should be identical along axis 1)
  282. positive_items: An array (batch_size x 1) of positive item indices.
  283. negative_items: An array of negative item indices.
  284. users_per_batch: How many users should be in the batch. This is passed
  285. as an argument so that ncf_test.py can use this method.
  286. Returns:
  287. User, item, and duplicate_mask arrays.
  288. """
  289. items = np.concatenate([positive_items, negative_items], axis=1)
  290. # We pad the users and items here so that the duplicate mask calculation
  291. # will include padding. The metric function relies on all padded elements
  292. # except the positive being marked as duplicate to mask out padded points.
  293. if users.shape[0] < users_per_batch:
  294. pad_rows = users_per_batch - users.shape[0]
  295. padding = np.zeros(shape=(pad_rows, users.shape[1]), dtype=np.int32)
  296. users = np.concatenate([users, padding.astype(users.dtype)], axis=0)
  297. items = np.concatenate([items, padding.astype(items.dtype)], axis=0)
  298. duplicate_mask = stat_utils.mask_duplicates(items, axis=1).astype(np.float32)
  299. items[:, (0, -1)] = items[:, (-1, 0)]
  300. duplicate_mask[:, (0, -1)] = duplicate_mask[:, (-1, 0)]
  301. assert users.shape == items.shape == duplicate_mask.shape
  302. return users, items, duplicate_mask
  303. def _get_eval_item(self, index):
  304. """Get eval item"""
  305. low_index, high_index = index
  306. users = np.repeat(self._pos_users[low_index:high_index, np.newaxis],
  307. 1 + rconst.NUM_EVAL_NEGATIVES, axis=1)
  308. positive_items = self._pos_items[low_index:high_index, np.newaxis]
  309. negative_items = (self.lookup_negative_items(negative_users=users[:, :-1])
  310. .reshape(-1, rconst.NUM_EVAL_NEGATIVES))
  311. users, items, duplicate_mask = self._assemble_eval_batch(
  312. users, positive_items, negative_items, self._eval_users_per_batch)
  313. users = np.reshape(users.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), int32
  314. items = np.reshape(items.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), int32
  315. duplicate_mask = np.reshape(duplicate_mask.flatten(), (self._batch_size, 1)) # (self._batch_size, 1), bool
  316. return users, items, duplicate_mask
  317. def __getitem__(self, index):
  318. """
  319. Get a batch of samples.
  320. """
  321. if self._is_training:
  322. return self._get_train_item(index)
  323. return self._get_eval_item(index)
  324. def __len__(self):
  325. """
  326. Return length of the dataset, i.e., the number of batches for an epoch
  327. """
  328. return self.dataset_len
  329. class RandomSampler(Sampler):
  330. """
  331. A random sampler for dataset.
  332. """
  333. def __init__(self, pos_count, num_train_negatives, batch_size):
  334. self.pos_count = pos_count
  335. self._num_samples = (1 + num_train_negatives) * self.pos_count
  336. self._batch_size = batch_size
  337. self._num_batches = math.ceil(self._num_samples / self._batch_size)
  338. super().__init__(self._num_batches)
  339. def __iter__(self):
  340. """
  341. Return indices of all batches within an epoch.
  342. """
  343. indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
  344. batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)]
  345. # padding last batch indices if necessary
  346. if len(batch_indices) > 2 and len(batch_indices[-2]) != len(batch_indices[-1]):
  347. pad_nums = len(batch_indices[-2]) - len(batch_indices[-1])
  348. pad_indices = np.random.randint(0, self._num_samples, pad_nums)
  349. batch_indices[-1] = np.hstack((batch_indices[-1], pad_indices))
  350. return iter(batch_indices)
  351. class DistributedSamplerOfTrain:
  352. """
  353. A distributed sampler for dataset.
  354. """
  355. def __init__(self, pos_count, num_train_negatives, batch_size, rank_id, rank_size):
  356. """
  357. Distributed sampler of training dataset.
  358. """
  359. self._num_samples = (1 + num_train_negatives) * pos_count
  360. self._rank_id = rank_id
  361. self._rank_size = rank_size
  362. self._batch_size = batch_size
  363. self._batchs_per_rank = int(math.ceil(self._num_samples / self._batch_size / rank_size))
  364. self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._batch_size))
  365. self._total_num_samples = self._samples_per_rank * self._rank_size
  366. def __iter__(self):
  367. """
  368. Returns the data after each sampling.
  369. """
  370. indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32()))
  371. indices = indices.tolist()
  372. indices.extend(indices[:self._total_num_samples - len(indices)])
  373. indices = indices[self._rank_id:self._total_num_samples:self._rank_size]
  374. batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._batchs_per_rank)]
  375. return iter(np.array(batch_indices))
  376. def __len__(self):
  377. """
  378. Returns the length after each sampling.
  379. """
  380. return self._batchs_per_rank
  381. class SequenceSampler(Sampler):
  382. """
  383. A sequence sampler for dataset.
  384. """
  385. def __init__(self, eval_batch_size, num_users):
  386. self._eval_users_per_batch = int(
  387. eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
  388. self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
  389. self._eval_batches_per_epoch = self.count_batches(
  390. self._eval_elements_in_epoch, eval_batch_size)
  391. super().__init__(self._eval_batches_per_epoch)
  392. def __iter__(self):
  393. indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch)
  394. for x in range(self._eval_batches_per_epoch)]
  395. # padding last batch indices if necessary
  396. if len(indices) > 2 and len(indices[-2]) != len(indices[-1]):
  397. pad_nums = len(indices[-2]) - len(indices[-1])
  398. pad_indices = np.random.randint(0, self._eval_elements_in_epoch, pad_nums)
  399. indices[-1] = np.hstack((indices[-1], pad_indices))
  400. return iter(indices)
  401. @staticmethod
  402. def count_batches(example_count, batch_size, batches_per_step=1):
  403. """Determine the number of batches, rounding up to fill all devices."""
  404. x = (example_count + batch_size - 1) // batch_size
  405. return (x + batches_per_step - 1) // batches_per_step * batches_per_step
  406. class DistributedSamplerOfEval:
  407. """
  408. A distributed sampler for eval dataset.
  409. """
  410. def __init__(self, eval_batch_size, num_users, rank_id, rank_size):
  411. self._eval_users_per_batch = int(
  412. eval_batch_size // (1 + rconst.NUM_EVAL_NEGATIVES))
  413. self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES)
  414. self._eval_batches_per_epoch = self.count_batches(
  415. self._eval_elements_in_epoch, eval_batch_size)
  416. self._rank_id = rank_id
  417. self._rank_size = rank_size
  418. self._eval_batch_size = eval_batch_size
  419. self._batchs_per_rank = int(math.ceil(self._eval_batches_per_epoch / rank_size))
  420. # self._samples_per_rank = int(math.ceil(self._batchs_per_rank * self._eval_batch_size))
  421. # self._total_num_samples = self._samples_per_rank * self._rank_size
  422. def __iter__(self):
  423. indices = [(x * self._eval_users_per_batch, (x + self._rank_id + 1) * self._eval_users_per_batch)
  424. for x in range(self._batchs_per_rank)]
  425. return iter(np.array(indices))
  426. @staticmethod
  427. def count_batches(example_count, batch_size, batches_per_step=1):
  428. """Determine the number of batches, rounding up to fill all devices."""
  429. x = (example_count + batch_size - 1) // batch_size
  430. return (x + batches_per_step - 1) // batches_per_step * batches_per_step
  431. def __len__(self):
  432. return self._batchs_per_rank
  433. def parse_eval_batch_size(eval_batch_size):
  434. """
  435. Parse eval batch size.
  436. """
  437. if eval_batch_size % (1 + rconst.NUM_EVAL_NEGATIVES):
  438. raise ValueError("Eval batch size {} is not divisible by {}".format(
  439. eval_batch_size, 1 + rconst.NUM_EVAL_NEGATIVES))
  440. return eval_batch_size
  441. def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', train_epochs=14, batch_size=256,
  442. eval_batch_size=160000, num_neg=4, rank_id=None, rank_size=None):
  443. """
  444. Create NCF dataset.
  445. """
  446. data, num_users, num_items = load_data(data_dir, dataset)
  447. train_pos_users = data[rconst.TRAIN_USER_KEY]
  448. train_pos_items = data[rconst.TRAIN_ITEM_KEY]
  449. eval_pos_users = data[rconst.EVAL_USER_KEY]
  450. eval_pos_items = data[rconst.EVAL_ITEM_KEY]
  451. total_negatives, index_bounds, sorted_train_pos_items = \
  452. construct_lookup_variables(train_pos_users, train_pos_items, num_users)
  453. if test_train:
  454. print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds,
  455. sorted_train_pos_items)
  456. dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives,
  457. index_bounds, sorted_train_pos_items, num_neg)
  458. sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size)
  459. if rank_id is not None and rank_size is not None:
  460. sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size)
  461. if dataset == 'ml-20m':
  462. ds = GeneratorDataset(dataset,
  463. column_names=[movielens.USER_COLUMN,
  464. movielens.ITEM_COLUMN,
  465. "labels",
  466. rconst.VALID_POINT_MASK],
  467. sampler=sampler, num_parallel_workers=32, python_multiprocessing=False)
  468. else:
  469. ds = GeneratorDataset(dataset,
  470. column_names=[movielens.USER_COLUMN,
  471. movielens.ITEM_COLUMN,
  472. "labels",
  473. rconst.VALID_POINT_MASK],
  474. sampler=sampler)
  475. else:
  476. eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size)
  477. dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items,
  478. eval_batch_size, total_negatives, index_bounds,
  479. sorted_train_pos_items, num_neg, is_training=False)
  480. sampler = SequenceSampler(eval_batch_size, num_users)
  481. ds = GeneratorDataset(dataset,
  482. column_names=[movielens.USER_COLUMN,
  483. movielens.ITEM_COLUMN,
  484. rconst.DUPLICATE_MASK],
  485. sampler=sampler)
  486. repeat_count = train_epochs if test_train else train_epochs + 1
  487. ds = ds.repeat(repeat_count)
  488. return ds, num_users, num_items