You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_datasets_amazon_review.py 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import numpy as np
  16. import mindspore.dataset as ds
  17. import mindspore.dataset.text.transforms as a_c_trans
  18. POLARITY_DIR = '../data/dataset/testAmazonReview/polarity'
  19. FULL_DIR = '../data/dataset/testAmazonReview/full'
  20. def count_unequal_element(data_expected, data_me):
  21. assert data_expected.shape == data_me.shape
  22. assert data_expected == data_me
  23. def test_amazon_review_polarity_dataset_basic():
  24. """
  25. Feature: Test AmazonReviewPolarity Dataset.
  26. Description: read data from a single file.
  27. Expectation: the data is processed successfully.
  28. """
  29. buffer = []
  30. data = ds.AmazonReviewDataset(POLARITY_DIR, usage='test', shuffle=False)
  31. data = data.repeat(2)
  32. data = data.skip(2)
  33. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  34. buffer.append(d)
  35. assert len(buffer) == 2
  36. def test_amazon_review_full_dataset_basic():
  37. """
  38. Feature: Test AmazonReviewFull Dataset.
  39. Description: read data from a single file.
  40. Expectation: the data is processed successfully.
  41. """
  42. buffer = []
  43. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False)
  44. data = data.repeat(2)
  45. data = data.skip(2)
  46. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  47. buffer.append(d)
  48. assert len(buffer) == 4
  49. def test_amazon_review_dataset_quoted():
  50. """
  51. Feature: Test get the AmazonReview Dataset.
  52. Description: read AmazonReviewPolarityDataset data and get data.
  53. Expectation: the data is processed successfully.
  54. """
  55. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False)
  56. buffer = []
  57. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  58. buffer.extend([d['label'].item().decode("utf8"),
  59. d['title'].item().decode("utf8"),
  60. d['content'].item().decode("utf8")])
  61. assert buffer == ["1", "amazing", "unlimited buyback!",
  62. "4", "delightful", "a funny book!",
  63. "3", "Small", "It is a small ball!"]
  64. def test_amazon_review_full_dataset_usage_all():
  65. """
  66. Feature: Test AmazonReviewPolarity Dataset(usage=all).
  67. Description: read train data and test data.
  68. Expectation: the data is processed successfully.
  69. """
  70. buffer = []
  71. data = ds.AmazonReviewDataset(FULL_DIR, usage='all', shuffle=False)
  72. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  73. buffer.extend([d['label'].item().decode("utf8"),
  74. d['title'].item().decode("utf8"),
  75. d['content'].item().decode("utf8")])
  76. assert buffer == ["1", "amazing", "unlimited buyback!",
  77. "3", "Satisfied", "good quality.",
  78. "4", "delightful", "a funny book!",
  79. "5", "good", "This is an very good product.",
  80. "3", "Small", "It is a small ball!",
  81. "1", "bad", "work badly."]
  82. def test_amazon_review_polarity_dataset_usage_all():
  83. """
  84. Feature: Test AmazonReviewPolarityPolarity Dataset(usage=all).
  85. Description: read train data and test data.
  86. Expectation: the data is processed successfully.
  87. """
  88. buffer = []
  89. data = ds.AmazonReviewDataset(POLARITY_DIR, usage='all', shuffle=False)
  90. for d in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  91. buffer.extend([d['label'].item().decode("utf8"),
  92. d['title'].item().decode("utf8"),
  93. d['content'].item().decode("utf8")])
  94. assert buffer == ["1", "DVD", "It is very good!",
  95. "2", "Great Read", "I thought this book was excellent!",
  96. "2", "Book", "I would read it again lol.",
  97. "1", "Oh dear", "It is so bad!",
  98. "2", "Delicious", "A funny product."]
  99. def test_amazon_review_dataset_get_datasetsize():
  100. """
  101. Feature: Test Getters.
  102. Description: test get_dataset_size of AmazonReview dataset.
  103. Expectation: the data is processed successfully.
  104. """
  105. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False)
  106. size = data.get_dataset_size()
  107. assert size == 3
  108. def test_amazon_review_dataset_distribution():
  109. """
  110. Feature: Test AmazonReviewDataset in distribution.
  111. Description: test in a distributed state.
  112. Expectation: the data is processed successfully.
  113. """
  114. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False, num_shards=2, shard_id=0)
  115. count = 0
  116. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  117. count += 1
  118. assert count == 2
  119. def test_amazon_review_dataset_num_samples():
  120. """
  121. Feature: Test AmazonReview Dataset(num_samples = 2).
  122. Description: test get num_samples.
  123. Expectation: the data is processed successfully.
  124. """
  125. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False, num_samples=2)
  126. count = 0
  127. for _ in data.create_dict_iterator(num_epochs=1, output_numpy=True):
  128. count += 1
  129. assert count == 2
  130. def test_amazon_review_dataset_exception():
  131. """
  132. Feature: Error Test.
  133. Description: test the wrong input.
  134. Expectation: unable to read in data.
  135. """
  136. def exception_func(item):
  137. raise Exception("Error occur!")
  138. try:
  139. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False)
  140. data = data.map(operations=exception_func, input_columns=["label"], num_parallel_workers=1)
  141. for _ in data.create_dict_iterator():
  142. pass
  143. assert False
  144. except RuntimeError as e:
  145. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  146. try:
  147. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False)
  148. data = data.map(operations=exception_func, input_columns=["title"], num_parallel_workers=1)
  149. for _ in data.create_dict_iterator():
  150. pass
  151. assert False
  152. except RuntimeError as e:
  153. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  154. try:
  155. data = ds.AmazonReviewDataset(FULL_DIR, usage='test', shuffle=False)
  156. data = data.map(operations=exception_func, input_columns=["content"], num_parallel_workers=1)
  157. for _ in data.create_dict_iterator():
  158. pass
  159. assert False
  160. except RuntimeError as e:
  161. assert "map operation: [PyFunc] failed. The corresponding data files" in str(e)
  162. def test_amazon_review_dataset_pipeline():
  163. """
  164. Feature: AmazonReviewDataset
  165. Description: test AmazonReviewDataset in pipeline mode
  166. Expectation: the data is processed successfully
  167. """
  168. expected_columns1 = np.array(["3", "5", "1"], dtype=np.string_)
  169. dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False)
  170. filter_wikipedia_xml_op = a_c_trans.CaseFold()
  171. dataset = dataset.map(input_columns=["label"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
  172. i = 0
  173. for data in dataset.create_dict_iterator(output_numpy=True):
  174. count_unequal_element(np.array(expected_columns1[i]), data['label'])
  175. i += 1
  176. assert i == 3
  177. expected_columns2 = np.array(["satisfied", "good", "bad"], dtype=np.string_)
  178. dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False)
  179. filter_wikipedia_xml_op = a_c_trans.CaseFold()
  180. dataset = dataset.map(input_columns=["title"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
  181. i = 0
  182. for data in dataset.create_dict_iterator(output_numpy=True):
  183. count_unequal_element(np.array(expected_columns2[i]), data['title'])
  184. i += 1
  185. assert i == 3
  186. expected_columns3 = np.array(["good quality.",
  187. "this is an very good product.",
  188. "work badly."], dtype=np.string_)
  189. dataset = ds.AmazonReviewDataset(FULL_DIR, 'train', shuffle=False)
  190. filter_wikipedia_xml_op = a_c_trans.CaseFold()
  191. dataset = dataset.map(input_columns=["content"], operations=filter_wikipedia_xml_op, num_parallel_workers=1)
  192. i = 0
  193. for data in dataset.create_dict_iterator(output_numpy=True):
  194. count_unequal_element(np.array(expected_columns3[i]), data['content'])
  195. i += 1
  196. assert i == 3
  197. if __name__ == "__main__":
  198. test_amazon_review_polarity_dataset_basic()
  199. test_amazon_review_full_dataset_basic()
  200. test_amazon_review_dataset_quoted()
  201. test_amazon_review_full_dataset_usage_all()
  202. test_amazon_review_polarity_dataset_usage_all()
  203. test_amazon_review_dataset_get_datasetsize()
  204. test_amazon_review_dataset_distribution()
  205. test_amazon_review_dataset_num_samples()
  206. test_amazon_review_dataset_exception()
  207. test_amazon_review_dataset_pipeline()