You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_string.py 6.6 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import mindspore._c_dataengine as cde
  16. import numpy as np
  17. import pytest
  18. from mindspore.dataset.text import to_str, to_bytes
  19. import mindspore.dataset as ds
  20. import mindspore.common.dtype as mstype
  21. # pylint: disable=comparison-with-itself
  22. def test_basic():
  23. x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
  24. n = cde.Tensor(x)
  25. arr = n.as_array()
  26. np.testing.assert_array_equal(x, arr)
  27. def compare(strings):
  28. arr = np.array(strings, dtype='S')
  29. def gen():
  30. yield arr,
  31. data = ds.GeneratorDataset(gen, column_names=["col"])
  32. for d in data:
  33. np.testing.assert_array_equal(d[0], arr)
  34. def test_generator():
  35. compare(["ab"])
  36. compare(["ab", "cde", "121"])
  37. compare([["ab", "cde", "121"], ["x", "km", "789"]])
  38. def test_batching_strings():
  39. def gen():
  40. yield np.array(["ab", "cde", "121"], dtype='S'),
  41. data = ds.GeneratorDataset(gen, column_names=["col"]).batch(10)
  42. with pytest.raises(RuntimeError) as info:
  43. for _ in data:
  44. pass
  45. assert "[Batch ERROR] Batch does not support" in str(info.value)
  46. def test_map():
  47. def gen():
  48. yield np.array(["ab cde 121"], dtype='S'),
  49. data = ds.GeneratorDataset(gen, column_names=["col"])
  50. def split(b):
  51. s = to_str(b)
  52. splits = s.item().split()
  53. return np.array(splits, dtype='S')
  54. data = data.map(input_columns=["col"], operations=split)
  55. expected = np.array(["ab", "cde", "121"], dtype='S')
  56. for d in data:
  57. np.testing.assert_array_equal(d[0], expected)
  58. def test_map2():
  59. def gen():
  60. yield np.array(["ab cde 121"], dtype='S'),
  61. data = ds.GeneratorDataset(gen, column_names=["col"])
  62. def upper(b):
  63. out = np.char.upper(b)
  64. return out
  65. data = data.map(input_columns=["col"], operations=upper)
  66. expected = np.array(["AB CDE 121"], dtype='S')
  67. for d in data:
  68. np.testing.assert_array_equal(d[0], expected)
  69. line = np.array(["This is a text file.",
  70. "Be happy every day.",
  71. "Good luck to everyone."])
  72. words = np.array([["This", "text", "file", "a"],
  73. ["Be", "happy", "day", "b"],
  74. ["女", "", "everyone", "c"]])
  75. chinese = np.array(["今天天气太好了我们一起去外面玩吧",
  76. "男默女泪",
  77. "江州市长江大桥参加了长江大桥的通车仪式"])
  78. def test_tfrecord1():
  79. s = ds.Schema()
  80. s.add_column("line", "string", [])
  81. s.add_column("words", "string", [-1])
  82. s.add_column("chinese", "string", [])
  83. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
  84. for i, d in enumerate(data.create_dict_iterator()):
  85. assert d["line"].shape == line[i].shape
  86. assert d["words"].shape == words[i].shape
  87. assert d["chinese"].shape == chinese[i].shape
  88. np.testing.assert_array_equal(line[i], to_str(d["line"]))
  89. np.testing.assert_array_equal(words[i], to_str(d["words"]))
  90. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  91. def test_tfrecord2():
  92. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False,
  93. schema='../data/dataset/testTextTFRecord/datasetSchema.json')
  94. for i, d in enumerate(data.create_dict_iterator()):
  95. assert d["line"].shape == line[i].shape
  96. assert d["words"].shape == words[i].shape
  97. assert d["chinese"].shape == chinese[i].shape
  98. np.testing.assert_array_equal(line[i], to_str(d["line"]))
  99. np.testing.assert_array_equal(words[i], to_str(d["words"]))
  100. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  101. def test_tfrecord3():
  102. s = ds.Schema()
  103. s.add_column("line", mstype.string, [])
  104. s.add_column("words", mstype.string, [-1, 2])
  105. s.add_column("chinese", mstype.string, [])
  106. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
  107. for i, d in enumerate(data.create_dict_iterator()):
  108. assert d["line"].shape == line[i].shape
  109. assert d["words"].shape == words[i].reshape([2, 2]).shape
  110. assert d["chinese"].shape == chinese[i].shape
  111. np.testing.assert_array_equal(line[i], to_str(d["line"]))
  112. np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
  113. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  114. def create_text_mindrecord():
  115. # methood to create mindrecord with string data, used to generate testTextMindRecord/test.mindrecord
  116. from mindspore.mindrecord import FileWriter
  117. mindrecord_file_name = "test.mindrecord"
  118. data = [{"english": "This is a text file.",
  119. "chinese": "今天天气太好了我们一起去外面玩吧"},
  120. {"english": "Be happy every day.",
  121. "chinese": "男默女泪"},
  122. {"english": "Good luck to everyone.",
  123. "chinese": "江州市长江大桥参加了长江大桥的通车仪式"},
  124. ]
  125. writer = FileWriter(mindrecord_file_name)
  126. schema = {"english": {"type": "string"},
  127. "chinese": {"type": "string"},
  128. }
  129. writer.add_schema(schema)
  130. writer.write_raw_data(data)
  131. writer.commit()
  132. def test_mindrecord():
  133. data = ds.MindDataset("../data/dataset/testTextMindRecord/test.mindrecord", shuffle=False)
  134. for i, d in enumerate(data.create_dict_iterator()):
  135. assert d["english"].shape == line[i].shape
  136. assert d["chinese"].shape == chinese[i].shape
  137. np.testing.assert_array_equal(line[i], to_str(d["english"]))
  138. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  139. if __name__ == '__main__':
  140. test_generator()
  141. test_basic()
  142. test_batching_strings()
  143. test_map()
  144. test_map2()
  145. test_tfrecord1()
  146. test_tfrecord2()
  147. test_tfrecord3()
  148. test_mindrecord()