You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

create_data.py 10 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Create training instances for Transformer."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import argparse
  20. import ast
  21. import collections
  22. import logging
  23. import numpy as np
  24. import src.tokenization as tokenization
  25. from mindspore.mindrecord import FileWriter
  26. class SampleInstance():
  27. """A single sample instance (sentence pair)."""
  28. def __init__(self, source_sos_tokens, source_eos_tokens, target_sos_tokens, target_eos_tokens):
  29. self.source_sos_tokens = source_sos_tokens
  30. self.source_eos_tokens = source_eos_tokens
  31. self.target_sos_tokens = target_sos_tokens
  32. self.target_eos_tokens = target_eos_tokens
  33. def __str__(self):
  34. s = ""
  35. s += "source sos tokens: %s\n" % (" ".join(
  36. [tokenization.convert_to_printable(x) for x in self.source_sos_tokens]))
  37. s += "source eos tokens: %s\n" % (" ".join(
  38. [tokenization.convert_to_printable(x) for x in self.source_eos_tokens]))
  39. s += "target sos tokens: %s\n" % (" ".join(
  40. [tokenization.convert_to_printable(x) for x in self.target_sos_tokens]))
  41. s += "target eos tokens: %s\n" % (" ".join(
  42. [tokenization.convert_to_printable(x) for x in self.target_eos_tokens]))
  43. s += "\n"
  44. return s
  45. def __repr__(self):
  46. return self.__str__()
  47. def get_instance_features(instance, tokenizer, max_seq_length, bucket):
  48. """Get features from `SampleInstance`s."""
  49. def _find_bucket_length(source_tokens, target_tokens):
  50. source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
  51. target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
  52. num = max(len(source_ids), len(target_ids))
  53. assert num <= bucket[-1]
  54. for index in range(1, len(bucket)):
  55. if bucket[index - 1] < num <= bucket[index]:
  56. return bucket[index]
  57. return bucket[0]
  58. def _convert_ids_and_mask(input_tokens, seq_max_bucket_length):
  59. input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
  60. input_mask = [1] * len(input_ids)
  61. assert len(input_ids) <= max_seq_length
  62. while len(input_ids) < seq_max_bucket_length:
  63. input_ids.append(0)
  64. input_mask.append(0)
  65. assert len(input_ids) == seq_max_bucket_length
  66. assert len(input_mask) == seq_max_bucket_length
  67. return input_ids, input_mask
  68. seq_max_bucket_length = _find_bucket_length(instance.source_sos_tokens, instance.target_sos_tokens)
  69. source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens, seq_max_bucket_length)
  70. source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens, seq_max_bucket_length)
  71. target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens, seq_max_bucket_length)
  72. target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens, seq_max_bucket_length)
  73. features = collections.OrderedDict()
  74. features["source_sos_ids"] = np.asarray(source_sos_ids)
  75. features["source_sos_mask"] = np.asarray(source_sos_mask)
  76. features["source_eos_ids"] = np.asarray(source_eos_ids)
  77. features["source_eos_mask"] = np.asarray(source_eos_mask)
  78. features["target_sos_ids"] = np.asarray(target_sos_ids)
  79. features["target_sos_mask"] = np.asarray(target_sos_mask)
  80. features["target_eos_ids"] = np.asarray(target_eos_ids)
  81. features["target_eos_mask"] = np.asarray(target_eos_mask)
  82. return features, seq_max_bucket_length
  83. def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
  84. """Creates `SampleInstance`s for a single sentence pair."""
  85. EOS = "</s>"
  86. SOS = "<s>"
  87. if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
  88. if clip_to_max_len:
  89. source_words = source_words[:min([len(source_words, max_seq_length-1)])]
  90. target_words = target_words[:min([len(target_words, max_seq_length-1)])]
  91. else:
  92. return None
  93. source_sos_tokens = [SOS] + source_words
  94. source_eos_tokens = source_words + [EOS]
  95. target_sos_tokens = [SOS] + target_words
  96. target_eos_tokens = target_words + [EOS]
  97. instance = SampleInstance(
  98. source_sos_tokens=source_sos_tokens,
  99. source_eos_tokens=source_eos_tokens,
  100. target_sos_tokens=target_sos_tokens,
  101. target_eos_tokens=target_eos_tokens)
  102. return instance
  103. def main():
  104. parser = argparse.ArgumentParser()
  105. parser.add_argument("--input_file", type=str, required=True,
  106. help='Input raw text file (or comma-separated list of files).')
  107. parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
  108. parser.add_argument("--num_splits", type=int, default=16,
  109. help='The MindRecord file will be split into the number of partition.')
  110. parser.add_argument("--vocab_file", type=str, required=True,
  111. help='The vocabulary file that the Transformer model was trained on.')
  112. parser.add_argument("--clip_to_max_len", type=bool, default=False,
  113. help='clip sequences to maximum sequence length.')
  114. parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
  115. parser.add_argument("--bucket", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
  116. help='bucket sequence length')
  117. args = parser.parse_args()
  118. tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
  119. input_files = []
  120. for input_pattern in args.input_file.split(","):
  121. input_files.append(input_pattern)
  122. logging.info("*** Read from input files ***")
  123. for input_file in input_files:
  124. logging.info(" %s", input_file)
  125. output_file = args.output_file
  126. logging.info("*** Write to output files ***")
  127. logging.info(" %s", output_file)
  128. total_written = 0
  129. total_read = 0
  130. feature_dict = {}
  131. for i in args.bucket:
  132. feature_dict[i] = []
  133. for input_file in input_files:
  134. logging.info("*** Reading from %s ***", input_file)
  135. with open(input_file, "r") as reader:
  136. while True:
  137. line = tokenization.convert_to_unicode(reader.readline())
  138. if not line:
  139. break
  140. total_read += 1
  141. if total_read % 100000 == 0:
  142. logging.info("Read %d ...", total_read)
  143. source_line, target_line = line.strip().split("\t")
  144. source_tokens = tokenizer.tokenize(source_line)
  145. target_tokens = tokenizer.tokenize(target_line)
  146. if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
  147. logging.info("ignore long sentence!")
  148. continue
  149. instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
  150. clip_to_max_len=args.clip_to_max_len)
  151. if instance is None:
  152. continue
  153. features, seq_max_bucket_length = get_instance_features(instance, tokenizer, args.max_seq_length,
  154. args.bucket)
  155. for key in feature_dict:
  156. if key == seq_max_bucket_length:
  157. feature_dict[key].append(features)
  158. if total_read <= 10:
  159. logging.info("*** Example ***")
  160. logging.info("source tokens: %s", " ".join(
  161. [tokenization.convert_to_printable(x) for x in instance.source_eos_tokens]))
  162. logging.info("target tokens: %s", " ".join(
  163. [tokenization.convert_to_printable(x) for x in instance.target_sos_tokens]))
  164. for feature_name in features.keys():
  165. feature = features[feature_name]
  166. logging.info("%s: %s", feature_name, feature)
  167. for i in args.bucket:
  168. if args.num_splits == 1:
  169. output_file_name = output_file
  170. else:
  171. output_file_name = output_file + '_' + str(i) + '_'
  172. writer = FileWriter(output_file_name, args.num_splits)
  173. data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
  174. "source_sos_mask": {"type": "int64", "shape": [-1]},
  175. "source_eos_ids": {"type": "int64", "shape": [-1]},
  176. "source_eos_mask": {"type": "int64", "shape": [-1]},
  177. "target_sos_ids": {"type": "int64", "shape": [-1]},
  178. "target_sos_mask": {"type": "int64", "shape": [-1]},
  179. "target_eos_ids": {"type": "int64", "shape": [-1]},
  180. "target_eos_mask": {"type": "int64", "shape": [-1]}
  181. }
  182. writer.add_schema(data_schema, "tranformer")
  183. features_ = feature_dict[i]
  184. logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))
  185. for item in features_:
  186. writer.write_raw_data([item])
  187. total_written += 1
  188. writer.commit()
  189. logging.info("Wrote %d total instances", total_written)
  190. if __name__ == "__main__":
  191. logging.basicConfig(level=logging.INFO)
  192. main()