You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_schema.py 2.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import json
  16. import pytest
  17. import mindspore.dataset as ds
  18. from mindspore import log as logger
  19. from util import dataset_equal
  20. FILES = ["../data/dataset/testTFTestAllTypes/test.data"]
  21. DATASET_ROOT = "../data/dataset/testTFTestAllTypes/"
  22. SCHEMA_FILE = "../data/dataset/testTFTestAllTypes/datasetSchema.json"
  23. def test_schema_simple():
  24. logger.info("test_schema_simple")
  25. ds.Schema(SCHEMA_FILE)
  26. def test_schema_file_vs_string():
  27. logger.info("test_schema_file_vs_string")
  28. schema1 = ds.Schema(SCHEMA_FILE)
  29. with open(SCHEMA_FILE) as file:
  30. json_obj = json.load(file)
  31. schema2 = ds.Schema()
  32. schema2.from_json(json_obj)
  33. ds1 = ds.TFRecordDataset(FILES, schema1)
  34. ds2 = ds.TFRecordDataset(FILES, schema2)
  35. dataset_equal(ds1, ds2, 0)
  36. def test_schema_exception():
  37. logger.info("test_schema_exception")
  38. with pytest.raises(TypeError) as info:
  39. ds.Schema(1)
  40. assert "Argument schema_file with value 1 is not of type (<class 'str'>,)" in str(info.value)
  41. with pytest.raises(RuntimeError) as info:
  42. schema = ds.Schema(SCHEMA_FILE)
  43. columns = [{'type': 'int8', 'shape': [3, 3]}]
  44. schema.parse_columns(columns)
  45. assert "Column's name is missing" in str(info.value)
  46. if __name__ == '__main__':
  47. test_schema_simple()
  48. test_schema_file_vs_string()
  49. test_schema_exception()