You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Configuration class for Transformer."""
  16. import os
  17. import json
  18. import copy
  19. from typing import List
  20. import mindspore.common.dtype as mstype
  21. def _is_dataset_file(file: str):
  22. return "tfrecord" in file.lower() or "mindrecord" in file.lower()
  23. def _get_files_from_dir(folder: str):
  24. _files = []
  25. for file in os.listdir(folder):
  26. if _is_dataset_file(file):
  27. _files.append(os.path.join(folder, file))
  28. return _files
  29. def get_source_list(folder: str) -> List:
  30. """
  31. Get file list from a folder.
  32. Returns:
  33. list, file list.
  34. """
  35. _list = []
  36. if not folder:
  37. return _list
  38. if os.path.isdir(folder):
  39. _list = _get_files_from_dir(folder)
  40. else:
  41. if _is_dataset_file(folder):
  42. _list.append(folder)
  43. return _list
  44. PARAM_NODES = {"dataset_config",
  45. "model_config",
  46. "loss_scale_config",
  47. "learn_rate_config",
  48. "checkpoint_options"}
  49. class TransformerConfig:
  50. """
  51. Configuration for `Transformer`.
  52. Args:
  53. random_seed (int): Random seed.
  54. batch_size (int): Batch size of input dataset.
  55. epochs (int): Epoch number.
  56. dataset_sink_mode (bool): Whether enable dataset sink mode.
  57. dataset_sink_step (int): Dataset sink step.
  58. lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
  59. lr (float): Initial learning rate.
  60. min_lr (float): Minimum learning rate.
  61. decay_start_step (int): Step to decay.
  62. warmup_steps (int): Warm up steps.
  63. dataset_schema (str): Path of dataset schema file.
  64. pre_train_dataset (str): Path of pre-training dataset file or folder.
  65. fine_tune_dataset (str): Path of fine-tune dataset file or folder.
  66. test_dataset (str): Path of test dataset file or folder.
  67. valid_dataset (str): Path of validation dataset file or folder.
  68. ckpt_path (str): Checkpoints save path.
  69. save_ckpt_steps (int): Interval of saving ckpt.
  70. ckpt_prefix (str): Prefix of ckpt file.
  71. keep_ckpt_max (int): Max ckpt files number.
  72. seq_length (int): Length of input sequence. Default: 64.
  73. vocab_size (int): The shape of each embedding vector. Default: 46192.
  74. hidden_size (int): Size of embedding, attention, dim. Default: 512.
  75. num_hidden_layers (int): Encoder, Decoder layers.
  76. ngram (int): Number of tokens to predict ahead. Default: 2.
  77. accumulation_steps (int): Number of steps to hold until next gradient optimization. Default: 1.
  78. num_attention_heads (int): Number of hidden layers in the Transformer encoder/decoder
  79. cell. Default: 6.
  80. intermediate_size (int): Size of intermediate layer in the Transformer
  81. encoder/decoder cell. Default: 4096.
  82. hidden_act (str): Activation function used in the Transformer encoder/decoder
  83. cell. Default: "relu".
  84. loss_scale_mode (str): Loss scale mode. Default: "dynamic".
  85. init_loss_scale (int): Initialized loss scale.
  86. loss_scale_factor (int): Loss scale factor.
  87. scale_window (int): Window size of loss scale.
  88. beam_width (int): Beam width for beam search in inferring. Default: 4.
  89. length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
  90. label_smoothing (float): Label smoothing setting. Default: 0.1.
  91. input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
  92. dataset. Default: True.
  93. save_graphs (bool): Whether to save graphs, please set to True if mindinsight
  94. is wanted.
  95. dtype (mstype): Data type of the input. Default: mstype.float32.
  96. max_decode_length (int): Max decode length for inferring. Default: 64.
  97. hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
  98. attention_dropout_prob (float): The dropout probability for
  99. Multi-head Self-Attention. Default: 0.1.
  100. max_position_embeddings (int): Maximum length of sequences used in this
  101. model. Default: 512.
  102. initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
  103. """
  104. def __init__(self,
  105. random_seed=74,
  106. batch_size=64, epochs=1,
  107. dataset_sink_mode=True, dataset_sink_step=1,
  108. lr_scheduler="", optimizer="adam",
  109. lr=1e-4, min_lr=1e-6,
  110. decay_steps=10000, poly_lr_scheduler_power=1,
  111. decay_start_step=-1, warmup_steps=2000,
  112. pre_train_dataset: str = None,
  113. fine_tune_dataset: str = None,
  114. test_dataset: str = None,
  115. valid_dataset: str = None,
  116. ckpt_path: str = None,
  117. save_ckpt_steps=2000,
  118. ckpt_prefix="CKPT",
  119. existed_ckpt="",
  120. keep_ckpt_max=20,
  121. seq_length=128,
  122. vocab_size=46192,
  123. hidden_size=512,
  124. num_hidden_layers=6,
  125. ngram=2,
  126. accumulation_steps=1,
  127. disable_ngram_loss=False,
  128. num_attention_heads=8,
  129. intermediate_size=4096,
  130. hidden_act="relu",
  131. hidden_dropout_prob=0.1,
  132. attention_dropout_prob=0.1,
  133. max_position_embeddings=64,
  134. initializer_range=0.02,
  135. loss_scale_mode="dynamic",
  136. init_loss_scale=2 ** 10,
  137. loss_scale_factor=2, scale_window=2000,
  138. beam_width=5,
  139. length_penalty_weight=1.0,
  140. label_smoothing=0.1,
  141. input_mask_from_dataset=True,
  142. save_graphs=False,
  143. dtype=mstype.float32,
  144. max_decode_length=64):
  145. self.save_graphs = save_graphs
  146. self.random_seed = random_seed
  147. self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
  148. self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
  149. self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
  150. self.test_dataset = get_source_list(test_dataset) # type: List[str]
  151. if not isinstance(epochs, int) and epochs < 0:
  152. raise ValueError("`epoch` must be type of int.")
  153. self.epochs = epochs
  154. self.dataset_sink_mode = dataset_sink_mode
  155. self.dataset_sink_step = dataset_sink_step
  156. self.ckpt_path = ckpt_path
  157. self.keep_ckpt_max = keep_ckpt_max
  158. self.save_ckpt_steps = save_ckpt_steps
  159. self.ckpt_prefix = ckpt_prefix
  160. self.existed_ckpt = existed_ckpt
  161. self.batch_size = batch_size
  162. self.seq_length = seq_length
  163. self.vocab_size = vocab_size
  164. self.hidden_size = hidden_size
  165. self.num_hidden_layers = num_hidden_layers
  166. self.ngram = ngram
  167. self.accumulation_steps = accumulation_steps
  168. self.disable_ngram_loss = disable_ngram_loss
  169. self.num_attention_heads = num_attention_heads
  170. self.hidden_act = hidden_act
  171. self.intermediate_size = intermediate_size
  172. self.hidden_dropout_prob = hidden_dropout_prob
  173. self.attention_dropout_prob = attention_dropout_prob
  174. self.max_position_embeddings = max_position_embeddings
  175. self.initializer_range = initializer_range
  176. self.label_smoothing = label_smoothing
  177. self.beam_width = beam_width
  178. self.length_penalty_weight = length_penalty_weight
  179. self.max_decode_length = max_decode_length
  180. self.input_mask_from_dataset = input_mask_from_dataset
  181. self.compute_type = mstype.float32
  182. self.dtype = dtype
  183. self.loss_scale_mode = loss_scale_mode
  184. self.scale_window = scale_window
  185. self.loss_scale_factor = loss_scale_factor
  186. self.init_loss_scale = init_loss_scale
  187. self.optimizer = optimizer
  188. self.lr = lr
  189. self.lr_scheduler = lr_scheduler
  190. self.min_lr = min_lr
  191. self.poly_lr_scheduler_power = poly_lr_scheduler_power
  192. self.decay_steps = decay_steps
  193. self.decay_start_step = decay_start_step
  194. self.warmup_steps = warmup_steps
  195. self.train_url = ""
  196. @classmethod
  197. def from_dict(cls, json_object: dict):
  198. """Constructs a `TransformerConfig` from a Python dictionary of parameters."""
  199. _params = {}
  200. for node in PARAM_NODES:
  201. for key in json_object[node]:
  202. _params[key] = json_object[node][key]
  203. return cls(**_params)
  204. @classmethod
  205. def from_json_file(cls, json_file):
  206. """Constructs a `TransformerConfig` from a json file of parameters."""
  207. with open(json_file, "r") as reader:
  208. return cls.from_dict(json.load(reader))
  209. def to_dict(self):
  210. """Serializes this instance to a Python dictionary."""
  211. output = copy.deepcopy(self.__dict__)
  212. return output
  213. def to_json_string(self):
  214. """Serializes this instance to a JSON string."""
  215. return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"