You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """dataset api"""
  16. import os
  17. from itertools import chain
  18. import gensim
  19. import numpy as np
  20. from mindspore.mindrecord import FileWriter
  21. import mindspore.dataset as ds
  22. # preprocess part
  23. def encode_samples(tokenized_samples, word_to_idx):
  24. """ encode word to index """
  25. features = []
  26. for sample in tokenized_samples:
  27. feature = []
  28. for token in sample:
  29. if token in word_to_idx:
  30. feature.append(word_to_idx[token])
  31. else:
  32. feature.append(0)
  33. features.append(feature)
  34. return features
  35. def pad_samples(features, maxlen=50, pad=0):
  36. """ pad all features to the same length """
  37. padded_features = []
  38. for feature in features:
  39. if len(feature) >= maxlen:
  40. padded_feature = feature[:maxlen]
  41. else:
  42. padded_feature = feature
  43. while len(padded_feature) < maxlen:
  44. padded_feature.append(pad)
  45. padded_features.append(padded_feature)
  46. return padded_features
  47. def read_imdb(path, seg='train'):
  48. """ read imdb dataset """
  49. pos_or_neg = ['pos', 'neg']
  50. data = []
  51. for label in pos_or_neg:
  52. f = os.path.join(path, seg, label)
  53. rf = open(f, 'r')
  54. for line in rf:
  55. line = line.strip()
  56. if label == 'pos':
  57. data.append([line, 1])
  58. elif label == 'neg':
  59. data.append([line, 0])
  60. return data
  61. def tokenizer(text):
  62. return [tok.lower() for tok in text.split(' ')]
  63. def collect_weight(glove_path, vocab, word_to_idx, embed_size):
  64. """ collect weight """
  65. vocab_size = len(vocab)
  66. wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path,
  67. 'GoogleNews-vectors-negative300.bin'),
  68. binary=True)
  69. weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
  70. idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
  71. idx_to_word[0] = '<unk>'
  72. for i in range(len(wvmodel.index2word)):
  73. try:
  74. index = word_to_idx[wvmodel.index2word[i]]
  75. except KeyError:
  76. continue
  77. weight_np[index, :] = wvmodel.get_vector(
  78. idx_to_word[word_to_idx[wvmodel.index2word[i]]])
  79. return weight_np
  80. def preprocess(data_path, glove_path, embed_size):
  81. """ preprocess the train and test data """
  82. train_data = read_imdb(data_path, 'train')
  83. test_data = read_imdb(data_path, 'test')
  84. train_tokenized = []
  85. test_tokenized = []
  86. for review, _ in train_data:
  87. train_tokenized.append(tokenizer(review))
  88. for review, _ in test_data:
  89. test_tokenized.append(tokenizer(review))
  90. vocab = set(chain(*train_tokenized))
  91. vocab_size = len(vocab)
  92. print("vocab_size: ", vocab_size)
  93. word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
  94. word_to_idx['<unk>'] = 0
  95. train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32)
  96. train_labels = np.array([score for _, score in train_data]).astype(np.int32)
  97. test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32)
  98. test_labels = np.array([score for _, score in test_data]).astype(np.int32)
  99. weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size)
  100. return train_features, train_labels, test_features, test_labels, weight_np, vocab_size
  101. def get_imdb_data(labels_data, features_data):
  102. data_list = []
  103. for i, (label, feature) in enumerate(zip(labels_data, features_data)):
  104. data_json = {"id": i,
  105. "label": int(label),
  106. "feature": feature.reshape(-1)}
  107. data_list.append(data_json)
  108. return data_list
  109. def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
  110. """ convert imdb dataset to mindrecord """
  111. num_shard = 4
  112. train_features, train_labels, test_features, test_labels, weight_np, _ = \
  113. preprocess(data_path, glove_path, embed_size)
  114. np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
  115. print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",
  116. weight_np.shape, "type:", train_labels.dtype)
  117. # write mindrecord
  118. schema_json = {"id": {"type": "int32"},
  119. "label": {"type": "int32"},
  120. "feature": {"type": "int32", "shape": [-1]}}
  121. writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
  122. data = get_imdb_data(train_labels, train_features)
  123. writer.add_schema(schema_json, "nlp_schema")
  124. writer.add_index(["id", "label"])
  125. writer.write_raw_data(data)
  126. writer.commit()
  127. writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
  128. data = get_imdb_data(test_labels, test_features)
  129. writer.add_schema(schema_json, "nlp_schema")
  130. writer.add_index(["id", "label"])
  131. writer.write_raw_data(data)
  132. writer.commit()
  133. def create_dataset(base_path, batch_size, is_train):
  134. """Create dataset for training."""
  135. columns_list = ["feature", "label"]
  136. num_consumer = 4
  137. if is_train:
  138. path = os.path.join(base_path, 'aclImdb_train.mindrecord0')
  139. else:
  140. path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
  141. data_set = ds.MindDataset(path, columns_list, num_consumer)
  142. ds.config.set_seed(0)
  143. data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
  144. data_set = data_set.batch(batch_size, drop_remainder=True)
  145. return data_set