Browse Source

!5713 Align num_samples of CSV with other dataset

Merge pull request !5713 from jiangzhiwen/fix/csv_num_samples
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
01da7d2c1a
6 changed files with 20 additions and 24 deletions
  1. +1
    -1
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +2
    -2
      mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc
  3. +2
    -2
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  4. +3
    -3
      mindspore/dataset/engine/datasets.py
  5. +1
    -5
      mindspore/dataset/engine/validators.py
  6. +11
    -11
      tests/ut/cpp/dataset/c_api_dataset_csv_test.cc

+ 1
- 1
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -1200,7 +1200,7 @@ bool CSVDataset::ValidateParams() {
return false; return false;
} }


if (num_samples_ < -1) {
if (num_samples_ < 0) {
MS_LOG(ERROR) << "CSVDataset: Invalid number of samples: " << num_samples_; MS_LOG(ERROR) << "CSVDataset: Invalid number of samples: " << num_samples_;
return false; return false;
} }


+ 2
- 2
mindspore/ccsrc/minddata/dataset/engine/datasetops/source/csv_op.cc View File

@@ -27,7 +27,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
CsvOp::Builder::Builder() CsvOp::Builder::Builder()
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(-1), builder_shuffle_files_(false) {
: builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) {
std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager();
builder_num_workers_ = config_manager->num_parallel_workers(); builder_num_workers_ = config_manager->num_parallel_workers();
builder_op_connector_size_ = config_manager->op_connector_size(); builder_op_connector_size_ = config_manager->op_connector_size();
@@ -539,7 +539,7 @@ Status CsvOp::operator()() {
RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer));
if (buffer->eoe()) { if (buffer->eoe()) {
workers_done++; workers_done++;
} else if (num_samples_ == -1 || rows_read < num_samples_) {
} else if (num_samples_ == 0 || rows_read < num_samples_) {
if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) {
int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read);
RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove));


+ 2
- 2
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -191,7 +191,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
/// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the /// \param[in] column_names List of column names of the dataset (default={}). If this is not provided, infers the
/// column_names from the first row of CSV file. /// column_names from the first row of CSV file.
/// \param[in] num_samples The number of samples to be included in the dataset. /// \param[in] num_samples The number of samples to be included in the dataset.
/// (Default = -1 means all samples.)
/// (Default = 0 means all samples.)
/// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal) /// \param[in] shuffle The mode for shuffling data every epoch. (Default=ShuffleMode::kGlobal)
/// Can be any of: /// Can be any of:
/// ShuffleMode::kFalse - No shuffling is performed. /// ShuffleMode::kFalse - No shuffling is performed.
@@ -203,7 +203,7 @@ std::shared_ptr<CocoDataset> Coco(const std::string &dataset_dir, const std::str
/// \return Shared pointer to the current Dataset /// \return Shared pointer to the current Dataset
std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',', std::shared_ptr<CSVDataset> CSV(const std::vector<std::string> &dataset_files, char field_delim = ',',
const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {}, const std::vector<std::shared_ptr<CsvBase>> &column_defaults = {},
const std::vector<std::string> &column_names = {}, int64_t num_samples = -1,
const std::vector<std::string> &column_names = {}, int64_t num_samples = 0,
ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1, ShuffleMode shuffle = ShuffleMode::kGlobal, int32_t num_shards = 1,
int32_t shard_id = 0); int32_t shard_id = 0);




+ 3
- 3
mindspore/dataset/engine/datasets.py View File

@@ -5140,7 +5140,7 @@ class CSVDataset(SourceDataset):
columns as string type. columns as string type.
column_names (list[str], optional): List of column names of the dataset (default=None). If this column_names (list[str], optional): List of column names of the dataset (default=None). If this
is not provided, infers the column_names from the first row of CSV file. is not provided, infers the column_names from the first row of CSV file.
num_samples (int, optional): number of samples(rows) to read (default=-1, reads the full dataset).
num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset).
num_parallel_workers (int, optional): number of workers to read the data num_parallel_workers (int, optional): number of workers to read the data
(default=None, number set in the config). (default=None, number set in the config).
shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch shuffle (Union[bool, Shuffle level], optional): perform reshuffling of the data every epoch
@@ -5164,7 +5164,7 @@ class CSVDataset(SourceDataset):
""" """


@check_csvdataset @check_csvdataset
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=-1,
def __init__(self, dataset_files, field_delim=',', column_defaults=None, column_names=None, num_samples=None,
num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
self.dataset_files = self._find_files(dataset_files) self.dataset_files = self._find_files(dataset_files)
@@ -5215,7 +5215,7 @@ class CSVDataset(SourceDataset):
if self.dataset_size is None: if self.dataset_size is None:
num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None) num_rows = CsvOp.get_num_rows(self.dataset_files, self.column_names is None)
self.dataset_size = get_num_rows(num_rows, self.num_shards) self.dataset_size = get_num_rows(num_rows, self.num_shards)
if self.num_samples != -1 and self.num_samples < self.dataset_size:
if self.num_samples is not None and self.num_samples < self.dataset_size:
self.dataset_size = num_rows self.dataset_size = num_rows
return self.dataset_size return self.dataset_size




+ 1
- 5
mindspore/dataset/engine/validators.py View File

@@ -830,16 +830,12 @@ def check_csvdataset(method):
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
_, param_dict = parse_user_args(method, *args, **kwargs) _, param_dict = parse_user_args(method, *args, **kwargs)


nreq_param_int = ['num_parallel_workers', 'num_shards', 'shard_id']
nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id']


