You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """config script"""
  16. import mindspore.common.dtype as mstype
  17. from easydict import EasyDict as edict
  18. from .tinybert_model import BertConfig
  19. from .assessment_method import Accuracy, F1, Pearsonr, Matthews
  20. gradient_cfg = edict({
  21. 'clip_type': 1,
  22. 'clip_value': 1.0
  23. })
  24. task_cfg = edict({
  25. "sst-2": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Accuracy}),
  26. "qnli": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
  27. "mnli": edict({"num_labels": 3, "seq_length": 128, "task_type": "classification", "metrics": Accuracy}),
  28. "cola": edict({"num_labels": 2, "seq_length": 64, "task_type": "classification", "metrics": Matthews}),
  29. "mrpc": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
  30. "sts-b": edict({"num_labels": 1, "seq_length": 128, "task_type": "regression", "metrics": Pearsonr}),
  31. "qqp": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": F1}),
  32. "rte": edict({"num_labels": 2, "seq_length": 128, "task_type": "classification", "metrics": Accuracy})
  33. })
  34. train_cfg = edict({
  35. 'batch_size': 16,
  36. 'loss_scale_value': 2 ** 16,
  37. 'scale_factor': 2,
  38. 'scale_window': 50,
  39. 'optimizer_cfg': edict({
  40. 'AdamWeightDecay': edict({
  41. 'learning_rate': 5e-5,
  42. 'end_learning_rate': 1e-14,
  43. 'power': 1.0,
  44. 'weight_decay': 1e-4,
  45. 'eps': 1e-6,
  46. 'decay_filter': lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
  47. 'warmup_ratio': 0.1
  48. }),
  49. }),
  50. })
  51. eval_cfg = edict({
  52. 'batch_size': 32,
  53. })
  54. teacher_net_cfg = BertConfig(
  55. seq_length=128,
  56. vocab_size=30522,
  57. hidden_size=768,
  58. num_hidden_layers=6,
  59. num_attention_heads=12,
  60. intermediate_size=3072,
  61. hidden_act="gelu",
  62. hidden_dropout_prob=0.1,
  63. attention_probs_dropout_prob=0.1,
  64. max_position_embeddings=512,
  65. type_vocab_size=2,
  66. initializer_range=0.02,
  67. use_relative_positions=False,
  68. dtype=mstype.float32,
  69. compute_type=mstype.float32,
  70. do_quant=False
  71. )
  72. student_net_cfg = BertConfig(
  73. seq_length=128,
  74. vocab_size=30522,
  75. hidden_size=768,
  76. num_hidden_layers=6,
  77. num_attention_heads=12,
  78. intermediate_size=3072,
  79. hidden_act="gelu",
  80. hidden_dropout_prob=0.1,
  81. attention_probs_dropout_prob=0.1,
  82. max_position_embeddings=512,
  83. type_vocab_size=2,
  84. initializer_range=0.02,
  85. use_relative_positions=False,
  86. dtype=mstype.float32,
  87. compute_type=mstype.float32,
  88. do_quant=True,
  89. embedding_bits=2,
  90. weight_bits=2,
  91. weight_clip_value=3.0,
  92. cls_dropout_prob=0.1,
  93. activation_init=2.5,
  94. is_lgt_fit=False
  95. )