Merge pull request !7646 from wanghua/mastertags/v1.1.0
| @@ -301,7 +301,7 @@ def do_eval_standalone(): | |||||
| input_data.append(data[i]) | input_data.append(data[i]) | ||||
| input_ids, input_mask, token_type_id, label_ids = input_data | input_ids, input_mask, token_type_id, label_ids = input_data | ||||
| logits = eval_model(input_ids, token_type_id, input_mask) | logits = eval_model(input_ids, token_type_id, input_mask) | ||||
| callback.update(logits[3], label_ids) | |||||
| callback.update(logits, label_ids) | |||||
| acc = callback.acc_num / callback.total_num | acc = callback.acc_num / callback.total_num | ||||
| print("======================================") | print("======================================") | ||||
| print("============== acc is {}".format(acc)) | print("============== acc is {}".format(acc)) | ||||
| @@ -964,7 +964,7 @@ class BertModelCLS(nn.Cell): | |||||
| The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. | The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. | ||||
| """ | """ | ||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, | def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, | ||||
| use_one_hot_embeddings=False, phase_type="teacher"): | |||||
| use_one_hot_embeddings=False, phase_type="student"): | |||||
| super(BertModelCLS, self).__init__() | super(BertModelCLS, self).__init__() | ||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | self.bert = BertModel(config, is_training, use_one_hot_embeddings) | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -992,4 +992,6 @@ class BertModelCLS(nn.Cell): | |||||
| logits = self.dense_1(cls) | logits = self.dense_1(cls) | ||||
| logits = self.cast(logits, self.dtype) | logits = self.cast(logits, self.dtype) | ||||
| log_probs = self.log_softmax(logits) | log_probs = self.log_softmax(logits) | ||||
| return seq_output, att_output, logits, log_probs | |||||
| if self._phase == 'train' or self.phase_type == "teacher": | |||||
| return seq_output, att_output, logits, log_probs | |||||
| return log_probs | |||||
| @@ -100,7 +100,7 @@ class EvalCallBack(Callback): | |||||
| input_ids, input_mask, token_type_id, label_ids = input_data | input_ids, input_mask, token_type_id, label_ids = input_data | ||||
| self.network.set_train(False) | self.network.set_train(False) | ||||
| logits = self.network(input_ids, token_type_id, input_mask) | logits = self.network(input_ids, token_type_id, input_mask) | ||||
| callback.update(logits[3], label_ids) | |||||
| callback.update(logits, label_ids) | |||||
| acc = callback.acc_num / callback.total_num | acc = callback.acc_num / callback.total_num | ||||
| with open("./eval.log", "a+") as f: | with open("./eval.log", "a+") as f: | ||||
| f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num, | f.write("acc_num {}, total_num{}, accuracy{:.6f}".format(callback.acc_num, callback.total_num, | ||||