You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

create_data.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Create training instances for Transformer."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import argparse
  20. import collections
  21. import logging
  22. import numpy as np
  23. import src.tokenization as tokenization
  24. from mindspore.mindrecord import FileWriter
  25. class SampleInstance():
  26. """A single sample instance (sentence pair)."""
  27. def __init__(self, source_sos_tokens, source_eos_tokens, target_sos_tokens, target_eos_tokens):
  28. self.source_sos_tokens = source_sos_tokens
  29. self.source_eos_tokens = source_eos_tokens
  30. self.target_sos_tokens = target_sos_tokens
  31. self.target_eos_tokens = target_eos_tokens
  32. def __str__(self):
  33. s = ""
  34. s += "source sos tokens: %s\n" % (" ".join(
  35. [tokenization.printable_text(x) for x in self.source_sos_tokens]))
  36. s += "source eos tokens: %s\n" % (" ".join(
  37. [tokenization.printable_text(x) for x in self.source_eos_tokens]))
  38. s += "target sos tokens: %s\n" % (" ".join(
  39. [tokenization.printable_text(x) for x in self.target_sos_tokens]))
  40. s += "target eos tokens: %s\n" % (" ".join(
  41. [tokenization.printable_text(x) for x in self.target_eos_tokens]))
  42. s += "\n"
  43. return s
  44. def __repr__(self):
  45. return self.__str__()
  46. def write_instance_to_file(writer, instance, tokenizer, max_seq_length):
  47. """Create files from `SampleInstance`s."""
  48. def _convert_ids_and_mask(input_tokens):
  49. input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
  50. input_mask = [1] * len(input_ids)
  51. assert len(input_ids) <= max_seq_length
  52. while len(input_ids) < max_seq_length:
  53. input_ids.append(0)
  54. input_mask.append(0)
  55. assert len(input_ids) == max_seq_length
  56. assert len(input_mask) == max_seq_length
  57. return input_ids, input_mask
  58. source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens)
  59. source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens)
  60. target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens)
  61. target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens)
  62. features = collections.OrderedDict()
  63. features["source_sos_ids"] = np.asarray(source_sos_ids)
  64. features["source_sos_mask"] = np.asarray(source_sos_mask)
  65. features["source_eos_ids"] = np.asarray(source_eos_ids)
  66. features["source_eos_mask"] = np.asarray(source_eos_mask)
  67. features["target_sos_ids"] = np.asarray(target_sos_ids)
  68. features["target_sos_mask"] = np.asarray(target_sos_mask)
  69. features["target_eos_ids"] = np.asarray(target_eos_ids)
  70. features["target_eos_mask"] = np.asarray(target_eos_mask)
  71. writer.write_raw_data([features])
  72. return features
  73. def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
  74. """Creates `SampleInstance`s for a single sentence pair."""
  75. EOS = "</s>"
  76. SOS = "<s>"
  77. if len(source_words) >= max_seq_length or len(target_words) >= max_seq_length:
  78. if clip_to_max_len:
  79. print("####lalalal")
  80. source_words = source_words[:min([len(source_words, max_seq_length-1)])]
  81. target_words = target_words[:min([len(target_words, max_seq_length-1)])]
  82. else:
  83. return None
  84. source_sos_tokens = [SOS] + source_words
  85. source_eos_tokens = source_words + [EOS]
  86. target_sos_tokens = [SOS] + target_words
  87. target_eos_tokens = target_words + [EOS]
  88. instance = SampleInstance(
  89. source_sos_tokens=source_sos_tokens,
  90. source_eos_tokens=source_eos_tokens,
  91. target_sos_tokens=target_sos_tokens,
  92. target_eos_tokens=target_eos_tokens)
  93. return instance
  94. def main():
  95. parser = argparse.ArgumentParser()
  96. parser.add_argument("--input_file", type=str, required=True,
  97. help='Input raw text file (or comma-separated list of files).')
  98. parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
  99. parser.add_argument("--num_splits", type=int, default=16,
  100. help='The MindRecord file will be split into the number of partition.')
  101. parser.add_argument("--vocab_file", type=str, required=True,
  102. help='The vocabulary file that the Transformer model was trained on.')
  103. parser.add_argument("--clip_to_max_len", type=bool, default=False,
  104. help='clip sequences to maximum sequence length.')
  105. parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
  106. args = parser.parse_args()
  107. tokenizer = tokenization.WhiteSpaceTokenizer(vocab_file=args.vocab_file)
  108. input_files = []
  109. for input_pattern in args.input_file.split(","):
  110. input_files.append(input_pattern)
  111. logging.info("*** Reading from input files ***")
  112. for input_file in input_files:
  113. logging.info(" %s", input_file)
  114. output_file = args.output_file
  115. logging.info("*** Writing to output files ***")
  116. logging.info(" %s", output_file)
  117. writer = FileWriter(output_file, args.num_splits)
  118. data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
  119. "source_sos_mask": {"type": "int64", "shape": [-1]},
  120. "source_eos_ids": {"type": "int64", "shape": [-1]},
  121. "source_eos_mask": {"type": "int64", "shape": [-1]},
  122. "target_sos_ids": {"type": "int64", "shape": [-1]},
  123. "target_sos_mask": {"type": "int64", "shape": [-1]},
  124. "target_eos_ids": {"type": "int64", "shape": [-1]},
  125. "target_eos_mask": {"type": "int64", "shape": [-1]}
  126. }
  127. writer.add_schema(data_schema, "tranformer hisi")
  128. total_written = 0
  129. total_read = 0
  130. for input_file in input_files:
  131. logging.info("*** Reading from %s ***", input_file)
  132. with open(input_file, "r") as reader:
  133. while True:
  134. line = tokenization.convert_to_unicode(reader.readline())
  135. if not line:
  136. break
  137. total_read += 1
  138. if total_read % 100000 == 0:
  139. logging.info("%d ...", total_read)
  140. source_line, target_line = line.strip().split("\t")
  141. source_tokens = tokenizer.tokenize(source_line)
  142. target_tokens = tokenizer.tokenize(target_line)
  143. if len(source_tokens) >= args.max_seq_length or len(target_tokens) >= args.max_seq_length:
  144. logging.info("ignore long sentence!")
  145. continue
  146. instance = create_training_instance(source_tokens, target_tokens, args.max_seq_length,
  147. clip_to_max_len=args.clip_to_max_len)
  148. if instance is None:
  149. continue
  150. features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length)
  151. total_written += 1
  152. if total_written <= 20:
  153. logging.info("*** Example ***")
  154. logging.info("source tokens: %s", " ".join(
  155. [tokenization.printable_text(x) for x in instance.source_eos_tokens]))
  156. logging.info("target tokens: %s", " ".join(
  157. [tokenization.printable_text(x) for x in instance.target_sos_tokens]))
  158. for feature_name in features.keys():
  159. feature = features[feature_name]
  160. logging.info("%s: %s", feature_name, feature)
  161. writer.commit()
  162. logging.info("Wrote %d total instances", total_written)
  163. if __name__ == "__main__":
  164. main()