diff --git a/mindspore/ccsrc/minddata/dataset/api/datasets.cc b/mindspore/ccsrc/minddata/dataset/api/datasets.cc index be7d942200..8417665139 100644 --- a/mindspore/ccsrc/minddata/dataset/api/datasets.cc +++ b/mindspore/ccsrc/minddata/dataset/api/datasets.cc @@ -782,6 +782,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) { return Status::OK(); } + Status SchemaObj::FromJSONString(const std::string &json_string) { try { nlohmann::json js = nlohmann::json::parse(json_string); @@ -794,6 +795,16 @@ Status SchemaObj::FromJSONString(const std::string &json_string) { return Status::OK(); } +Status SchemaObj::ParseColumnString(const std::string &json_string) { + try { + nlohmann::json js = nlohmann::json::parse(json_string); + RETURN_IF_NOT_OK(parse_column(js)); + } catch (const std::exception &err) { + RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse"); + } + return Status::OK(); +} + // OTHER FUNCTIONS #ifndef ENABLE_ANDROID diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/schema_bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/schema_bindings.cc index fd0256cd69..c91fdc1920 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/schema_bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/schema_bindings.cc @@ -42,6 +42,8 @@ PYBIND_REGISTER( [](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) .def("add_column", [](SchemaObj &self, std::string name, std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); }) + .def("parse_columns", + [](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.ParseColumnString(json_string)); }) .def("to_json", &SchemaObj::to_json) .def("to_string", &SchemaObj::to_string) .def("from_string", diff --git a/mindspore/ccsrc/minddata/dataset/include/datasets.h b/mindspore/ccsrc/minddata/dataset/include/datasets.h index bec58cad6b..d974671a7e 100644 --- a/mindspore/ccsrc/minddata/dataset/include/datasets.h +++ b/mindspore/ccsrc/minddata/dataset/include/datasets.h @@ -424,6 +424,8 @@ class SchemaObj { Status FromJSONString(const std::string &json_string); + Status ParseColumnString(const std::string &json_string); + private: /// \brief Parse the columns and add it to columns /// \param[in] columns dataset attribution information, decoded from schema file. diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index c73302d369..a6267b4191 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -96,6 +96,9 @@ def zip(datasets): if len(datasets) <= 1: raise ValueError( "Can't zip empty or just one dataset!") + for dataset in datasets: + if not isinstance(dataset, Dataset): + raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) return ZipDataset(datasets) @@ -2452,9 +2455,6 @@ class ZipDataset(Dataset): def __init__(self, datasets): super().__init__(children=datasets) - for dataset in datasets: - if not isinstance(dataset, Dataset): - raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset)) self.datasets = datasets def parse(self, children=None): @@ -4480,6 +4480,33 @@ class Schema: else: self.cpp_schema.add_column(name, col_type, shape) + def parse_columns(self, columns): + """ + Parse the columns and add it to self. + + Args: + columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file. + + - list[dict], 'name' and 'type' must be in keys, 'shape' optional. + + - dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional. + + Raises: + RuntimeError: If failed to parse columns. + RuntimeError: If unknown items in columns. + RuntimeError: If column's name field is missing. + RuntimeError: If column's type field is missing. + + Example: + >>> schema = Schema() + >>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]}, + >>> {'name': 'label', 'type': 'int8', 'shape': [1]}] + >>> schema.parse_columns(columns1) + >>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}} + >>> schema.parse_columns(columns2) + """ + self.cpp_schema.parse_columns(json.dumps(columns, indent=2)) + def to_json(self): """ Get a JSON string of the schema. diff --git a/tests/ut/python/dataset/test_schema.py b/tests/ut/python/dataset/test_schema.py index 0e7c2d2464..520d8c9a6c 100644 --- a/tests/ut/python/dataset/test_schema.py +++ b/tests/ut/python/dataset/test_schema.py @@ -50,6 +50,12 @@ def test_schema_exception(): ds.Schema(1) assert "Argument schema_file with value 1 is not of type (,)" in str(info.value) + with pytest.raises(RuntimeError) as info: + schema = ds.Schema(SCHEMA_FILE) + columns = [{'type': 'int8', 'shape': [3, 3]}] + schema.parse_columns(columns) + assert "Column's name is missing" in str(info.value) + if __name__ == '__main__': test_schema_simple() diff --git a/tests/ut/python/dataset/test_zip.py b/tests/ut/python/dataset/test_zip.py index ab6136473f..d3cc67e91c 100644 --- a/tests/ut/python/dataset/test_zip.py +++ b/tests/ut/python/dataset/test_zip.py @@ -250,16 +250,46 @@ def test_zip_exception_06(): logger.info("Got an exception in DE: {}".format(str(e))) +def test_zip_exception_07(): + """ + Test zip: zip with string as parameter + """ + logger.info("test_zip_exception_07") + + try: + dataz = ds.zip(('dataset1', 'dataset2')) + + num_iter = 0 + for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert False + + except Exception as e: + logger.info("Got an exception in DE: {}".format(str(e))) + + try: + data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1) + dataz = data.zip(('dataset1',)) + + num_iter = 0 + for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True): + num_iter += 1 + assert False + + except Exception as e: + logger.info("Got an exception in DE: {}".format(str(e))) + if __name__ == '__main__': test_zip_01() - #test_zip_02() - #test_zip_03() - #test_zip_04() - #test_zip_05() - #test_zip_06() - #test_zip_exception_01() - #test_zip_exception_02() - #test_zip_exception_03() - #test_zip_exception_04() - #test_zip_exception_05() - #test_zip_exception_06() + test_zip_02() + test_zip_03() + test_zip_04() + test_zip_05() + test_zip_06() + test_zip_exception_01() + test_zip_exception_02() + test_zip_exception_03() + test_zip_exception_04() + test_zip_exception_05() + test_zip_exception_06() + test_zip_exception_07()