# check dataset_files; required argument # check dataset_files; required argument
dataset_files = param_dict.get('dataset_files') dataset_files = param_dict.get('dataset_files')
type_check(dataset_files, (str, list), "dataset files") type_check(dataset_files, (str, list), "dataset files")


# check num_samples
num_samples = param_dict.get('num_samples')
check_value(num_samples, [-1, INT32_MAX], "num_samples")

# check field_delim # check field_delim
field_delim = param_dict.get('field_delim') field_delim = param_dict.get('field_delim')
type_check(field_delim, (str,), 'field delim') type_check(field_delim, (str,), 'field delim')


+ 11
- 11
tests/ut/cpp/dataset/c_api_dataset_csv_test.cc View File

@@ -33,7 +33,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetBasic) {
// Create a CSVDataset, with single CSV file // Create a CSVDataset, with single CSV file
std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, -1, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -85,7 +85,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetMultiFiles) {
std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; std::string file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; std::string file2 = datasets_root_path_ + "/testCSV/append.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({file1, file2}, ',', {}, column_names, -1, ShuffleMode::kGlobal);
std::shared_ptr<Dataset> ds = CSV({file1, file2}, ',', {}, column_names, 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -179,7 +179,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDistribution) {
// Create a CSVDataset, with single CSV file // Create a CSVDataset, with single CSV file
std::string file = datasets_root_path_ + "/testCSV/1.csv"; std::string file = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({file}, ',', {}, column_names, -1, ShuffleMode::kFalse, 2, 0);
std::shared_ptr<Dataset> ds = CSV({file}, ',', {}, column_names, 0, ShuffleMode::kFalse, 2, 0);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -228,7 +228,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetType) {
std::make_shared<CsvRecord<std::string>>(CsvType::STRING, ""), std::make_shared<CsvRecord<std::string>>(CsvType::STRING, ""),
}; };
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({file}, ',', colum_type, column_names, -1, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds = CSV({file}, ',', colum_type, column_names, 0, ShuffleMode::kFalse);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -343,15 +343,15 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetException) {
EXPECT_EQ(ds1, nullptr); EXPECT_EQ(ds1, nullptr);


// Test invalid num_samples < -1 // Test invalid num_samples < -1
std::shared_ptr<Dataset> ds2 = CSV({file}, ',', {}, column_names, -2);
std::shared_ptr<Dataset> ds2 = CSV({file}, ',', {}, column_names, -1);
EXPECT_EQ(ds2, nullptr); EXPECT_EQ(ds2, nullptr);


// Test invalid num_shards < 1 // Test invalid num_shards < 1
std::shared_ptr<Dataset> ds3 = CSV({file}, ',', {}, column_names, -1, ShuffleMode::kFalse, 0);
std::shared_ptr<Dataset> ds3 = CSV({file}, ',', {}, column_names, 0, ShuffleMode::kFalse, 0);
EXPECT_EQ(ds3, nullptr); EXPECT_EQ(ds3, nullptr);


// Test invalid shard_id >= num_shards // Test invalid shard_id >= num_shards
std::shared_ptr<Dataset> ds4 = CSV({file}, ',', {}, column_names, -1, ShuffleMode::kFalse, 2, 2);
std::shared_ptr<Dataset> ds4 = CSV({file}, ',', {}, column_names, 0, ShuffleMode::kFalse, 2, 2);
EXPECT_EQ(ds4, nullptr); EXPECT_EQ(ds4, nullptr);


// Test invalid field_delim // Test invalid field_delim
@@ -373,7 +373,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesA) {
std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; std::string file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; std::string file2 = datasets_root_path_ + "/testCSV/append.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({file1, file2}, ',', {}, column_names, -1, ShuffleMode::kFiles);
std::shared_ptr<Dataset> ds = CSV({file1, file2}, ',', {}, column_names, 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -432,7 +432,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleFilesB) {
std::string file1 = datasets_root_path_ + "/testCSV/1.csv"; std::string file1 = datasets_root_path_ + "/testCSV/1.csv";
std::string file2 = datasets_root_path_ + "/testCSV/append.csv"; std::string file2 = datasets_root_path_ + "/testCSV/append.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({file2, file1}, ',', {}, column_names, -1, ShuffleMode::kFiles);
std::shared_ptr<Dataset> ds = CSV({file2, file1}, ',', {}, column_names, 0, ShuffleMode::kFiles);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -492,7 +492,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetShuffleGlobal) {
// Create a CSVFile Dataset, with single CSV file // Create a CSVFile Dataset, with single CSV file
std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col2", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, -1, ShuffleMode::kGlobal);
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kGlobal);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);


// Create an iterator over the result of the above dataset // Create an iterator over the result of the above dataset
@@ -540,7 +540,7 @@ TEST_F(MindDataTestPipeline, TestCSVDatasetDuplicateColumnName) {
// Create a CSVDataset, with single CSV file // Create a CSVDataset, with single CSV file
std::string train_file = datasets_root_path_ + "/testCSV/1.csv"; std::string train_file = datasets_root_path_ + "/testCSV/1.csv";
std::vector<std::string> column_names = {"col1", "col1", "col3", "col4"}; std::vector<std::string> column_names = {"col1", "col1", "col3", "col4"};
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, -1, ShuffleMode::kFalse);
std::shared_ptr<Dataset> ds = CSV({train_file}, ',', {}, column_names, 0, ShuffleMode::kFalse);
// Expect failure: duplicate column names // Expect failure: duplicate column names
EXPECT_EQ(ds, nullptr); EXPECT_EQ(ds, nullptr);
} }

Loading…
Cancel
Save