You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """dataset api"""
  16. import os
  17. from itertools import chain
  18. import gensim
  19. import numpy as np
  20. from mindspore.mindrecord import FileWriter
  21. import mindspore.dataset as ds
  22. # preprocess part
  23. def encode_samples(tokenized_samples, word_to_idx):
  24. """ encode word to index """
  25. features = []
  26. for sample in tokenized_samples:
  27. feature = []
  28. for token in sample:
  29. if token in word_to_idx:
  30. feature.append(word_to_idx[token])
  31. else:
  32. feature.append(0)
  33. features.append(feature)
  34. return features
  35. def pad_samples(features, maxlen=50, pad=0):
  36. """ pad all features to the same length """
  37. padded_features = []
  38. for feature in features:
  39. if len(feature) >= maxlen:
  40. padded_feature = feature[:maxlen]
  41. else:
  42. padded_feature = feature
  43. while len(padded_feature) < maxlen:
  44. padded_feature.append(pad)
  45. padded_features.append(padded_feature)
  46. return padded_features
  47. def read_imdb(path, seg='train'):
  48. """ read imdb dataset """
  49. pos_or_neg = ['pos', 'neg']
  50. data = []
  51. for label in pos_or_neg:
  52. f = os.path.join(path, seg, label)
  53. rf = open(f, 'r')
  54. for line in rf:
  55. line = line.strip()
  56. if label == 'pos':
  57. data.append([line, 1])
  58. elif label == 'neg':
  59. data.append([line, 0])
  60. return data
  61. def tokenizer(text):
  62. return [tok.lower() for tok in text.split(' ')]
  63. def collect_weight(glove_path, vocab, word_to_idx, embed_size):
  64. """ collect weight """
  65. vocab_size = len(vocab)
  66. # wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, 'glove.6B.300d.txt'),
  67. # binary=False, encoding='utf-8')
  68. wvmodel = gensim.models.KeyedVectors.load_word2vec_format(os.path.join(glove_path, \
  69. 'GoogleNews-vectors-negative300.bin'), binary=True)
  70. weight_np = np.zeros((vocab_size + 1, embed_size)).astype(np.float32)
  71. idx_to_word = {i + 1: word for i, word in enumerate(vocab)}
  72. idx_to_word[0] = '<unk>'
  73. for i in range(len(wvmodel.index2word)):
  74. try:
  75. index = word_to_idx[wvmodel.index2word[i]]
  76. except KeyError:
  77. continue
  78. weight_np[index, :] = wvmodel.get_vector(
  79. idx_to_word[word_to_idx[wvmodel.index2word[i]]])
  80. return weight_np
  81. def preprocess(data_path, glove_path, embed_size):
  82. """ preprocess the train and test data """
  83. train_data = read_imdb(data_path, 'train')
  84. test_data = read_imdb(data_path, 'test')
  85. train_tokenized = []
  86. test_tokenized = []
  87. for review, _ in train_data:
  88. train_tokenized.append(tokenizer(review))
  89. for review, _ in test_data:
  90. test_tokenized.append(tokenizer(review))
  91. vocab = set(chain(*train_tokenized))
  92. vocab_size = len(vocab)
  93. print("vocab_size: ", vocab_size)
  94. word_to_idx = {word: i + 1 for i, word in enumerate(vocab)}
  95. word_to_idx['<unk>'] = 0
  96. train_features = np.array(pad_samples(encode_samples(train_tokenized, word_to_idx))).astype(np.int32)
  97. train_labels = np.array([score for _, score in train_data]).astype(np.int32)
  98. test_features = np.array(pad_samples(encode_samples(test_tokenized, word_to_idx))).astype(np.int32)
  99. test_labels = np.array([score for _, score in test_data]).astype(np.int32)
  100. weight_np = collect_weight(glove_path, vocab, word_to_idx, embed_size)
  101. return train_features, train_labels, test_features, test_labels, weight_np, vocab_size
  102. def get_imdb_data(labels_data, features_data):
  103. data_list = []
  104. for i, (label, feature) in enumerate(zip(labels_data, features_data)):
  105. data_json = {"id": i,
  106. "label": int(label),
  107. "feature": feature.reshape(-1)}
  108. data_list.append(data_json)
  109. return data_list
  110. def convert_to_mindrecord(embed_size, data_path, proprocess_path, glove_path):
  111. """ convert imdb dataset to mindrecord """
  112. num_shard = 4
  113. train_features, train_labels, test_features, test_labels, weight_np, _ = \
  114. preprocess(data_path, glove_path, embed_size)
  115. np.savetxt(os.path.join(proprocess_path, 'weight.txt'), weight_np)
  116. print("train_features.shape:", train_features.shape, "train_labels.shape:", train_labels.shape, "weight_np.shape:",\
  117. weight_np.shape, "type:", train_labels.dtype)
  118. # write mindrecord
  119. schema_json = {"id": {"type": "int32"},
  120. "label": {"type": "int32"},
  121. "feature": {"type": "int32", "shape": [-1]}}
  122. writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_train.mindrecord'), num_shard)
  123. data = get_imdb_data(train_labels, train_features)
  124. writer.add_schema(schema_json, "nlp_schema")
  125. writer.add_index(["id", "label"])
  126. writer.write_raw_data(data)
  127. writer.commit()
  128. writer = FileWriter(os.path.join(proprocess_path, 'aclImdb_test.mindrecord'), num_shard)
  129. data = get_imdb_data(test_labels, test_features)
  130. writer.add_schema(schema_json, "nlp_schema")
  131. writer.add_index(["id", "label"])
  132. writer.write_raw_data(data)
  133. writer.commit()
  134. def create_dataset(base_path, batch_size, num_epochs, is_train):
  135. """Create dataset for training."""
  136. columns_list = ["feature", "label"]
  137. num_consumer = 4
  138. if is_train:
  139. path = os.path.join(base_path, 'aclImdb_train.mindrecord0')
  140. else:
  141. path = os.path.join(base_path, 'aclImdb_test.mindrecord0')
  142. data_set = ds.MindDataset(path, columns_list, num_consumer)
  143. ds.config.set_seed(1)
  144. data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
  145. data_set = data_set.batch(batch_size, drop_remainder=True)
  146. return data_set