From: @luoyang42 Reviewed-by: @pandoublefeng,@liucunwei Signed-off-by: @liucunweitags/v1.1.0
| @@ -782,6 +782,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status SchemaObj::FromJSONString(const std::string &json_string) { | Status SchemaObj::FromJSONString(const std::string &json_string) { | ||||
| try { | try { | ||||
| nlohmann::json js = nlohmann::json::parse(json_string); | nlohmann::json js = nlohmann::json::parse(json_string); | ||||
| @@ -794,6 +795,16 @@ Status SchemaObj::FromJSONString(const std::string &json_string) { | |||||
| return Status::OK(); | return Status::OK(); | ||||
| } | } | ||||
| Status SchemaObj::ParseColumnString(const std::string &json_string) { | |||||
| try { | |||||
| nlohmann::json js = nlohmann::json::parse(json_string); | |||||
| RETURN_IF_NOT_OK(parse_column(js)); | |||||
| } catch (const std::exception &err) { | |||||
| RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse"); | |||||
| } | |||||
| return Status::OK(); | |||||
| } | |||||
| // OTHER FUNCTIONS | // OTHER FUNCTIONS | ||||
| #ifndef ENABLE_ANDROID | #ifndef ENABLE_ANDROID | ||||
| @@ -42,6 +42,8 @@ PYBIND_REGISTER( | |||||
| [](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | [](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | ||||
| .def("add_column", [](SchemaObj &self, std::string name, | .def("add_column", [](SchemaObj &self, std::string name, | ||||
| std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) | ||||
| .def("parse_columns", | |||||
| [](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.ParseColumnString(json_string)); }) | |||||
| .def("to_json", &SchemaObj::to_json) | .def("to_json", &SchemaObj::to_json) | ||||
| .def("to_string", &SchemaObj::to_string) | .def("to_string", &SchemaObj::to_string) | ||||
| .def("from_string", | .def("from_string", | ||||
| @@ -424,6 +424,8 @@ class SchemaObj { | |||||
| Status FromJSONString(const std::string &json_string); | Status FromJSONString(const std::string &json_string); | ||||
| Status ParseColumnString(const std::string &json_string); | |||||
| private: | private: | ||||
| /// \brief Parse the columns and add it to columns | /// \brief Parse the columns and add it to columns | ||||
| /// \param[in] columns dataset attribution information, decoded from schema file. | /// \param[in] columns dataset attribution information, decoded from schema file. | ||||
| @@ -96,6 +96,9 @@ def zip(datasets): | |||||
| if len(datasets) <= 1: | if len(datasets) <= 1: | ||||
| raise ValueError( | raise ValueError( | ||||
| "Can't zip empty or just one dataset!") | "Can't zip empty or just one dataset!") | ||||
| for dataset in datasets: | |||||
| if not isinstance(dataset, Dataset): | |||||
| raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) | |||||
| return ZipDataset(datasets) | return ZipDataset(datasets) | ||||
| @@ -2452,9 +2455,6 @@ class ZipDataset(Dataset): | |||||
| def __init__(self, datasets): | def __init__(self, datasets): | ||||
| super().__init__(children=datasets) | super().__init__(children=datasets) | ||||
| for dataset in datasets: | |||||
| if not isinstance(dataset, Dataset): | |||||
| raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) | |||||
| self.datasets = datasets | self.datasets = datasets | ||||
| def parse(self, children=None): | def parse(self, children=None): | ||||
| @@ -4480,6 +4480,33 @@ class Schema: | |||||
| else: | else: | ||||
| self.cpp_schema.add_column(name, col_type, shape) | self.cpp_schema.add_column(name, col_type, shape) | ||||
| def parse_columns(self, columns): | |||||
| """ | |||||
| Parse the columns and add it to self. | |||||
| Args: | |||||
| columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file. | |||||
| - list[dict], 'name' and 'type' must be in keys, 'shape' optional. | |||||
| - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional. | |||||
| Raises: | |||||
| RuntimeError: If failed to parse columns. | |||||
| RuntimeError: If unknown items in columns. | |||||
| RuntimeError: If column's name field is missing. | |||||
| RuntimeError: If column's type field is missing. | |||||
| Example: | |||||
| >>> schema = Schema() | |||||
| >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]}, | |||||
| >>> {'name': 'label', 'type': 'int8', 'shape': [1]}] | |||||
| >>> schema.parse_columns(columns1) | |||||
| >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}} | |||||
| >>> schema.parse_columns(columns2) | |||||
| """ | |||||
| self.cpp_schema.parse_columns(json.dumps(columns, indent=2)) | |||||
| def to_json(self): | def to_json(self): | ||||
| """ | """ | ||||
| Get a JSON string of the schema. | Get a JSON string of the schema. | ||||
| @@ -50,6 +50,12 @@ def test_schema_exception(): | |||||
| ds.Schema(1) | ds.Schema(1) | ||||
| assert "Argument schema_file with value 1 is not of type (<class 'str'>,)" in str(info.value) | assert "Argument schema_file with value 1 is not of type (<class 'str'>,)" in str(info.value) | ||||
| with pytest.raises(RuntimeError) as info: | |||||
| schema = ds.Schema(SCHEMA_FILE) | |||||
| columns = [{'type': 'int8', 'shape': [3, 3]}] | |||||
| schema.parse_columns(columns) | |||||
| assert "Column's name is missing" in str(info.value) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_schema_simple() | test_schema_simple() | ||||
| @@ -250,16 +250,46 @@ def test_zip_exception_06(): | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | logger.info("Got an exception in DE: {}".format(str(e))) | ||||
| def test_zip_exception_07(): | |||||
| """ | |||||
| Test zip: zip with string as parameter | |||||
| """ | |||||
| logger.info("test_zip_exception_07") | |||||
| try: | |||||
| dataz = ds.zip(('dataset1', 'dataset2')) | |||||
| num_iter = 0 | |||||
| for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| num_iter += 1 | |||||
| assert False | |||||
| except Exception as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| try: | |||||
| data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) | |||||
| dataz = data.zip(('dataset1',)) | |||||
| num_iter = 0 | |||||
| for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True): | |||||
| num_iter += 1 | |||||
| assert False | |||||
| except Exception as e: | |||||
| logger.info("Got an exception in DE: {}".format(str(e))) | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| test_zip_01() | test_zip_01() | ||||
| #test_zip_02() | |||||
| #test_zip_03() | |||||
| #test_zip_04() | |||||
| #test_zip_05() | |||||
| #test_zip_06() | |||||
| #test_zip_exception_01() | |||||
| #test_zip_exception_02() | |||||
| #test_zip_exception_03() | |||||
| #test_zip_exception_04() | |||||
| #test_zip_exception_05() | |||||
| #test_zip_exception_06() | |||||
| test_zip_02() | |||||
| test_zip_03() | |||||
| test_zip_04() | |||||
| test_zip_05() | |||||
| test_zip_06() | |||||
| test_zip_exception_01() | |||||
| test_zip_exception_02() | |||||
| test_zip_exception_03() | |||||
| test_zip_exception_04() | |||||
| test_zip_exception_05() | |||||
| test_zip_exception_06() | |||||
| test_zip_exception_07() | |||||