Browse Source

!9277 fix schema & zip validation

From: @luoyang42
Reviewed-by: @pandoublefeng,@liucunwei
Signed-off-by: @liucunwei
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
920bbf1541
6 changed files with 92 additions and 14 deletions
  1. +11
    -0
      mindspore/ccsrc/minddata/dataset/api/datasets.cc
  2. +2
    -0
      mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/schema_bindings.cc
  3. +2
    -0
      mindspore/ccsrc/minddata/dataset/include/datasets.h
  4. +30
    -3
      mindspore/dataset/engine/datasets.py
  5. +6
    -0
      tests/ut/python/dataset/test_schema.py
  6. +41
    -11
      tests/ut/python/dataset/test_zip.py

+ 11
- 0
mindspore/ccsrc/minddata/dataset/api/datasets.cc View File

@@ -782,6 +782,7 @@ Status SchemaObj::from_json(nlohmann::json json_obj) {

return Status::OK();
}

Status SchemaObj::FromJSONString(const std::string &json_string) {
try {
nlohmann::json js = nlohmann::json::parse(json_string);
@@ -794,6 +795,16 @@ Status SchemaObj::FromJSONString(const std::string &json_string) {
return Status::OK();
}

Status SchemaObj::ParseColumnString(const std::string &json_string) {
try {
nlohmann::json js = nlohmann::json::parse(json_string);
RETURN_IF_NOT_OK(parse_column(js));
} catch (const std::exception &err) {
RETURN_STATUS_SYNTAX_ERROR("JSON string is failed to parse");
}
return Status::OK();
}

// OTHER FUNCTIONS

#ifndef ENABLE_ANDROID


+ 2
- 0
mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/include/schema_bindings.cc View File

@@ -42,6 +42,8 @@ PYBIND_REGISTER(
[](SchemaObj &self, std::string name, TypeId de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("add_column", [](SchemaObj &self, std::string name,
std::string de_type) { THROW_IF_ERROR(self.add_column(name, de_type)); })
.def("parse_columns",
[](SchemaObj &self, std::string json_string) { THROW_IF_ERROR(self.ParseColumnString(json_string)); })
.def("to_json", &SchemaObj::to_json)
.def("to_string", &SchemaObj::to_string)
.def("from_string",


+ 2
- 0
mindspore/ccsrc/minddata/dataset/include/datasets.h View File

@@ -424,6 +424,8 @@ class SchemaObj {

Status FromJSONString(const std::string &json_string);

Status ParseColumnString(const std::string &json_string);

private:
/// \brief Parse the columns and add it to columns
/// \param[in] columns dataset attribution information, decoded from schema file.


+ 30
- 3
mindspore/dataset/engine/datasets.py View File

@@ -96,6 +96,9 @@ def zip(datasets):
if len(datasets) <= 1:
raise ValueError(
"Can't zip empty or just one dataset!")
for dataset in datasets:
if not isinstance(dataset, Dataset):
raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
return ZipDataset(datasets)


@@ -2452,9 +2455,6 @@ class ZipDataset(Dataset):

def __init__(self, datasets):
super().__init__(children=datasets)
for dataset in datasets:
if not isinstance(dataset, Dataset):
raise TypeError("Invalid dataset, expected Dataset object, but got %s!" % type(dataset))
self.datasets = datasets

def parse(self, children=None):
@@ -4480,6 +4480,33 @@ class Schema:
else:
self.cpp_schema.add_column(name, col_type, shape)

def parse_columns(self, columns):
"""
Parse the columns and add it to self.

Args:
columns (Union[dict, list[dict]]): Dataset attribute information, decoded from schema file.

- list[dict], 'name' and 'type' must be in keys, 'shape' optional.

- dict, columns.keys() as name, columns.values() is dict, and 'type' inside, 'shape' optional.

Raises:
RuntimeError: If failed to parse columns.
RuntimeError: If unknown items in columns.
RuntimeError: If column's name field is missing.
RuntimeError: If column's type field is missing.

Example:
>>> schema = Schema()
>>> columns1 = [{'name': 'image', 'type': 'int8', 'shape': [3, 3]},
>>> {'name': 'label', 'type': 'int8', 'shape': [1]}]
>>> schema.parse_columns(columns1)
>>> columns2 = {'image': {'shape': [3, 3], 'type': 'int8'}, 'label': {'shape': [1], 'type': 'int8'}}
>>> schema.parse_columns(columns2)
"""
self.cpp_schema.parse_columns(json.dumps(columns, indent=2))

def to_json(self):
"""
Get a JSON string of the schema.


+ 6
- 0
tests/ut/python/dataset/test_schema.py View File

@@ -50,6 +50,12 @@ def test_schema_exception():
ds.Schema(1)
assert "Argument schema_file with value 1 is not of type (<class 'str'>,)" in str(info.value)

with pytest.raises(RuntimeError) as info:
schema = ds.Schema(SCHEMA_FILE)
columns = [{'type': 'int8', 'shape': [3, 3]}]
schema.parse_columns(columns)
assert "Column's name is missing" in str(info.value)


if __name__ == '__main__':
test_schema_simple()


+ 41
- 11
tests/ut/python/dataset/test_zip.py View File

@@ -250,16 +250,46 @@ def test_zip_exception_06():
logger.info("Got an exception in DE: {}".format(str(e)))


def test_zip_exception_07():
"""
Test zip: zip with string as parameter
"""
logger.info("test_zip_exception_07")

try:
dataz = ds.zip(('dataset1', 'dataset2'))

num_iter = 0
for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert False

except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))

try:
data = ds.TFRecordDataset(DATA_DIR_1, SCHEMA_DIR_1)
dataz = data.zip(('dataset1',))

num_iter = 0
for _ in dataz.create_dict_iterator(num_epochs=1, output_numpy=True):
num_iter += 1
assert False

except Exception as e:
logger.info("Got an exception in DE: {}".format(str(e)))

if __name__ == '__main__':
test_zip_01()
#test_zip_02()
#test_zip_03()
#test_zip_04()
#test_zip_05()
#test_zip_06()
#test_zip_exception_01()
#test_zip_exception_02()
#test_zip_exception_03()
#test_zip_exception_04()
#test_zip_exception_05()
#test_zip_exception_06()
test_zip_02()
test_zip_03()
test_zip_04()
test_zip_05()
test_zip_06()
test_zip_exception_01()
test_zip_exception_02()
test_zip_exception_03()
test_zip_exception_04()
test_zip_exception_05()
test_zip_exception_06()
test_zip_exception_07()

Loading…
Cancel
Save