From 77578fdf7d5680d48d5ef399dad5366d02bd6cd2 Mon Sep 17 00:00:00 2001 From: zhouneng Date: Thu, 11 Mar 2021 15:28:43 +0800 Subject: [PATCH] padding sample last batch indices to be the same length with previous batch --- .../official/recommend/ncf/src/dataset.py | 49 ++++++++++++------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/model_zoo/official/recommend/ncf/src/dataset.py b/model_zoo/official/recommend/ncf/src/dataset.py index 589ff742bf..010eb197aa 100644 --- a/model_zoo/official/recommend/ncf/src/dataset.py +++ b/model_zoo/official/recommend/ncf/src/dataset.py @@ -22,7 +22,7 @@ import pickle import numpy as np import pandas as pd -from mindspore.dataset import GeneratorDataset +from mindspore.dataset import GeneratorDataset, Sampler import src.constants as rconst import src.movielens as movielens @@ -214,6 +214,7 @@ class NCFDataset: total_negatives, index_bounds, sorted_train_pos_items, + num_neg, is_training=True): self._pos_users = pos_users self._pos_items = pos_items @@ -234,6 +235,10 @@ class NCFDataset: self._eval_users_per_batch = int( batch_size // (1 + rconst.NUM_EVAL_NEGATIVES)) + _pos_count = pos_users.shape[0] + _num_samples = (1 + num_neg) * _pos_count + self.dataset_len = math.ceil(_num_samples / batch_size) + def lookup_negative_items(self, negative_users): """Lookup negative items""" output = np.zeros(shape=negative_users.shape, dtype=rconst.ITEM_DTYPE) - 1 @@ -402,8 +407,14 @@ class NCFDataset: return self._get_eval_item(index) + def __len__(self): + """ + Return length of the dataset, i.e., the number of batches for an epoch + """ + return self.dataset_len + -class RandomSampler: +class RandomSampler(Sampler): """ A random sampler for dataset. """ @@ -413,6 +424,7 @@ class RandomSampler: self._num_samples = (1 + num_train_negatives) * self.pos_count self._batch_size = batch_size self._num_batches = math.ceil(self._num_samples / self._batch_size) + super().__init__(self._num_batches) def __iter__(self): """ @@ -421,13 +433,14 @@ class RandomSampler: indices = stat_utils.permutation((self._num_samples, stat_utils.random_int32())) batch_indices = [indices[x * self._batch_size:(x + 1) * self._batch_size] for x in range(self._num_batches)] - return iter(batch_indices) - def __len__(self): - """ - Return length of the sampler, i.e., the number of batches for an epoch. - """ - return self._num_batches + # padding last batch indices if necessary + if len(batch_indices) > 2 and len(batch_indices[-2]) != len(batch_indices[-1]): + pad_nums = len(batch_indices[-2]) - len(batch_indices[-1]) + pad_indices = np.random.randint(0, self._num_samples, pad_nums) + batch_indices[-1] = np.hstack((batch_indices[-1], pad_indices)) + + return iter(batch_indices) class DistributedSamplerOfTrain: @@ -467,7 +480,7 @@ class DistributedSamplerOfTrain: return self._batchs_per_rank -class SequenceSampler: +class SequenceSampler(Sampler): """ A sequence sampler for dataset. """ @@ -478,10 +491,18 @@ class SequenceSampler: self._eval_elements_in_epoch = num_users * (1 + rconst.NUM_EVAL_NEGATIVES) self._eval_batches_per_epoch = self.count_batches( self._eval_elements_in_epoch, eval_batch_size) + super().__init__(self._eval_batches_per_epoch) def __iter__(self): indices = [(x * self._eval_users_per_batch, (x + 1) * self._eval_users_per_batch) for x in range(self._eval_batches_per_epoch)] + + # padding last batch indices if necessary + if len(indices) > 2 and len(indices[-2]) != len(indices[-1]): + pad_nums = len(indices[-2]) - len(indices[-1]) + pad_indices = np.random.randint(0, self._eval_elements_in_epoch, pad_nums) + indices[-1] = np.hstack((indices[-1], pad_indices)) + return iter(indices) @staticmethod @@ -490,12 +511,6 @@ class SequenceSampler: x = (example_count + batch_size - 1) // batch_size return (x + batches_per_step - 1) // batches_per_step * batches_per_step - def __len__(self): - """ - Return the length of the sampler, i,e, the number of batches in an epoch. - """ - return self._eval_batches_per_epoch - class DistributedSamplerOfEval: """ @@ -562,7 +577,7 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai print(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, index_bounds, sorted_train_pos_items) dataset = NCFDataset(train_pos_users, train_pos_items, num_users, num_items, batch_size, total_negatives, - index_bounds, sorted_train_pos_items) + index_bounds, sorted_train_pos_items, num_neg) sampler = RandomSampler(train_pos_users.shape[0], num_neg, batch_size) if rank_id is not None and rank_size is not None: sampler = DistributedSamplerOfTrain(train_pos_users.shape[0], num_neg, batch_size, rank_id, rank_size) @@ -585,7 +600,7 @@ def create_dataset(test_train=True, data_dir='./dataset/', dataset='ml-1m', trai eval_batch_size = parse_eval_batch_size(eval_batch_size=eval_batch_size) dataset = NCFDataset(eval_pos_users, eval_pos_items, num_users, num_items, eval_batch_size, total_negatives, index_bounds, - sorted_train_pos_items, is_training=False) + sorted_train_pos_items, num_neg, is_training=False) sampler = SequenceSampler(eval_batch_size, num_users) ds = GeneratorDataset(dataset,