| @@ -27,7 +27,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace dataset { | namespace dataset { | ||||
| CsvOp::Builder::Builder() | CsvOp::Builder::Builder() | ||||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { | |||||
| : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(-1), builder_shuffle_files_(false) { | |||||
| std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); | ||||
| builder_num_workers_ = config_manager->num_parallel_workers(); | builder_num_workers_ = config_manager->num_parallel_workers(); | ||||
| builder_op_connector_size_ = config_manager->op_connector_size(); | builder_op_connector_size_ = config_manager->op_connector_size(); | ||||
| @@ -451,7 +451,7 @@ Status CsvOp::operator()() { | |||||
| RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); | ||||
| if (buffer->eoe()) { | if (buffer->eoe()) { | ||||
| workers_done++; | workers_done++; | ||||
| } else if (num_samples_ == 0 || rows_read < num_samples_) { | |||||
| } else if (num_samples_ == -1 || rows_read < num_samples_) { | |||||
| if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { | ||||
| int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); | ||||
| RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); | ||||
| @@ -4935,7 +4935,7 @@ class CSVDataset(SourceDataset): | |||||
| columns as string type. | columns as string type. | ||||
| column_names (list[str], optional): List of column names of the dataset (default=None). If this | column_names (list[str], optional): List of column names of the dataset (default=None). If this | ||||
| is not provided, infers the column_names from the first row of CSV file. | is not provided, infers the column_names from the first row of CSV file. | ||||
| num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). | |||||
| num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset). | |||||
| num_parallel_workers (int, optional): number of workers to read the data | num_parallel_workers (int, optional): number of workers to read the data | ||||
| (default=None, number set in the config). | (default=None, number set in the config). | ||||
| shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch | shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch | ||||
| @@ -4959,7 +4959,7 @@ class CSVDataset(SourceDataset): | |||||
| """ | """ | ||||
| @check_csvdataset | @check_csvdataset | ||||
| def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None, | |||||
| def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1, | |||||
| num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): | ||||
| super().__init__(num_parallel_workers) | super().__init__(num_parallel_workers) | ||||
| self.dataset_files = self._find_files(dataset_files) | self.dataset_files = self._find_files(dataset_files) | ||||
| @@ -5010,7 +5010,7 @@ class CSVDataset(SourceDataset): | |||||
| if self._dataset_size is None: | if self._dataset_size is None: | ||||
| num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) | num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) | ||||
| num_rows = get_num_rows(num_rows, self.num_shards) | num_rows = get_num_rows(num_rows, self.num_shards) | ||||
| if self.num_samples is None: | |||||
| if self.num_samples == -1: | |||||
| return num_rows | return num_rows | ||||
| return min(self.num_samples, num_rows) | return min(self.num_samples, num_rows) | ||||
| return self._dataset_size | return self._dataset_size | ||||
| @@ -813,12 +813,16 @@ def check_csvdataset(method): | |||||
| def new_method(self, *args, **kwargs): | def new_method(self, *args, **kwargs): | ||||
| _, param_dict = parse_user_args(method, *args, **kwargs) | _, param_dict = parse_user_args(method, *args, **kwargs) | ||||
| nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] | |||||
| nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id'] | |||||
| # check dataset_files; required argument | # check dataset_files; required argument | ||||
| dataset_files = param_dict.get('dataset_files') | dataset_files = param_dict.get('dataset_files') | ||||
| type_check(dataset_files, (str, list), "dataset files") | type_check(dataset_files, (str, list), "dataset files") | ||||
| # check num_samples | |||||
| num_samples = param_dict.get('num_samples') | |||||
| check_value(num_samples, [-1, INT32_MAX], "num_samples") | |||||
| # check field_delim | # check field_delim | ||||
| field_delim = param_dict.get('field_delim') | field_delim = param_dict.get('field_delim') | ||||
| type_check(field_delim, (str,), 'field delim') | type_check(field_delim, (str,), 'field delim') | ||||