|
|
|
@@ -32,15 +32,12 @@ else: |
|
|
|
path = '' |
|
|
|
sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/nlp/tinybert') |
|
|
|
|
|
|
|
from official.nlp.tinybert.src.tinybert_model import TinyBertModel |
|
|
|
from official.nlp.tinybert.src.model_utils.config import bert_student_net_cfg |
|
|
|
from train_utils import save_t |
|
|
|
from official.nlp.tinybert.src.tinybert_model import TinyBertModel # noqa: 402 |
|
|
|
from official.nlp.tinybert.src.model_utils.config import bert_student_net_cfg # noqa: 402 |
|
|
|
from train_utils import save_t # noqa: 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BertNetworkWithLoss_gd(M.nn.Cell): |
|
|
|
class BertNetworkWithLossGenDistill(M.nn.Cell): |
|
|
|
""" |
|
|
|
Provide bert pre-training loss through network. |
|
|
|
Args: |
|
|
|
@@ -53,7 +50,7 @@ class BertNetworkWithLoss_gd(M.nn.Cell): |
|
|
|
|
|
|
|
def __init__(self, student_config, is_training, use_one_hot_embeddings=False, |
|
|
|
is_att_fit=False, is_rep_fit=True): |
|
|
|
super(BertNetworkWithLoss_gd, self).__init__() |
|
|
|
super(BertNetworkWithLossGenDistill, self).__init__() |
|
|
|
# load teacher model |
|
|
|
self.bert = TinyBertModel( |
|
|
|
student_config, is_training, use_one_hot_embeddings) |
|
|
|
@@ -169,7 +166,7 @@ bert_student_net_cfg.attention_probs_dropout_prob = 0.0 |
|
|
|
bert_student_net_cfg.compute_type = mstype.float32 |
|
|
|
|
|
|
|
#==============Training=============== |
|
|
|
nloss = BertNetworkWithLoss_gd( |
|
|
|
nloss = BertNetworkWithLossGenDistill( |
|
|
|
bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False) |
|
|
|
optimizer = M.nn.Adam(nloss.bert.trainable_params(), learning_rate=1e-3, beta1=0.5, beta2=0.7, |
|
|
|
eps=1e-2, use_locking=True, use_nesterov=False, weight_decay=0.1, loss_scale=0.3) |
|
|
|
|