From: @mahdirahmanihanzaki Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -113,13 +113,11 @@ int64_t SequentialSamplerRT::CalculateNumSamples(int64_t num_rows) { | |||||
| int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | int64_t num_samples = (num_samples_ > 0) ? std::min(child_num_rows, num_samples_) : child_num_rows; | ||||
| // For this sampler we need to take start_index into account. Because for example in the case we are given n rows | // For this sampler we need to take start_index into account. Because for example in the case we are given n rows | ||||
| // and start_index != 0 and num_samples >= n then we can't return all the n rows. | // and start_index != 0 and num_samples >= n then we can't return all the n rows. | ||||
| if (child_num_rows - (start_index_ - current_id_) <= 0) { | |||||
| if (child_num_rows - start_index_ <= 0) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| if (child_num_rows - (start_index_ - current_id_) < num_samples) | |||||
| num_samples = child_num_rows - (start_index_ - current_id_) > num_samples | |||||
| ? num_samples | |||||
| : num_samples - (start_index_ - current_id_); | |||||
| if (child_num_rows - start_index_ < num_samples) | |||||
| num_samples = child_num_rows - start_index_ > num_samples ? num_samples : num_samples - start_index_; | |||||
| return num_samples; | return num_samples; | ||||
| } | } | ||||
| @@ -70,16 +70,16 @@ def test_numpyslices_sampler_chain(): | |||||
| # Use 1 statement to add child sampler | # Use 1 statement to add child sampler | ||||
| np_data = [1, 2, 3, 4] | np_data = [1, 2, 3, 4] | ||||
| sampler = ds.SequentialSampler(start_index=1, num_samples=2) | sampler = ds.SequentialSampler(start_index=1, num_samples=2) | ||||
| sampler = sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) | |||||
| sampler.add_child(ds.SequentialSampler(start_index=1, num_samples=2)) | |||||
| data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) | data1 = ds.NumpySlicesDataset(np_data, sampler=sampler) | ||||
| # Verify dataset size | # Verify dataset size | ||||
| data1_size = data1.get_dataset_size() | data1_size = data1.get_dataset_size() | ||||
| logger.info("dataset size is: {}".format(data1_size)) | logger.info("dataset size is: {}".format(data1_size)) | ||||
| assert data1_size == 4 | |||||
| assert data1_size == 1 | |||||
| # Verify number of rows | # Verify number of rows | ||||
| assert sum([1 for _ in data1]) == 4 | |||||
| assert sum([1 for _ in data1]) == 1 | |||||
| # Verify dataset contents | # Verify dataset contents | ||||
| res = [] | res = [] | ||||