You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prepare.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import pickle
  2. import Word2Idx
  3. def get_sets(m, n):
  4. """
  5. get a train set containing m samples and a test set containing n samples
  6. """
  7. samples = pickle.load(open("tuples.pkl","rb"))
  8. if m+n > len(samples):
  9. print("asking for too many tuples\n")
  10. return
  11. train_samples = samples[ : m]
  12. test_samples = samples[m: m+n]
  13. return train_samples, test_samples
  14. def build_wordidx():
  15. """
  16. build wordidx using word2idx
  17. """
  18. train, test = get_sets(500000, 2000)
  19. words = []
  20. for x in train:
  21. words += x[0]
  22. wordidx = Word2Idx.Word2Idx()
  23. wordidx.build(words)
  24. print(wordidx.num)
  25. print(wordidx.i2w(0))
  26. wordidx.save("wordidx.pkl")
  27. def build_sets():
  28. """
  29. build train set and test set, transform word to index
  30. """
  31. train, test = get_sets(500000, 2000)
  32. wordidx = Word2Idx.Word2Idx()
  33. wordidx.load("wordidx.pkl")
  34. train_set = []
  35. for x in train:
  36. sent = [wordidx.w2i(w) for w in x[0]]
  37. train_set.append({"sent" : sent, "class" : x[1]})
  38. test_set = []
  39. for x in test:
  40. sent = [wordidx.w2i(w) for w in x[0]]
  41. test_set.append({"sent" : sent, "class" : x[1]})
  42. pickle.dump(train_set, open("train_set.pkl", "wb"))
  43. pickle.dump(test_set, open("test_set.pkl", "wb"))
  44. if __name__ == "__main__":
  45. build_wordidx()
  46. build_sets()

一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等