You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tensor_string.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore._c_dataengine as cde
  17. import mindspore.common.dtype as mstype
  18. import mindspore.dataset as ds
  19. from mindspore.dataset.text import to_str, to_bytes
  20. def test_basic():
  21. x = np.array([["ab", "cde", "121"], ["x", "km", "789"]], dtype='S')
  22. n = cde.Tensor(x)
  23. arr = n.as_array()
  24. np.testing.assert_array_equal(x, arr)
  25. def compare(strings, dtype='S'):
  26. arr = np.array(strings, dtype=dtype)
  27. def gen():
  28. (yield arr,)
  29. data = ds.GeneratorDataset(gen, column_names=["col"])
  30. for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
  31. np.testing.assert_array_equal(d[0], arr.astype('S'))
  32. def test_generator():
  33. compare(["ab"])
  34. compare(["", ""])
  35. compare([""])
  36. compare(["ab", ""])
  37. compare(["ab", "cde", "121"])
  38. compare([["ab", "cde", "121"], ["x", "km", "789"]])
  39. compare([["ab", "", "121"], ["", "km", "789"]])
  40. compare(["ab"], dtype='U')
  41. compare(["", ""], dtype='U')
  42. compare([""], dtype='U')
  43. compare(["ab", ""], dtype='U')
  44. compare(["", ""], dtype='U')
  45. compare(["", "ab"], dtype='U')
  46. compare(["ab", "cde", "121"], dtype='U')
  47. compare([["ab", "cde", "121"], ["x", "km", "789"]], dtype='U')
  48. compare([["ab", "", "121"], ["", "km", "789"]], dtype='U')
  49. line = np.array(["This is a text file.",
  50. "Be happy every day.",
  51. "Good luck to everyone."])
  52. words = np.array([["This", "text", "file", "a"],
  53. ["Be", "happy", "day", "b"],
  54. ["女", "", "everyone", "c"]])
  55. chinese = np.array(["今天天气太好了我们一起去外面玩吧",
  56. "男默女泪",
  57. "江州市长江大桥参加了长江大桥的通车仪式"])
  58. def test_batching_strings():
  59. def gen():
  60. for row in chinese:
  61. yield (np.array(row),)
  62. data = ds.GeneratorDataset(gen, column_names=["col"])
  63. data = data.batch(2, drop_remainder=True)
  64. for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
  65. np.testing.assert_array_equal(d[0], to_bytes(chinese[0:2]))
  66. def test_map():
  67. def gen():
  68. yield (np.array(["ab cde 121"], dtype='S'),)
  69. data = ds.GeneratorDataset(gen, column_names=["col"])
  70. def split(b):
  71. s = to_str(b)
  72. splits = s.item().split()
  73. return np.array(splits)
  74. data = data.map(operations=split, input_columns=["col"])
  75. expected = np.array(["ab", "cde", "121"], dtype='S')
  76. for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
  77. np.testing.assert_array_equal(d[0], expected)
  78. def test_map2():
  79. def gen():
  80. yield (np.array(["ab cde 121"], dtype='S'),)
  81. data = ds.GeneratorDataset(gen, column_names=["col"])
  82. def upper(b):
  83. out = np.char.upper(b)
  84. return out
  85. data = data.map(operations=upper, input_columns=["col"])
  86. expected = np.array(["AB CDE 121"], dtype='S')
  87. for d in data.create_tuple_iterator(num_epochs=1, output_numpy=True):
  88. np.testing.assert_array_equal(d[0], expected)
  89. def test_tfrecord1():
  90. s = ds.Schema()
  91. s.add_column("line", "string", [])
  92. s.add_column("words", "string", [-1])
  93. s.add_column("chinese", "string", [])
  94. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
  95. for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  96. assert d["line"].shape == line[i].shape
  97. assert d["words"].shape == words[i].shape
  98. assert d["chinese"].shape == chinese[i].shape
  99. np.testing.assert_array_equal(line[i], to_str(d["line"]))
  100. np.testing.assert_array_equal(words[i], to_str(d["words"]))
  101. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  102. def test_tfrecord2():
  103. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False,
  104. schema='../data/dataset/testTextTFRecord/datasetSchema.json')
  105. for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  106. assert d["line"].shape == line[i].shape
  107. assert d["words"].shape == words[i].shape
  108. assert d["chinese"].shape == chinese[i].shape
  109. np.testing.assert_array_equal(line[i], to_str(d["line"]))
  110. np.testing.assert_array_equal(words[i], to_str(d["words"]))
  111. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  112. def test_tfrecord3():
  113. s = ds.Schema()
  114. s.add_column("line", mstype.string, [])
  115. s.add_column("words", mstype.string, [-1, 2])
  116. s.add_column("chinese", mstype.string, [])
  117. data = ds.TFRecordDataset("../data/dataset/testTextTFRecord/text.tfrecord", shuffle=False, schema=s)
  118. for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  119. assert d["line"].shape == line[i].shape
  120. assert d["words"].shape == words[i].reshape([2, 2]).shape
  121. assert d["chinese"].shape == chinese[i].shape
  122. np.testing.assert_array_equal(line[i], to_str(d["line"]))
  123. np.testing.assert_array_equal(words[i].reshape([2, 2]), to_str(d["words"]))
  124. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  125. def create_text_mindrecord():
  126. # methood to create mindrecord with string data, used to generate testTextMindRecord/test.mindrecord
  127. from mindspore.mindrecord import FileWriter
  128. mindrecord_file_name = "test.mindrecord"
  129. data = [{"english": "This is a text file.",
  130. "chinese": "今天天气太好了我们一起去外面玩吧"},
  131. {"english": "Be happy every day.",
  132. "chinese": "男默女泪"},
  133. {"english": "Good luck to everyone.",
  134. "chinese": "江州市长江大桥参加了长江大桥的通车仪式"},
  135. ]
  136. writer = FileWriter(mindrecord_file_name)
  137. schema = {"english": {"type": "string"},
  138. "chinese": {"type": "string"},
  139. }
  140. writer.add_schema(schema)
  141. writer.write_raw_data(data)
  142. writer.commit()
  143. def test_mindrecord():
  144. data = ds.MindDataset("../data/dataset/testTextMindRecord/test.mindrecord", shuffle=False)
  145. for i, d in enumerate(data.create_dict_iterator(num_epochs=1, output_numpy=True)):
  146. assert d["english"].shape == line[i].shape
  147. assert d["chinese"].shape == chinese[i].shape
  148. np.testing.assert_array_equal(line[i], to_str(d["english"]))
  149. np.testing.assert_array_equal(chinese[i], to_str(d["chinese"]))
  150. # The following tests cases were copied from test_pad_batch but changed to strings instead
  151. # this generator function yield two columns
  152. # col1d: [0],[1], [2], [3]
  153. # col2d: [[100],[200]], [[101],[201]], [102],[202]], [103],[203]]
  154. def gen_2cols(num):
  155. for i in range(num):
  156. yield (np.array([str(i)]), np.array([[str(i + 100)], [str(i + 200)]]))
  157. # this generator function yield one column of variable shapes
  158. # col: [0], [0,1], [0,1,2], [0,1,2,3]
  159. def gen_var_col(num):
  160. for i in range(num):
  161. yield (np.array([str(j) for j in range(i + 1)]),)
  162. # this generator function yield two columns of variable shapes
  163. # col1: [0], [0,1], [0,1,2], [0,1,2,3]
  164. # col2: [100], [100,101], [100,101,102], [100,110,102,103]
  165. def gen_var_cols(num):
  166. for i in range(num):
  167. yield (np.array([str(j) for j in range(i + 1)]), np.array([str(100 + j) for j in range(i + 1)]))
  168. # this generator function yield two columns of variable shapes
  169. # col1: [[0]], [[0,1]], [[0,1,2]], [[0,1,2,3]]
  170. # col2: [[100]], [[100,101]], [[100,101,102]], [[100,110,102,103]]
  171. def gen_var_cols_2d(num):
  172. for i in range(num):
  173. yield (np.array([[str(j) for j in range(i + 1)]]), np.array([[str(100 + j) for j in range(i + 1)]]))
  174. def test_batch_padding_01():
  175. data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
  176. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([2, 2], b"-2"), "col1d": ([2], b"-1")})
  177. data1 = data1.repeat(2)
  178. for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  179. np.testing.assert_array_equal([[b"0", b"-1"], [b"1", b"-1"]], data["col1d"])
  180. np.testing.assert_array_equal([[[b"100", b"-2"], [b"200", b"-2"]], [[b"101", b"-2"], [b"201", b"-2"]]],
  181. data["col2d"])
  182. def test_batch_padding_02():
  183. data1 = ds.GeneratorDataset((lambda: gen_2cols(2)), ["col1d", "col2d"])
  184. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col2d": ([1, 2], "")})
  185. data1 = data1.repeat(2)
  186. for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  187. np.testing.assert_array_equal([[b"0"], [b"1"]], data["col1d"])
  188. np.testing.assert_array_equal([[[b"100", b""]], [[b"101", b""]]], data["col2d"])
  189. def test_batch_padding_03():
  190. data1 = ds.GeneratorDataset((lambda: gen_var_col(4)), ["col"])
  191. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={"col": (None, "PAD_VALUE")}) # pad automatically
  192. data1 = data1.repeat(2)
  193. res = dict()
  194. for ind, data in enumerate(data1.create_dict_iterator(num_epochs=1, output_numpy=True)):
  195. res[ind] = data["col"].copy()
  196. np.testing.assert_array_equal(res[0], [[b"0", b"PAD_VALUE"], [0, 1]])
  197. np.testing.assert_array_equal(res[1], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]])
  198. np.testing.assert_array_equal(res[2], [[b"0", b"PAD_VALUE"], [b"0", b"1"]])
  199. np.testing.assert_array_equal(res[3], [[b"0", b"1", b"2", b"PAD_VALUE"], [b"0", b"1", b"2", b"3"]])
  200. def test_batch_padding_04():
  201. data1 = ds.GeneratorDataset((lambda: gen_var_cols(2)), ["col1", "col2"])
  202. data1 = data1.batch(batch_size=2, drop_remainder=False, pad_info={}) # pad automatically
  203. data1 = data1.repeat(2)
  204. for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  205. np.testing.assert_array_equal(data["col1"], [[b"0", b""], [b"0", b"1"]])
  206. np.testing.assert_array_equal(data["col2"], [[b"100", b""], [b"100", b"101"]])
  207. def test_batch_padding_05():
  208. data1 = ds.GeneratorDataset((lambda: gen_var_cols_2d(3)), ["col1", "col2"])
  209. data1 = data1.batch(batch_size=3, drop_remainder=False,
  210. pad_info={"col2": ([2, None], "-2"), "col1": (None, "-1")}) # pad automatically
  211. for data in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
  212. np.testing.assert_array_equal(data["col1"],
  213. [[[b"0", b"-1", b"-1"]], [[b"0", b"1", b"-1"]], [[b"0", b"1", b"2"]]])
  214. np.testing.assert_array_equal(data["col2"],
  215. [[[b"100", b"-2", b"-2"], [b"-2", b"-2", b"-2"]],
  216. [[b"100", b"101", b"-2"], [b"-2", b"-2", b"-2"]],
  217. [[b"100", b"101", b"102"], [b"-2", b"-2", b"-2"]]])
  218. if __name__ == '__main__':
  219. test_generator()
  220. test_basic()
  221. test_batching_strings()
  222. test_map()
  223. test_map2()
  224. test_tfrecord1()
  225. test_tfrecord2()
  226. test_tfrecord3()
  227. test_mindrecord()
  228. test_batch_padding_01()
  229. test_batch_padding_02()
  230. test_batch_padding_03()
  231. test_batch_padding_04()
  232. test_batch_padding_05()