diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index cd35bc5c34..430f27ac3b 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -301,7 +301,7 @@ def do_eval_standalone(): input_data.append(data[i]) input_ids, input_mask, token_type_id, label_ids = input_data logits = eval_model(input_ids, token_type_id, input_mask) - callback.update(logits[3], label_ids) + callback.update(logits, label_ids) acc = callback.acc_num / callback.total_num print("======================================") print("============== acc is {}".format(acc)) diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_model.py b/model_zoo/official/nlp/tinybert/src/tinybert_model.py index 09504abcd8..d802d4ba8a 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_model.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_model.py @@ -964,7 +964,7 @@ class BertModelCLS(nn.Cell): The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. """ def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, - use_one_hot_embeddings=False, phase_type="teacher"): + use_one_hot_embeddings=False, phase_type="student"): super(BertModelCLS, self).__init__() self.bert = BertModel(config, is_training, use_one_hot_embeddings) self.cast = P.Cast() @@ -992,4 +992,6 @@ class BertModelCLS(nn.Cell): logits = self.dense_1(cls) logits = self.cast(logits, self.dtype) log_probs = self.log_softmax(logits) - return seq_output, att_output, logits, log_probs + if self._phase == 'train' or self.phase_type == "teacher": + return seq_output, att_output, logits, log_probs + return log_probs diff --git a/model_zoo/official/nlp/tinybert/src/utils.py b/model_zoo/official/nlp/tinybert/src/utils.py index 40b970aa8b..2b2fd69c3f 100644 --- a/model_zoo/official/nlp/tinybert/src/utils.py +++ b/model_zoo/official/nlp/tinybert/src/utils.py @@ -100,7 +100,7 @@ class EvalCallBack(Callback): input_ids, input_mask, token_type_id, label_ids = input_data self.network.set_train(False) logits = self.network(input_ids, token_type_id, input_mask) - callback.update(logits[3], label_ids) + callback.update(logits, label_ids) acc = callback.acc_num / callback.total_num with open("./eval.log", "a+") as f: f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num,