You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_vocab.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import mindspore.dataset as ds
  2. import mindspore.dataset.text as text
  3. # this file contains "home is behind the world head" each word is 1 line
  4. DATA_FILE = "../data/dataset/testVocab/words.txt"
  5. VOCAB_FILE = "../data/dataset/testVocab/vocab_list.txt"
  6. def test_from_list():
  7. vocab = text.Vocab.from_list("home IS behind the world ahead !".split(" "))
  8. lookup = text.Lookup(vocab)
  9. data = ds.TextFileDataset(DATA_FILE, shuffle=False)
  10. data = data.map(input_columns=["text"], operations=lookup)
  11. ind = 0
  12. res = [2, 1, 4, 5, 6, 7]
  13. for d in data.create_dict_iterator():
  14. assert d["text"] == res[ind], ind
  15. ind += 1
  16. def test_from_file():
  17. vocab = text.Vocab.from_file(VOCAB_FILE, ",")
  18. lookup = text.Lookup(vocab)
  19. data = ds.TextFileDataset(DATA_FILE, shuffle=False)
  20. data = data.map(input_columns=["text"], operations=lookup)
  21. ind = 0
  22. res = [10, 11, 12, 15, 13, 14]
  23. for d in data.create_dict_iterator():
  24. assert d["text"] == res[ind], ind
  25. ind += 1
  26. def test_from_dict():
  27. vocab = text.Vocab.from_dict({"home": 3, "behind": 2, "the": 4, "world": 5, "<unk>": 6})
  28. lookup = text.Lookup(vocab, 6) # default value is -1
  29. data = ds.TextFileDataset(DATA_FILE, shuffle=False)
  30. data = data.map(input_columns=["text"], operations=lookup)
  31. res = [3, 6, 2, 4, 5, 6]
  32. ind = 0
  33. for d in data.create_dict_iterator():
  34. assert d["text"] == res[ind], ind
  35. ind += 1
  36. if __name__ == '__main__':
  37. test_from_list()
  38. test_from_file()
  39. test_from_dict()