| @@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const { | |||||
| std::ifstream in(schemaFile); | std::ifstream in(schemaFile); | ||||
| nlohmann::json js; | nlohmann::json js; | ||||
| in >> js; | in >> js; | ||||
| num_rows = js.value("numRows", 0); | |||||
| if (js.find("numRows") == js.end()) { | |||||
| num_rows = MAX_INTEGER_INT32; | |||||
| } else { | |||||
| num_rows = js.value("numRows", 0); | |||||
| } | |||||
| if (num_rows == 0) { | if (num_rows == 0) { | ||||
| std::string err_msg = | std::string err_msg = | ||||
| "Storage client has not properly done dataset " | "Storage client has not properly done dataset " | ||||
| @@ -163,6 +163,9 @@ Status TFReaderOp::Init() { | |||||
| if (total_rows_ == 0) { | if (total_rows_ == 0) { | ||||
| total_rows_ = data_schema_->num_rows(); | total_rows_ = data_schema_->num_rows(); | ||||
| } | } | ||||
| if (total_rows_ < 0) { | |||||
| RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); | |||||
| } | |||||
| // Build the index with our files such that each file corresponds to a key id. | // Build the index with our files such that each file corresponds to a key id. | ||||
| RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); | RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); | ||||
| @@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset): | |||||
| Args: | Args: | ||||
| dataset_files (list[str]): List of files to be read. | dataset_files (list[str]): List of files to be read. | ||||
| schema (str): Path to the json schema file. | |||||
| schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset. | |||||
| distribution (str, optional): Path of distribution config file (default=""). | distribution (str, optional): Path of distribution config file (default=""). | ||||
| columns_list (list[str], optional): List of columns to be read (default=None, read all columns). | columns_list (list[str], optional): List of columns to be read (default=None, read all columns). | ||||
| num_parallel_workers (int, optional): Number of parallel working threads (default=None). | num_parallel_workers (int, optional): Number of parallel working threads (default=None). | ||||
| @@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset): | |||||
| schema (str or Schema, optional): Path to the json schema file or schema object (default=None). | schema (str or Schema, optional): Path to the json schema file or schema object (default=None). | ||||
| If the schema is not provided, the meta data from the TFData file is considered the schema. | If the schema is not provided, the meta data from the TFData file is considered the schema. | ||||
| columns_list (list[str], optional): List of columns to be read (default=None, read all columns) | columns_list (list[str], optional): List of columns to be read (default=None, read all columns) | ||||
| num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). | |||||
| num_samples (int, optional): number of samples(rows) to read (default=None). | |||||
| If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; | |||||
| If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; | |||||
| If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. | |||||
| num_parallel_workers (int, optional): number of workers to read the data | num_parallel_workers (int, optional): number of workers to read the data | ||||
| (default=None, number set in the config). | (default=None, number set in the config). | ||||
| shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). | shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). | ||||
| @@ -2711,10 +2714,10 @@ class Schema: | |||||
| """ | """ | ||||
| def __init__(self, schema_file=None): | def __init__(self, schema_file=None): | ||||
| self.num_rows = None | |||||
| if schema_file is None: | if schema_file is None: | ||||
| self.columns = [] | self.columns = [] | ||||
| self.dataset_type = '' | self.dataset_type = '' | ||||
| self.num_rows = 0 | |||||
| else: | else: | ||||
| if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): | if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): | ||||
| raise ValueError("The file %s does not exist or permission denied!" % schema_file) | raise ValueError("The file %s does not exist or permission denied!" % schema_file) | ||||
| @@ -2859,6 +2862,9 @@ class Schema: | |||||
| raise RuntimeError("DatasetType field is missing.") | raise RuntimeError("DatasetType field is missing.") | ||||
| if self.columns is None: | if self.columns is None: | ||||
| raise RuntimeError("Columns are missing.") | raise RuntimeError("Columns are missing.") | ||||
| if self.num_rows is not None: | |||||
| if not isinstance(self.num_rows, int) or self.num_rows <= 0: | |||||
| raise ValueError("numRows must be greater than 0") | |||||
| def __str__(self): | def __str__(self): | ||||
| return self.to_json() | return self.to_json() | ||||
| @@ -0,0 +1,45 @@ | |||||
| { | |||||
| "datasetType": "TF", | |||||
| "columns": { | |||||
| "col_sint16": { | |||||
| "type": "int16", | |||||
| "rank": 1, | |||||
| "shape": [1] | |||||
| }, | |||||
| "col_sint32": { | |||||
| "type": "int32", | |||||
| "rank": 1, | |||||
| "shape": [1] | |||||
| }, | |||||
| "col_sint64": { | |||||
| "type": "int64", | |||||
| "rank": 1, | |||||
| "shape": [1] | |||||
| }, | |||||
| "col_float": { | |||||
| "type": "float32", | |||||
| "rank": 1, | |||||
| "shape": [1] | |||||
| }, | |||||
| "col_1d": { | |||||
| "type": "int64", | |||||
| "rank": 1, | |||||
| "shape": [2] | |||||
| }, | |||||
| "col_2d": { | |||||
| "type": "int64", | |||||
| "rank": 2, | |||||
| "shape": [2, 2] | |||||
| }, | |||||
| "col_3d": { | |||||
| "type": "int64", | |||||
| "rank": 3, | |||||
| "shape": [2, 2, 2] | |||||
| }, | |||||
| "col_binary": { | |||||
| "type": "uint8", | |||||
| "rank": 1, | |||||
| "shape": [1] | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,15 @@ | |||||
| { | |||||
| "datasetType": "TF", | |||||
| "columns": { | |||||
| "image": { | |||||
| "type": "uint8", | |||||
| "rank": 1, | |||||
| "t_impl": "cvmat" | |||||
| }, | |||||
| "label" : { | |||||
| "type": "uint64", | |||||
| "rank": 1, | |||||
| "t_impl": "flex" | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -37,3 +37,15 @@ def test_case_storage(): | |||||
| filename = "storage_result.npz" | filename = "storage_result.npz" | ||||
| save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) | ||||
| def test_case_no_rows(): | |||||
| DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] | |||||
| SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json" | |||||
| dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) | |||||
| assert dataset.get_dataset_size() == 3 | |||||
| count = 0 | |||||
| for data in dataset.create_tuple_iterator(): | |||||
| count += 1 | |||||
| assert count == 3 | |||||
| @@ -37,6 +37,36 @@ def test_case_tf_shape(): | |||||
| assert (len(output_shape[-1]) == 1) | assert (len(output_shape[-1]) == 1) | ||||
| def test_case_tf_read_all_dataset(): | |||||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" | |||||
| ds1 = ds.TFRecordDataset(FILES, schema_file) | |||||
| assert ds1.get_dataset_size() == 12 | |||||
| count = 0 | |||||
| for data in ds1.create_tuple_iterator(): | |||||
| count += 1 | |||||
| assert count == 12 | |||||
| def test_case_num_samples(): | |||||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" | |||||
| ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) | |||||
| assert ds1.get_dataset_size() == 8 | |||||
| count = 0 | |||||
| for data in ds1.create_dict_iterator(): | |||||
| count += 1 | |||||
| assert count == 8 | |||||
| def test_case_num_samples2(): | |||||
| schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" | |||||
| ds1 = ds.TFRecordDataset(FILES, schema_file) | |||||
| assert ds1.get_dataset_size() == 7 | |||||
| count = 0 | |||||
| for data in ds1.create_dict_iterator(): | |||||
| count += 1 | |||||
| assert count == 7 | |||||
| def test_case_tf_shape_2(): | def test_case_tf_shape_2(): | ||||
| ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) | ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) | ||||
| ds1 = ds1.batch(2) | ds1 = ds1.batch(2) | ||||