You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ngram_op.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Testing Ngram in mindspore.dataset
  17. """
  18. import numpy as np
  19. import mindspore.dataset as ds
  20. import mindspore.dataset.text as text
  21. def test_ngram_callable():
  22. """
  23. Test ngram op is callable
  24. """
  25. op = text.Ngram(2, separator="-")
  26. input1 = " WildRose Country"
  27. input1 = np.array(input1.split(" "), dtype='S')
  28. expect1 = ['-WildRose', 'WildRose-Country']
  29. result1 = op(input1)
  30. assert np.array_equal(result1, expect1)
  31. input2 = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"]
  32. expect2 = ["WildRose Country-Canada's Ocean Playground", "Canada's Ocean Playground-Land of Living Skies"]
  33. result2 = op(input2)
  34. assert np.array_equal(result2, expect2)
  35. def test_multiple_ngrams():
  36. """ test n-gram where n is a list of integers"""
  37. plates_mottos = ["WildRose Country", "Canada's Ocean Playground", "Land of Living Skies"]
  38. n_gram_mottos = []
  39. n_gram_mottos.append(
  40. ['WildRose', 'Country', '_ WildRose', 'WildRose Country', 'Country _', '_ _ WildRose', '_ WildRose Country',
  41. 'WildRose Country _', 'Country _ _'])
  42. n_gram_mottos.append(
  43. ["Canada's", 'Ocean', 'Playground', "_ Canada's", "Canada's Ocean", 'Ocean Playground', 'Playground _',
  44. "_ _ Canada's", "_ Canada's Ocean", "Canada's Ocean Playground", 'Ocean Playground _', 'Playground _ _'])
  45. n_gram_mottos.append(
  46. ['Land', 'of', 'Living', 'Skies', '_ Land', 'Land of', 'of Living', 'Living Skies', 'Skies _', '_ _ Land',
  47. '_ Land of', 'Land of Living', 'of Living Skies', 'Living Skies _', 'Skies _ _'])
  48. def gen(texts):
  49. for line in texts:
  50. yield (np.array(line.split(" "), dtype='S'),)
  51. dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
  52. dataset = dataset.map(operations=text.Ngram([1, 2, 3], ("_", 2), ("_", 2), " "), input_columns="text")
  53. i = 0
  54. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  55. assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i]
  56. i += 1
  57. def test_simple_ngram():
  58. """ test simple gram with only one n value"""
  59. plates_mottos = ["Friendly Manitoba", "Yours to Discover", "Land of Living Skies",
  60. "Birthplace of the Confederation"]
  61. n_gram_mottos = [[""]]
  62. n_gram_mottos.append(["Yours to Discover"])
  63. n_gram_mottos.append(['Land of Living', 'of Living Skies'])
  64. n_gram_mottos.append(['Birthplace of the', 'of the Confederation'])
  65. def gen(texts):
  66. for line in texts:
  67. yield (np.array(line.split(" "), dtype='S'),)
  68. dataset = ds.GeneratorDataset(gen(plates_mottos), column_names=["text"])
  69. dataset = dataset.map(operations=text.Ngram(3, separator=" "), input_columns="text")
  70. i = 0
  71. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  72. assert [d.decode("utf8") for d in data["text"]] == n_gram_mottos[i], i
  73. i += 1
  74. def test_corner_cases():
  75. """ testing various corner cases and exceptions"""
  76. def test_config(input_line, n, l_pad=("", 0), r_pad=("", 0), sep=" "):
  77. def gen(texts):
  78. yield (np.array(texts.split(" "), dtype='S'),)
  79. try:
  80. dataset = ds.GeneratorDataset(gen(input_line), column_names=["text"])
  81. dataset = dataset.map(operations=text.Ngram(n, l_pad, r_pad, separator=sep), input_columns=["text"])
  82. for data in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
  83. return [d.decode("utf8") for d in data["text"]]
  84. except (ValueError, TypeError) as e:
  85. return str(e)
  86. # test tensor length smaller than n
  87. assert test_config("Lone Star", [2, 3, 4, 5]) == ["Lone Star", "", "", ""]
  88. # test empty separator
  89. assert test_config("Beautiful British Columbia", 2, sep="") == ['BeautifulBritish', 'BritishColumbia']
  90. # test separator with longer length
  91. assert test_config("Beautiful British Columbia", 3, sep="^-^") == ['Beautiful^-^British^-^Columbia']
  92. # test left pad != right pad
  93. assert test_config("Lone Star", 4, ("The", 1), ("State", 1)) == ['The Lone Star State']
  94. # test invalid n
  95. assert "gram[1] with value [1] is not of type (<class 'int'>,)" in test_config("Yours to Discover", [1, [1]])
  96. assert "n needs to be a non-empty list" in test_config("Yours to Discover", [])
  97. # test invalid pad
  98. assert "padding width need to be positive numbers" in test_config("Yours to Discover", [1], ("str", -1))
  99. assert "pad needs to be a tuple of (str, int)" in test_config("Yours to Discover", [1], ("str", "rts"))
  100. # test 0 as in valid input
  101. assert "gram_0 must be greater than 0" in test_config("Yours to Discover", 0)
  102. assert "gram_0 must be greater than 0" in test_config("Yours to Discover", [0])
  103. assert "gram_1 must be greater than 0" in test_config("Yours to Discover", [1, 0])
  104. if __name__ == '__main__':
  105. test_ngram_callable()
  106. test_multiple_ngrams()
  107. test_simple_ngram()
  108. test_corner_cases()