Merge pull request !2790 from wanghua/mastertags/v0.6.0-beta
| @@ -5,9 +5,9 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( | |||||
| ## Requirements | ## Requirements | ||||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | - Install [MindSpore](https://www.mindspore.cn/install/en). | ||||
| - Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. | - Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. | ||||
| - Download the CLUE/SQuAD v1.1 dataset for fine-tuning and evaluation. | |||||
| - Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. | |||||
| > Notes: | > Notes: | ||||
| If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. | |||||
| If you are running a fine-tuning or evaluation task, prepare a checkpoint from pre-train. | |||||
| ## Running the Example | ## Running the Example | ||||
| ### Pre-Training | ### Pre-Training | ||||
| @@ -24,31 +24,15 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( | |||||
| sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH | sh scripts/run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR MINDSPORE_HCCL_CONFIG_PATH | ||||
| ``` | ``` | ||||
| ### Fine-Tuning | |||||
| - Set options in `finetune_config.py`. Make sure the 'data_file', 'schema_file' and 'pre_training_file' are set to your own path. Set the 'pre_training_ckpt' to a saved checkpoint file generated after pre-training. | |||||
| ### Fine-Tuning and Evaluation | |||||
| - Set bert network config and optimizer hyperparameters in `finetune_eval_config.py`. | |||||
| - Run `finetune.py` for fine-tuning of BERT-base and BERT-NEZHA model. | |||||
| - Set task related hyperparameters in scripts/run_XXX.sh. | |||||
| ```bash | |||||
| python finetune.py | |||||
| ``` | |||||
| ### Evaluation | |||||
| - Set options in `evaluation_config.py`. Make sure the 'data_file', 'schema_file' and 'finetune_ckpt' are set to your own path. | |||||
| - NER: Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model. | |||||
| ```bash | |||||
| python evaluation.py | |||||
| ``` | |||||
| - SQuAD v1.1: Run `squadeval.py` and `SQuAD_postprocess.py` for evaluation of BERT-base and BERT-NEZHA model. | |||||
| ```bash | |||||
| python squadeval.py | |||||
| ``` | |||||
| - Run `bash scripts/run_XXX.py` for fine-tuning of BERT-base and BERT-NEZHA model. | |||||
| ```bash | ```bash | ||||
| python SQuAD_postprocess.py | |||||
| bash scripts/run_XXX.sh | |||||
| ``` | ``` | ||||
| ## Usage | ## Usage | ||||
| @@ -88,26 +72,56 @@ config.py: | |||||
| scale_window steps for once updatation of loss scale: N, default is 1000 | scale_window steps for once updatation of loss scale: N, default is 1000 | ||||
| optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | ||||
| finetune_config.py: | |||||
| task task type: SeqLabeling | Regression | Classification | COLA | SQUAD | |||||
| num_labels number of labels to do classification | |||||
| data_file dataset file to load: PATH, default is "/your/path/train.tfrecord" | |||||
| schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" | |||||
| epoch_num repeat counts of training: N, default is 5 | |||||
| ckpt_prefix prefix used to save checkpoint files: PREFIX, default is "bert" | |||||
| ckpt_dir path to save checkpoint files: PATH, default is None | |||||
| pre_training_ckpt checkpoint file to load: PATH, default is "/your/path/pre_training.ckpt" | |||||
| use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False | |||||
| optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | |||||
| evaluation_config.py: | |||||
| task task type: SeqLabeling | Regression | Classification | COLA | |||||
| num_labels number of labels to do classsification | |||||
| data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord" | |||||
| schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" | |||||
| finetune_ckpt checkpoint file to load: PATH, default is "/your/path/your.ckpt" | |||||
| use_crf whether to use crf for evaluation. use_crf takes effect only when task type is NER, default is False | |||||
| clue_benchmark whether to use clue benchmark. clue_benchmark takes effect only when task type is NER, default is False | |||||
| scripts/run_ner.sh: | |||||
| device_target targeted device to run task: Ascend | GPU | |||||
| do_train whether to run training on training set: true | false | |||||
| do_eval whether to run eval on dev set: true | false | |||||
| assessment_method assessment method to do evaluation: f1 | clue_benchmark | |||||
| use_crf whether to use crf to calculate loss: true | false | |||||
| device_id device id to run task | |||||
| epoch_num total number of training epochs to perform | |||||
| num_class number of classes to do labeling | |||||
| vocab_file_path the vocabulary file that the BERT model was trained on | |||||
| label2id_file_path label to id json file | |||||
| save_finetune_checkpoint_path path to save generated finetuning checkpoint | |||||
| load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) | |||||
| load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval | |||||
| train_data_file_path ner tfrecord for training. E.g., train.tfrecord | |||||
| eval_data_file_path ner tfrecord for predictions if f1 is used to evaluate result, ner json for predictions if clue_benchmark is used to evaluate result | |||||
| schema_file_path path to datafile schema file | |||||
| scripts/run_squad.sh: | |||||
| device_target targeted device to run task: Ascend | GPU | |||||
| do_train whether to run training on training set: true | false | |||||
| do_eval whether to run eval on dev set: true | false | |||||
| device_id device id to run task | |||||
| epoch_num total number of training epochs to perform | |||||
| num_class number of classes to classify, usually 2 for squad task | |||||
| vocab_file_path the vocabulary file that the BERT model was trained on | |||||
| eval_json_path path to squad dev json file | |||||
| save_finetune_checkpoint_path path to save generated finetuning checkpoint | |||||
| load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) | |||||
| load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval | |||||
| train_data_file_path squad tfrecord for training. E.g., train1.1.tfrecord | |||||
| eval_data_file_path squad tfrecord for predictions. E.g., dev1.1.tfrecord | |||||
| schema_file_path path to datafile schema file | |||||
| scripts/run_classifier.sh | |||||
| device_target targeted device to run task: Ascend | GPU | |||||
| do_train whether to run training on training set: true | false | |||||
| do_eval whether to run eval on dev set: true | false | |||||
| assessment_method assessment method to do evaluation: accuracy | f1 | mcc | spearman_correlation | |||||
| device_id device id to run task | |||||
| epoch_num total number of training epochs to perform | |||||
| num_class number of classes to do labeling | |||||
| save_finetune_checkpoint_path path to save generated finetuning checkpoint | |||||
| load_pretrain_checkpoint_path initial checkpoint (usually from a pre-trained BERT model) | |||||
| load_finetune_checkpoint_path give a finetuning checkpoint path if only do eval | |||||
| train_data_file_path tfrecord for training. E.g., train.tfrecord | |||||
| eval_data_file_path tfrecord for predictions. E.g., dev.tfrecord | |||||
| schema_file_path path to datafile schema file | |||||
| ``` | ``` | ||||
| ### Parameters: | ### Parameters: | ||||
| @@ -115,7 +129,7 @@ evaluation_config.py: | |||||
| Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): | Parameters for dataset and network (Pre-Training/Fine-Tuning/Evaluation): | ||||
| batch_size batch size of input dataset: N, default is 16 | batch_size batch size of input dataset: N, default is 16 | ||||
| seq_length length of input sequence: N, default is 128 | seq_length length of input sequence: N, default is 128 | ||||
| vocab_size size of each embedding vector: N, default is 21136 | |||||
| vocab_size size of each embedding vector: N, must be consistant with the dataset you use. Default is 21136 | |||||
| hidden_size size of bert encoder layers: N, default is 768 | hidden_size size of bert encoder layers: N, default is 768 | ||||
| num_hidden_layers number of hidden layers: N, default is 12 | num_hidden_layers number of hidden layers: N, default is 12 | ||||
| num_attention_heads number of attention heads: N, default is 12 | num_attention_heads number of attention heads: N, default is 12 | ||||
| @@ -1,272 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Bert evaluation script. | |||||
| """ | |||||
| import os | |||||
| import argparse | |||||
| import math | |||||
| import numpy as np | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore import log as logger | |||||
| from mindspore.common.tensor import Tensor | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src.evaluation_config import cfg, bert_net_cfg | |||||
| from src.utils import BertNER, BertCLS, BertReg | |||||
| from src.CRF import postprocess | |||||
| from src.cluener_evaluation import submit | |||||
| from src.finetune_config import tag_to_index | |||||
| class Accuracy(): | |||||
| """ | |||||
| calculate accuracy | |||||
| """ | |||||
| def __init__(self): | |||||
| self.acc_num = 0 | |||||
| self.total_num = 0 | |||||
| def update(self, logits, labels): | |||||
| """ | |||||
| Update accuracy | |||||
| """ | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| self.acc_num += np.sum(labels == logit_id) | |||||
| self.total_num += len(labels) | |||||
| print("=========================accuracy is ", self.acc_num / self.total_num) | |||||
| class F1(): | |||||
| """ | |||||
| calculate F1 score | |||||
| """ | |||||
| def __init__(self): | |||||
| self.TP = 0 | |||||
| self.FP = 0 | |||||
| self.FN = 0 | |||||
| def update(self, logits, labels): | |||||
| """ | |||||
| update F1 score | |||||
| """ | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| if cfg.use_crf: | |||||
| backpointers, best_tag_id = logits | |||||
| best_path = postprocess(backpointers, best_tag_id) | |||||
| logit_id = [] | |||||
| for ele in best_path: | |||||
| logit_id.extend(ele) | |||||
| else: | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| logit_id = np.reshape(logit_id, -1) | |||||
| pos_eva = np.isin(logit_id, [i for i in range(1, cfg.num_labels)]) | |||||
| pos_label = np.isin(labels, [i for i in range(1, cfg.num_labels)]) | |||||
| self.TP += np.sum(pos_eva&pos_label) | |||||
| self.FP += np.sum(pos_eva&(~pos_label)) | |||||
| self.FN += np.sum((~pos_eva)&pos_label) | |||||
| class MCC(): | |||||
| """ | |||||
| Calculate Matthews Correlation Coefficient. | |||||
| """ | |||||
| def __init__(self): | |||||
| self.TP = 0 | |||||
| self.FP = 0 | |||||
| self.FN = 0 | |||||
| self.TN = 0 | |||||
| def update(self, logits, labels): | |||||
| """ | |||||
| Update MCC score | |||||
| """ | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| labels = labels.astype(np.bool) | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| logit_id = np.reshape(logit_id, -1) | |||||
| logit_id = logit_id.astype(np.bool) | |||||
| ornot = logit_id ^ labels | |||||
| self.TP += (~ornot & labels).sum() | |||||
| self.FP += (ornot & ~labels).sum() | |||||
| self.FN += (ornot & labels).sum() | |||||
| self.TN += (~ornot & ~labels).sum() | |||||
| class Spearman_Correlation(): | |||||
| """ | |||||
| calculate Spearman Correlation coefficient | |||||
| """ | |||||
| def __init__(self): | |||||
| self.label = [] | |||||
| self.logit = [] | |||||
| def update(self, logits, labels): | |||||
| """ | |||||
| Update Spearman Correlation | |||||
| """ | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| logits = logits.asnumpy() | |||||
| logits = np.reshape(logits, -1) | |||||
| self.label.append(labels) | |||||
| self.logit.append(logits) | |||||
| def cal(self): | |||||
| """ | |||||
| Calculate Spearman Correlation | |||||
| """ | |||||
| label = np.concatenate(self.label) | |||||
| logit = np.concatenate(self.logit) | |||||
| sort_label = label.argsort()[::-1] | |||||
| sort_logit = logit.argsort()[::-1] | |||||
| n = len(label) | |||||
| d_acc = 0 | |||||
| for i in range(n): | |||||
| d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] | |||||
| d_acc += d**2 | |||||
| ps = 1 - 6*d_acc/n/(n**2-1) | |||||
| return ps | |||||
| def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| """ | |||||
| get dataset | |||||
| """ | |||||
| _ = distribute_file | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", | |||||
| "segment_ids", "label_ids"]) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| if cfg.task == "Regression": | |||||
| type_cast_op_float = C.TypeCast(mstype.float32) | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) | |||||
| else: | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| # apply shuffle operation | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def bert_predict(Evaluation): | |||||
| """ | |||||
| prediction function | |||||
| """ | |||||
| target = args_opt.device_target | |||||
| if target == "Ascend": | |||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | |||||
| elif target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| if bert_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_net_cfg.compute_type = mstype.float32 | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| dataset = get_dataset(bert_net_cfg.batch_size, 1) | |||||
| if cfg.use_crf: | |||||
| net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True, | |||||
| tag_to_index=tag_to_index, dropout_prob=0.0) | |||||
| else: | |||||
| net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels) | |||||
| net_for_pretraining.set_train(False) | |||||
| param_dict = load_checkpoint(cfg.finetune_ckpt) | |||||
| load_param_into_net(net_for_pretraining, param_dict) | |||||
| model = Model(net_for_pretraining) | |||||
| return model, dataset | |||||
| def test_eval(): | |||||
| """ | |||||
| evaluation function | |||||
| """ | |||||
| if cfg.task == "SeqLabeling": | |||||
| task_type = BertNER | |||||
| elif cfg.task == "Regression": | |||||
| task_type = BertReg | |||||
| elif cfg.task == "Classification": | |||||
| task_type = BertCLS | |||||
| elif cfg.task == "COLA": | |||||
| task_type = BertCLS | |||||
| else: | |||||
| raise ValueError("Task not supported.") | |||||
| model, dataset = bert_predict(task_type) | |||||
| if cfg.clue_benchmark: | |||||
| submit(model, cfg.data_file, bert_net_cfg.seq_length) | |||||
| else: | |||||
| if cfg.task == "SeqLabeling": | |||||
| callback = F1() | |||||
| elif cfg.task == "COLA": | |||||
| callback = MCC() | |||||
| elif cfg.task == "Regression": | |||||
| callback = Spearman_Correlation() | |||||
| else: | |||||
| callback = Accuracy() | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(Tensor(data[i])) | |||||
| input_ids, input_mask, token_type_id, label_ids = input_data | |||||
| logits = model.predict(input_ids, input_mask, token_type_id, label_ids) | |||||
| callback.update(logits, label_ids) | |||||
| print("==============================================================") | |||||
| if cfg.task == "SeqLabeling": | |||||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | |||||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | |||||
| print("F1 {:.6f} ".format(2*callback.TP / (2*callback.TP + callback.FP + callback.FN))) | |||||
| elif cfg.task == "COLA": | |||||
| TP = callback.TP | |||||
| TN = callback.TN | |||||
| FP = callback.FP | |||||
| FN = callback.FN | |||||
| mcc = (TP*TN-FP*FN)/math.sqrt((TP+FP)*(TP+FN)*(TN+FP)*(TN+FN)) | |||||
| print("MCC: {:.6f}".format(mcc)) | |||||
| elif cfg.task == "Regression": | |||||
| print("Spearman Correlation is {:.6f}".format(callback.cal()[0])) | |||||
| else: | |||||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||||
| callback.acc_num / callback.total_num)) | |||||
| print("==============================================================") | |||||
| parser = argparse.ArgumentParser(description='Bert eval') | |||||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | |||||
| args_opt = parser.parse_args() | |||||
| if __name__ == "__main__": | |||||
| num_labels = cfg.num_labels | |||||
| test_eval() | |||||
| @@ -1,178 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| Bert finetune script. | |||||
| """ | |||||
| import os | |||||
| import argparse | |||||
| from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell, BertReg | |||||
| from src.finetune_config import cfg, bert_net_cfg, tag_to_index | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore import log as logger | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import Callback | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF, terminate training. | |||||
| Note: | |||||
| If per_print_times is 0, do not print loss. | |||||
| Args: | |||||
| per_print_times (int): Print loss every times. Default: 1. | |||||
| """ | |||||
| def __init__(self, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be in and >= 0.") | |||||
| self._per_print_times = per_print_times | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| with open("./loss.log", "a+") as f: | |||||
| f.write("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||||
| str(cb_params.net_outputs))) | |||||
| f.write("\n") | |||||
| def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| """ | |||||
| get dataset | |||||
| """ | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", | |||||
| "segment_ids", "label_ids"]) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| if cfg.task == "Regression": | |||||
| type_cast_op_float = C.TypeCast(mstype.float32) | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) | |||||
| else: | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| # apply shuffle operation | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| """ | |||||
| get SQuAD dataset | |||||
| """ | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", | |||||
| "start_positions", "end_positions", | |||||
| "unique_ids", "is_impossible"]) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="start_positions", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="end_positions", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def test_train(): | |||||
| """ | |||||
| finetune function | |||||
| """ | |||||
| target = args_opt.device_target | |||||
| if target == "Ascend": | |||||
| devid = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | |||||
| elif target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| if bert_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_net_cfg.compute_type = mstype.float32 | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| #BertCLSTrain for classification | |||||
| #BertNERTrain for sequence labeling | |||||
| if cfg.task == 'SeqLabeling': | |||||
| if cfg.use_crf: | |||||
| netwithloss = BertNER(bert_net_cfg, True, num_labels=len(tag_to_index), use_crf=True, | |||||
| tag_to_index=tag_to_index, dropout_prob=0.1) | |||||
| else: | |||||
| netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | |||||
| elif cfg.task == 'SQUAD': | |||||
| netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) | |||||
| elif cfg.task == 'Regression': | |||||
| netwithloss = BertReg(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | |||||
| elif cfg.task == 'Classification': | |||||
| netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| if cfg.task == 'SQUAD': | |||||
| dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num) | |||||
| else: | |||||
| dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num) | |||||
| # optimizer | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| if cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||||
| optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), | |||||
| decay_steps=steps_per_epoch * cfg.epoch_num, | |||||
| learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, | |||||
| end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, | |||||
| power=cfg.AdamWeightDecayDynamicLR.power, | |||||
| warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), | |||||
| weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, | |||||
| eps=cfg.AdamWeightDecayDynamicLR.eps) | |||||
| elif cfg.optimizer == 'Lamb': | |||||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, | |||||
| start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, | |||||
| power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay, | |||||
| warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter) | |||||
| elif cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, | |||||
| momentum=cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported.") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| ckpoint_cb = ModelCheckpoint(prefix=cfg.ckpt_prefix, directory=cfg.ckpt_dir, config=ckpt_config) | |||||
| param_dict = load_checkpoint(cfg.pre_training_ckpt) | |||||
| load_param_into_net(netwithloss, param_dict) | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) | |||||
| if cfg.task == 'SQUAD': | |||||
| netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| else: | |||||
| netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| model = Model(netwithgrads) | |||||
| model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) | |||||
| parser = argparse.ArgumentParser(description='Bert finetune') | |||||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | |||||
| args_opt = parser.parse_args() | |||||
| if __name__ == "__main__": | |||||
| test_train() | |||||
| @@ -0,0 +1,201 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert finetune and evaluation script. | |||||
| ''' | |||||
| import os | |||||
| import argparse | |||||
| from src.bert_for_finetune import BertFinetuneCell, BertCLS | |||||
| from src.finetune_eval_config import optimizer_cfg, bert_net_cfg | |||||
| from src.dataset import create_classification_dataset | |||||
| from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation | |||||
| from src.utils import make_directory, LossCallBack, LoadNewestCkpt | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore import log as logger | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| _cur_dir = os.getcwd() | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): | |||||
| """ do train """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| epoch_num = dataset.get_repeat_count() | |||||
| # optimizer | |||||
| if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||||
| optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, | |||||
| end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, | |||||
| power=optimizer_cfg.AdamWeightDecayDynamicLR.power, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, | |||||
| eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) | |||||
| elif optimizer_cfg.optimizer == 'Lamb': | |||||
| optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, | |||||
| start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, | |||||
| end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, | |||||
| power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_filter=optimizer_cfg.Lamb.decay_filter) | |||||
| elif optimizer_cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, | |||||
| momentum=optimizer_cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="classifier", directory=save_checkpoint_path, config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] | |||||
| model.train(epoch_num, dataset, callbacks=callbacks) | |||||
| def eval_result_print(assessment_method="accuracy", callback=None): | |||||
| """ print eval result """ | |||||
| if assessment_method == "accuracy": | |||||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||||
| callback.acc_num / callback.total_num)) | |||||
| elif assessment_method == "f1": | |||||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | |||||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | |||||
| print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | |||||
| elif assessment_method == "mcc": | |||||
| print("MCC {:.6f} ".format(callback.cal())) | |||||
| elif assessment_method == "spearman_correlation": | |||||
| print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) | |||||
| else: | |||||
| raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") | |||||
| def do_eval(dataset=None, network=None, num_class=2, assessment_method="accuracy", load_checkpoint_path=""): | |||||
| """ do eval """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| net_for_pretraining = network(bert_net_cfg, False, num_class) | |||||
| net_for_pretraining.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| load_param_into_net(net_for_pretraining, param_dict) | |||||
| model = Model(net_for_pretraining) | |||||
| if assessment_method == "accuracy": | |||||
| callback = Accuracy() | |||||
| elif assessment_method == "f1": | |||||
| callback = F1(False, num_class) | |||||
| elif assessment_method == "mcc": | |||||
| callback = MCC() | |||||
| elif assessment_method == "spearman_correlation": | |||||
| callback = Spearman_Correlation() | |||||
| else: | |||||
| raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(Tensor(data[i])) | |||||
| input_ids, input_mask, token_type_id, label_ids = input_data | |||||
| logits = model.predict(input_ids, input_mask, token_type_id, label_ids) | |||||
| callback.update(logits, label_ids) | |||||
| print("==============================================================") | |||||
| eval_result_print(assessment_method, callback) | |||||
| print("==============================================================") | |||||
| def run_classifier(): | |||||
| """run classifier task""" | |||||
| parser = argparse.ArgumentParser(description="run classifier") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") | |||||
| parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " | |||||
| "[MCC, Spearman_correlation, " | |||||
| "Accuracy], default is accuracy") | |||||
| parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") | |||||
| parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||||
| parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") | |||||
| parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") | |||||
| parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") | |||||
| parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||||
| parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--schema_file_path", type=str, default="", | |||||
| help="Schema path, it is better to use absolute path") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| assessment_method = args_opt.assessment_method.lower() | |||||
| load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path | |||||
| save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path | |||||
| load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| target = args_opt.device_target | |||||
| if target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| elif target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| if bert_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_net_cfg.compute_type = mstype.float32 | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| netwithloss = BertCLS(bert_net_cfg, True, num_labels=args_opt.num_class, dropout_prob=0.1, | |||||
| assessment_method=assessment_method) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||||
| assessment_method=assessment_method, | |||||
| data_file_path=args_opt.train_data_file_path, | |||||
| schema_file_path=args_opt.schema_file_path) | |||||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| if save_finetune_checkpoint_path == "": | |||||
| load_finetune_checkpoint_dir = _cur_dir | |||||
| else: | |||||
| load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) | |||||
| load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, | |||||
| ds.get_dataset_size(), epoch_num, "classifier") | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| ds = create_classification_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||||
| assessment_method=assessment_method, | |||||
| data_file_path=args_opt.eval_data_file_path, | |||||
| schema_file_path=args_opt.schema_file_path) | |||||
| do_eval(ds, BertCLS, args_opt.num_class, assessment_method, load_finetune_checkpoint_path) | |||||
| if __name__ == "__main__": | |||||
| run_classifier() | |||||
| @@ -0,0 +1,228 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert finetune and evaluation script. | |||||
| ''' | |||||
| import os | |||||
| import json | |||||
| import argparse | |||||
| from src.bert_for_finetune import BertFinetuneCell, BertNER | |||||
| from src.finetune_eval_config import optimizer_cfg, bert_net_cfg | |||||
| from src.dataset import create_ner_dataset | |||||
| from src.utils import make_directory, LossCallBack, LoadNewestCkpt | |||||
| from src.assessment_method import Accuracy, F1, MCC, Spearman_Correlation | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore import log as logger | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| _cur_dir = os.getcwd() | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): | |||||
| """ do train """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| epoch_num = dataset.get_repeat_count() | |||||
| # optimizer | |||||
| if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||||
| optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, | |||||
| end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, | |||||
| power=optimizer_cfg.AdamWeightDecayDynamicLR.power, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, | |||||
| eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) | |||||
| elif optimizer_cfg.optimizer == 'Lamb': | |||||
| optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, | |||||
| start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, | |||||
| end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, | |||||
| power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_filter=optimizer_cfg.Lamb.decay_filter) | |||||
| elif optimizer_cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, | |||||
| momentum=optimizer_cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="ner", directory=save_checkpoint_path, config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = BertFinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] | |||||
| model.train(epoch_num, dataset, callbacks=callbacks) | |||||
| def eval_result_print(assessment_method="accuracy", callback=None): | |||||
| """print eval result""" | |||||
| if assessment_method == "accuracy": | |||||
| print("acc_num {} , total_num {}, accuracy {:.6f}".format(callback.acc_num, callback.total_num, | |||||
| callback.acc_num / callback.total_num)) | |||||
| elif assessment_method == "f1": | |||||
| print("Precision {:.6f} ".format(callback.TP / (callback.TP + callback.FP))) | |||||
| print("Recall {:.6f} ".format(callback.TP / (callback.TP + callback.FN))) | |||||
| print("F1 {:.6f} ".format(2 * callback.TP / (2 * callback.TP + callback.FP + callback.FN))) | |||||
| elif assessment_method == "mcc": | |||||
| print("MCC {:.6f} ".format(callback.cal())) | |||||
| elif assessment_method == "spearman_correlation": | |||||
| print("Spearman Correlation is {:.6f} ".format(callback.cal()[0])) | |||||
| else: | |||||
| raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") | |||||
| def do_eval(dataset=None, network=None, use_crf="", num_class=2, assessment_method="accuracy", data_file="", | |||||
| load_checkpoint_path="", vocab_file="", label2id_file="", tag_to_index=None): | |||||
| """ do eval """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| if assessment_method == "clue_benchmark": | |||||
| bert_net_cfg.batch_size = 1 | |||||
| net_for_pretraining = network(bert_net_cfg, False, num_class, use_crf=(use_crf.lower() == "true"), | |||||
| tag_to_index=tag_to_index) | |||||
| net_for_pretraining.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| load_param_into_net(net_for_pretraining, param_dict) | |||||
| model = Model(net_for_pretraining) | |||||
| if assessment_method == "clue_benchmark": | |||||
| from src.cluener_evaluation import submit | |||||
| submit(model=model, path=data_file, vocab_file=vocab_file, use_crf=use_crf, label2id_file=label2id_file) | |||||
| else: | |||||
| if assessment_method == "accuracy": | |||||
| callback = Accuracy() | |||||
| elif assessment_method == "f1": | |||||
| callback = F1((use_crf.lower() == "true"), num_class) | |||||
| elif assessment_method == "mcc": | |||||
| callback = MCC() | |||||
| elif assessment_method == "spearman_correlation": | |||||
| callback = Spearman_Correlation() | |||||
| else: | |||||
| raise ValueError("Assessment method not supported, support: [accuracy, f1, mcc, spearman_correlation]") | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(Tensor(data[i])) | |||||
| input_ids, input_mask, token_type_id, label_ids = input_data | |||||
| logits = model.predict(input_ids, input_mask, token_type_id, label_ids) | |||||
| callback.update(logits, label_ids) | |||||
| print("==============================================================") | |||||
| eval_result_print(assessment_method, callback) | |||||
| print("==============================================================") | |||||
| def run_ner(): | |||||
| """run ner task""" | |||||
| parser = argparse.ArgumentParser(description="run classifier") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") | |||||
| parser.add_argument("--assessment_method", type=str, default="accuracy", help="assessment_method include: " | |||||
| "[F1, clue_benchmark], default is F1") | |||||
| parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") | |||||
| parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") | |||||
| parser.add_argument("--use_crf", type=str, default="false", help="Use crf, default is false") | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||||
| parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") | |||||
| parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") | |||||
| parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path, used in clue benchmark") | |||||
| parser.add_argument("--label2id_file_path", type=str, default="", help="label2id file path, used in clue benchmark") | |||||
| parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") | |||||
| parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||||
| parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--schema_file_path", type=str, default="", | |||||
| help="Schema path, it is better to use absolute path") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| assessment_method = args_opt.assessment_method.lower() | |||||
| load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path | |||||
| save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path | |||||
| load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.vocab_file_path == "": | |||||
| raise ValueError("'vocab_file_path' must be set to do clue benchmark") | |||||
| if args_opt.use_crf.lower() == "true" and args_opt.label2id_file_path == "": | |||||
| raise ValueError("'label2id_file_path' must be set to use crf") | |||||
| if args_opt.assessment_method.lower() == "clue_benchmark" and args_opt.label2id_file_path == "": | |||||
| raise ValueError("'label2id_file_path' must be set to do clue benchmark") | |||||
| target = args_opt.device_target | |||||
| if target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| elif target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| if bert_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_net_cfg.compute_type = mstype.float32 | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| tag_to_index = None | |||||
| if args_opt.use_crf.lower() == "true": | |||||
| with open(args_opt.label2id_file_path) as json_file: | |||||
| tag_to_index = json.load(json_file) | |||||
| max_val = max(tag_to_index.values()) | |||||
| tag_to_index["<START>"] = max_val + 1 | |||||
| tag_to_index["<STOP>"] = max_val + 2 | |||||
| number_labels = len(tag_to_index) | |||||
| else: | |||||
| number_labels = args_opt.num_class | |||||
| netwithloss = BertNER(bert_net_cfg, True, num_labels=number_labels, | |||||
| use_crf=(args_opt.use_crf.lower() == "true"), | |||||
| tag_to_index=tag_to_index, dropout_prob=0.1) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||||
| assessment_method=assessment_method, data_file_path=args_opt.train_data_file_path, | |||||
| schema_file_path=args_opt.schema_file_path) | |||||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| if save_finetune_checkpoint_path == "": | |||||
| load_finetune_checkpoint_dir = _cur_dir | |||||
| else: | |||||
| load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) | |||||
| load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, | |||||
| ds.get_dataset_size(), epoch_num, "ner") | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| ds = create_ner_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||||
| assessment_method=assessment_method, data_file_path=args_opt.eval_data_file_path, | |||||
| schema_file_path=args_opt.schema_file_path) | |||||
| do_eval(ds, BertNER, args_opt.use_crf, number_labels, assessment_method, args_opt.eval_data_file_path, | |||||
| load_finetune_checkpoint_path, args_opt.vocab_file_path, args_opt.label2id_file_path, tag_to_index) | |||||
| if __name__ == "__main__": | |||||
| run_ner() | |||||
| @@ -26,33 +26,16 @@ from mindspore import context | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | ||||
| from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR | from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell | ||||
| from src.dataset import create_bert_dataset | from src.dataset import create_bert_dataset | ||||
| from src.config import cfg, bert_net_cfg | from src.config import cfg, bert_net_cfg | ||||
| from src.utils import LossCallBack | |||||
| _current_dir = os.path.dirname(os.path.realpath(__file__)) | _current_dir = os.path.dirname(os.path.realpath(__file__)) | ||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss in NAN or INF terminating training. | |||||
| Note: | |||||
| if per_print_times is 0 do not print loss. | |||||
| Args: | |||||
| per_print_times (int): Print loss every times. Default: 1. | |||||
| """ | |||||
| def __init__(self, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0") | |||||
| self._per_print_times = per_print_times | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||||
| str(cb_params.net_outputs))) | |||||
| def run_pretrain(): | def run_pretrain(): | ||||
| """pre-train bert_clue""" | """pre-train bert_clue""" | ||||
| @@ -0,0 +1,204 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert finetune and evaluation script. | |||||
| ''' | |||||
| import os | |||||
| import argparse | |||||
| import collections | |||||
| from src.bert_for_finetune import BertSquadCell, BertSquad | |||||
| from src.finetune_eval_config import optimizer_cfg, bert_net_cfg | |||||
| from src.dataset import create_squad_dataset | |||||
| from src import tokenization | |||||
| from src.create_squad_data import read_squad_examples, convert_examples_to_features | |||||
| from src.run_squad import write_predictions | |||||
| from src.utils import make_directory, LossCallBack, LoadNewestCkpt | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore import log as logger | |||||
| from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell | |||||
| from mindspore.nn.optim import AdamWeightDecayDynamicLR, Lamb, Momentum | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| _cur_dir = os.getcwd() | |||||
| def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path=""): | |||||
| """ do train """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Pretrain model missed, finetune task must load pretrain model!") | |||||
| steps_per_epoch = dataset.get_dataset_size() | |||||
| epoch_num = dataset.get_repeat_count() | |||||
| # optimizer | |||||
| if optimizer_cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||||
| optimizer = AdamWeightDecayDynamicLR(network.trainable_params(), | |||||
| decay_steps=steps_per_epoch * epoch_num, | |||||
| learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.learning_rate, | |||||
| end_learning_rate=optimizer_cfg.AdamWeightDecayDynamicLR.end_learning_rate, | |||||
| power=optimizer_cfg.AdamWeightDecayDynamicLR.power, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| weight_decay=optimizer_cfg.AdamWeightDecayDynamicLR.weight_decay, | |||||
| eps=optimizer_cfg.AdamWeightDecayDynamicLR.eps) | |||||
| elif optimizer_cfg.optimizer == 'Lamb': | |||||
| optimizer = Lamb(network.trainable_params(), decay_steps=steps_per_epoch * epoch_num, | |||||
| start_learning_rate=optimizer_cfg.Lamb.start_learning_rate, | |||||
| end_learning_rate=optimizer_cfg.Lamb.end_learning_rate, | |||||
| power=optimizer_cfg.Lamb.power, weight_decay=optimizer_cfg.Lamb.weight_decay, | |||||
| warmup_steps=int(steps_per_epoch * epoch_num * 0.1), | |||||
| decay_filter=optimizer_cfg.Lamb.decay_filter) | |||||
| elif optimizer_cfg.optimizer == 'Momentum': | |||||
| optimizer = Momentum(network.trainable_params(), learning_rate=optimizer_cfg.Momentum.learning_rate, | |||||
| momentum=optimizer_cfg.Momentum.momentum) | |||||
| else: | |||||
| raise Exception("Optimizer not supported. support: [AdamWeightDecayDynamicLR, Lamb, Momentum]") | |||||
| # load checkpoint into network | |||||
| ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1) | |||||
| ckpoint_cb = ModelCheckpoint(prefix="squad", directory=save_checkpoint_path, config=ckpt_config) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| load_param_into_net(network, param_dict) | |||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) | |||||
| netwithgrads = BertSquadCell(network, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| model = Model(netwithgrads) | |||||
| callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack(), ckpoint_cb] | |||||
| model.train(epoch_num, dataset, callbacks=callbacks) | |||||
| def do_eval(dataset=None, vocab_file="", eval_json="", load_checkpoint_path="", seq_length=384): | |||||
| """ do eval """ | |||||
| if load_checkpoint_path == "": | |||||
| raise ValueError("Finetune model missed, evaluation task must load finetune model!") | |||||
| tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, do_lower_case=True) | |||||
| eval_examples = read_squad_examples(eval_json, False) | |||||
| eval_features = convert_examples_to_features( | |||||
| examples=eval_examples, | |||||
| tokenizer=tokenizer, | |||||
| max_seq_length=seq_length, | |||||
| doc_stride=128, | |||||
| max_query_length=64, | |||||
| is_training=False, | |||||
| output_fn=None, | |||||
| verbose_logging=False) | |||||
| net = BertSquad(bert_net_cfg, False, 2) | |||||
| net.set_train(False) | |||||
| param_dict = load_checkpoint(load_checkpoint_path) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| output = [] | |||||
| RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(Tensor(data[i])) | |||||
| input_ids, input_mask, segment_ids, unique_ids = input_data | |||||
| start_positions = Tensor([1], mstype.float32) | |||||
| end_positions = Tensor([1], mstype.float32) | |||||
| is_impossible = Tensor([1], mstype.float32) | |||||
| logits = model.predict(input_ids, input_mask, segment_ids, start_positions, | |||||
| end_positions, unique_ids, is_impossible) | |||||
| ids = logits[0].asnumpy() | |||||
| start = logits[1].asnumpy() | |||||
| end = logits[2].asnumpy() | |||||
| for i in range(bert_net_cfg.batch_size): | |||||
| unique_id = int(ids[i]) | |||||
| start_logits = [float(x) for x in start[i].flat] | |||||
| end_logits = [float(x) for x in end[i].flat] | |||||
| output.append(RawResult( | |||||
| unique_id=unique_id, | |||||
| start_logits=start_logits, | |||||
| end_logits=end_logits)) | |||||
| write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", None, None) | |||||
| def run_squad(): | |||||
| """run squad task""" | |||||
| parser = argparse.ArgumentParser(description="run classifier") | |||||
| parser.add_argument("--device_target", type=str, default="Ascend", help="Device type, default is Ascend") | |||||
| parser.add_argument("--do_train", type=str, default="false", help="Eable train, default is false") | |||||
| parser.add_argument("--do_eval", type=str, default="false", help="Eable eval, default is false") | |||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") | |||||
| parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") | |||||
| parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") | |||||
| parser.add_argument("--vocab_file_path", type=str, default="", help="Vocab file path") | |||||
| parser.add_argument("--eval_json_path", type=str, default="", help="Evaluation json file path, can be eval.json") | |||||
| parser.add_argument("--save_finetune_checkpoint_path", type=str, default="", help="Save checkpoint path") | |||||
| parser.add_argument("--load_pretrain_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||||
| parser.add_argument("--load_finetune_checkpoint_path", type=str, default="", help="Load checkpoint file path") | |||||
| parser.add_argument("--train_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--eval_data_file_path", type=str, default="", | |||||
| help="Data path, it is better to use absolute path") | |||||
| parser.add_argument("--schema_file_path", type=str, default="", | |||||
| help="Schema path, it is better to use absolute path") | |||||
| args_opt = parser.parse_args() | |||||
| epoch_num = args_opt.epoch_num | |||||
| load_pretrain_checkpoint_path = args_opt.load_pretrain_checkpoint_path | |||||
| save_finetune_checkpoint_path = args_opt.save_finetune_checkpoint_path | |||||
| load_finetune_checkpoint_path = args_opt.load_finetune_checkpoint_path | |||||
| if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false": | |||||
| raise ValueError("At least one of 'do_train' or 'do_eval' must be true") | |||||
| if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "": | |||||
| raise ValueError("'train_data_file_path' must be set when do finetune task") | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| if args_opt.eval_data_file_path == "": | |||||
| raise ValueError("'eval_data_file_path' must be set when do evaluation task") | |||||
| if args_opt.vocab_file_path == "": | |||||
| raise ValueError("'vocab_file_path' must be set when do evaluation task") | |||||
| if args_opt.eval_json_path == "": | |||||
| raise ValueError("'tokenization_file_path' must be set when do evaluation task") | |||||
| target = args_opt.device_target | |||||
| if target == "Ascend": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) | |||||
| elif target == "GPU": | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||||
| if bert_net_cfg.compute_type != mstype.float32: | |||||
| logger.warning('GPU only support fp32 temporarily, run with fp32.') | |||||
| bert_net_cfg.compute_type = mstype.float32 | |||||
| else: | |||||
| raise Exception("Target error, GPU or Ascend is supported.") | |||||
| netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) | |||||
| if args_opt.do_train.lower() == "true": | |||||
| ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||||
| data_file_path=args_opt.train_data_file_path, | |||||
| schema_file_path=args_opt.schema_file_path) | |||||
| do_train(ds, netwithloss, load_pretrain_checkpoint_path, save_finetune_checkpoint_path) | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| if save_finetune_checkpoint_path == "": | |||||
| load_finetune_checkpoint_dir = _cur_dir | |||||
| else: | |||||
| load_finetune_checkpoint_dir = make_directory(save_finetune_checkpoint_path) | |||||
| load_finetune_checkpoint_path = LoadNewestCkpt(load_finetune_checkpoint_dir, | |||||
| ds.get_dataset_size(), epoch_num, "squad") | |||||
| if args_opt.do_eval.lower() == "true": | |||||
| ds = create_squad_dataset(batch_size=bert_net_cfg.batch_size, repeat_count=epoch_num, | |||||
| data_file_path=args_opt.eval_data_file_path, | |||||
| schema_file_path=args_opt.schema_file_path, is_training=False) | |||||
| do_eval(ds, args_opt.vocab_file_path, args_opt.eval_json_path, | |||||
| load_finetune_checkpoint_path, bert_net_cfg.seq_length) | |||||
| if __name__ == "__main__": | |||||
| run_squad() | |||||
| @@ -0,0 +1,42 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | |||||
| echo "bash scripts/run_classifier.sh" | |||||
| echo "for example: bash scripts/run_classifier.sh" | |||||
| echo "assessment_method include: [MCC, Spearman_correlation ,Accuracy]" | |||||
| echo "==============================================================================================================" | |||||
| mkdir -p ms_log | |||||
| CUR_DIR=`pwd` | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_classifier.py \ | |||||
| --device_target="Ascend" \ | |||||
| --do_train="true" \ | |||||
| --do_eval="false" \ | |||||
| --assessment_method="Accuracy" \ | |||||
| --device_id=0 \ | |||||
| --epoch_num=1 \ | |||||
| --num_class=2 \ | |||||
| --save_finetune_checkpoint_path="" \ | |||||
| --load_pretrain_checkpoint_path="" \ | |||||
| --load_finetune_checkpoint_path="" \ | |||||
| --train_data_file_path="" \ | |||||
| --eval_data_file_path="" \ | |||||
| --schema_file_path="" > log.txt 2>&1 & | |||||
| @@ -24,8 +24,7 @@ echo "========================================================================== | |||||
| EPOCH_SIZE=$2 | EPOCH_SIZE=$2 | ||||
| DATA_DIR=$3 | DATA_DIR=$3 | ||||
| SCHEMA_DIR=$4 | SCHEMA_DIR=$4 | ||||
| export MINDSPORE_HCCL_CONFIG_PATH=$5 | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export RANK_TABLE_FILE=$5 | export RANK_TABLE_FILE=$5 | ||||
| export RANK_SIZE=$1 | export RANK_SIZE=$1 | ||||
| cores=`cat /proc/cpuinfo|grep "processor" |wc -l` | cores=`cat /proc/cpuinfo|grep "processor" |wc -l` | ||||
| @@ -54,7 +53,7 @@ do | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | export GLOG_log_dir=${CUR_DIR}/ms_log | ||||
| export GLOG_logtostderr=0 | export GLOG_logtostderr=0 | ||||
| env > env.log | env > env.log | ||||
| taskset -c $cmdopt python ../run_pretrain.py \ | |||||
| taskset -c $cmdopt python ${PROJECT_DIR}/../run_pretrain.py \ | |||||
| --distribute="true" \ | --distribute="true" \ | ||||
| --epoch_size=$EPOCH_SIZE \ | --epoch_size=$EPOCH_SIZE \ | ||||
| --device_id=$DEVICE_ID \ | --device_id=$DEVICE_ID \ | ||||
| @@ -0,0 +1,45 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | |||||
| echo "bash scripts/run_ner.sh" | |||||
| echo "for example: bash scripts/run_ner.sh" | |||||
| echo "assessment_method include: [F1, clue_benchmark]" | |||||
| echo "==============================================================================================================" | |||||
| mkdir -p ms_log | |||||
| CUR_DIR=`pwd` | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_ner.py \ | |||||
| --device_target="Ascend" \ | |||||
| --do_train="true" \ | |||||
| --do_eval="false" \ | |||||
| --assessment_method="F1" \ | |||||
| --use_crf="false" \ | |||||
| --device_id=0 \ | |||||
| --epoch_num=1 \ | |||||
| --num_class=2 \ | |||||
| --vocab_file_path="" \ | |||||
| --label2id_file_path="" \ | |||||
| --save_finetune_checkpoint_path="" \ | |||||
| --load_pretrain_checkpoint_path="" \ | |||||
| --load_finetune_checkpoint_path="" \ | |||||
| --train_data_file_path="" \ | |||||
| --eval_data_file_path="" \ | |||||
| --schema_file_path="" > log.txt 2>&1 & | |||||
| @@ -0,0 +1,43 @@ | |||||
| #!/bin/bash | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| echo "==============================================================================================================" | |||||
| echo "Please run the scipt as: " | |||||
| echo "bash scripts/run_squad.sh" | |||||
| echo "for example: bash scripts/run_squad.sh" | |||||
| echo "assessment_method include: [Accuracy]" | |||||
| echo "==============================================================================================================" | |||||
| mkdir -p ms_log | |||||
| CUR_DIR=`pwd` | |||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | |||||
| export GLOG_logtostderr=0 | |||||
| python ${PROJECT_DIR}/../run_squad.py \ | |||||
| --device_target="Ascend" \ | |||||
| --do_train="true" \ | |||||
| --do_eval="false" \ | |||||
| --device_id=0 \ | |||||
| --epoch_num=1 \ | |||||
| --num_class=2 \ | |||||
| --vocab_file_path="" \ | |||||
| --eval_json_path="" \ | |||||
| --save_finetune_checkpoint_path="" \ | |||||
| --load_pretrain_checkpoint_path="" \ | |||||
| --load_finetune_checkpoint_path="" \ | |||||
| --train_data_file_path="" \ | |||||
| --eval_data_file_path="" \ | |||||
| --schema_file_path="" > log.txt 2>&1 & | |||||
| @@ -26,10 +26,11 @@ DATA_DIR=$3 | |||||
| SCHEMA_DIR=$4 | SCHEMA_DIR=$4 | ||||
| mkdir -p ms_log | mkdir -p ms_log | ||||
| PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd) | |||||
| CUR_DIR=`pwd` | CUR_DIR=`pwd` | ||||
| export GLOG_log_dir=${CUR_DIR}/ms_log | export GLOG_log_dir=${CUR_DIR}/ms_log | ||||
| export GLOG_logtostderr=0 | export GLOG_logtostderr=0 | ||||
| python run_pretrain.py \ | |||||
| python ${PROJECT_DIR}/../run_pretrain.py \ | |||||
| --distribute="false" \ | --distribute="false" \ | ||||
| --epoch_size=$EPOCH_SIZE \ | --epoch_size=$EPOCH_SIZE \ | ||||
| --device_id=$DEVICE_ID \ | --device_id=$DEVICE_ID \ | ||||
| @@ -1,99 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Evaluation script for SQuAD task""" | |||||
| import os | |||||
| import collections | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src import tokenization | |||||
| from src.evaluation_config import cfg, bert_net_cfg | |||||
| from src.utils import BertSquad | |||||
| from src.create_squad_data import read_squad_examples, convert_examples_to_features | |||||
| from src.run_squad import write_predictions | |||||
| def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| """get SQuAD dataset from tfrecord""" | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", | |||||
| "segment_ids", "unique_ids"], | |||||
| shuffle=False) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def test_eval(): | |||||
| """Evaluation function for SQuAD task""" | |||||
| tokenizer = tokenization.FullTokenizer(vocab_file="./vocab.txt", do_lower_case=True) | |||||
| input_file = "dataset/v1.1/dev-v1.1.json" | |||||
| eval_examples = read_squad_examples(input_file, False) | |||||
| eval_features = convert_examples_to_features( | |||||
| examples=eval_examples, | |||||
| tokenizer=tokenizer, | |||||
| max_seq_length=384, | |||||
| doc_stride=128, | |||||
| max_query_length=64, | |||||
| is_training=False, | |||||
| output_fn=None, | |||||
| verbose_logging=False) | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id) | |||||
| dataset = get_squad_dataset(bert_net_cfg.batch_size, 1) | |||||
| net = BertSquad(bert_net_cfg, False, 2) | |||||
| net.set_train(False) | |||||
| param_dict = load_checkpoint(cfg.finetune_ckpt) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| output = [] | |||||
| RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(Tensor(data[i])) | |||||
| input_ids, input_mask, segment_ids, unique_ids = input_data | |||||
| start_positions = Tensor([1], mstype.float32) | |||||
| end_positions = Tensor([1], mstype.float32) | |||||
| is_impossible = Tensor([1], mstype.float32) | |||||
| logits = model.predict(input_ids, input_mask, segment_ids, start_positions, | |||||
| end_positions, unique_ids, is_impossible) | |||||
| ids = logits[0].asnumpy() | |||||
| start = logits[1].asnumpy() | |||||
| end = logits[2].asnumpy() | |||||
| for i in range(bert_net_cfg.batch_size): | |||||
| unique_id = int(ids[i]) | |||||
| start_logits = [float(x) for x in start[i].flat] | |||||
| end_logits = [float(x) for x in end[i].flat] | |||||
| output.append(RawResult( | |||||
| unique_id=unique_id, | |||||
| start_logits=start_logits, | |||||
| end_logits=end_logits)) | |||||
| write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", | |||||
| None, None, False, False) | |||||
| if __name__ == "__main__": | |||||
| test_eval() | |||||
| @@ -0,0 +1,134 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert evaluation assessment method script. | |||||
| ''' | |||||
| import math | |||||
| import numpy as np | |||||
| from .CRF import postprocess | |||||
| class Accuracy(): | |||||
| ''' | |||||
| calculate accuracy | |||||
| ''' | |||||
| def __init__(self): | |||||
| self.acc_num = 0 | |||||
| self.total_num = 0 | |||||
| def update(self, logits, labels): | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| self.acc_num += np.sum(labels == logit_id) | |||||
| self.total_num += len(labels) | |||||
| print("=========================accuracy is ", self.acc_num / self.total_num) | |||||
| class F1(): | |||||
| ''' | |||||
| calculate F1 score | |||||
| ''' | |||||
| def __init__(self, use_crf=False, num_labels=2): | |||||
| self.TP = 0 | |||||
| self.FP = 0 | |||||
| self.FN = 0 | |||||
| self.use_crf = use_crf | |||||
| self.num_labels = num_labels | |||||
| def update(self, logits, labels): | |||||
| ''' | |||||
| update F1 score | |||||
| ''' | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| if self.use_crf: | |||||
| backpointers, best_tag_id = logits | |||||
| best_path = postprocess(backpointers, best_tag_id) | |||||
| logit_id = [] | |||||
| for ele in best_path: | |||||
| logit_id.extend(ele) | |||||
| else: | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| logit_id = np.reshape(logit_id, -1) | |||||
| pos_eva = np.isin(logit_id, [i for i in range(1, self.num_labels)]) | |||||
| pos_label = np.isin(labels, [i for i in range(1, self.num_labels)]) | |||||
| self.TP += np.sum(pos_eva&pos_label) | |||||
| self.FP += np.sum(pos_eva&(~pos_label)) | |||||
| self.FN += np.sum((~pos_eva)&pos_label) | |||||
| class MCC(): | |||||
| ''' | |||||
| Calculate Matthews Correlation Coefficient | |||||
| ''' | |||||
| def __init__(self): | |||||
| self.TP = 0 | |||||
| self.FP = 0 | |||||
| self.FN = 0 | |||||
| self.TN = 0 | |||||
| def update(self, logits, labels): | |||||
| ''' | |||||
| MCC update | |||||
| ''' | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| labels = labels.astype(np.bool) | |||||
| logits = logits.asnumpy() | |||||
| logit_id = np.argmax(logits, axis=-1) | |||||
| logit_id = np.reshape(logit_id, -1) | |||||
| logit_id = logit_id.astype(np.bool) | |||||
| ornot = logit_id ^ labels | |||||
| self.TP += (~ornot & labels).sum() | |||||
| self.FP += (ornot & ~labels).sum() | |||||
| self.FN += (ornot & labels).sum() | |||||
| self.TN += (~ornot & ~labels).sum() | |||||
| def cal(self): | |||||
| mcc = (self.TP*self.TN - self.FP*self.FN)/math.sqrt((self.TP+self.FP)*(self.TP+self.FN) * | |||||
| (self.TN+self.FP)*(self.TN+self.FN)) | |||||
| return mcc | |||||
| class Spearman_Correlation(): | |||||
| ''' | |||||
| Calculate Spearman Correlation Coefficient | |||||
| ''' | |||||
| def __init__(self): | |||||
| self.label = [] | |||||
| self.logit = [] | |||||
| def update(self, logits, labels): | |||||
| labels = labels.asnumpy() | |||||
| labels = np.reshape(labels, -1) | |||||
| logits = logits.asnumpy() | |||||
| logits = np.reshape(logits, -1) | |||||
| self.label.append(labels) | |||||
| self.logit.append(logits) | |||||
| def cal(self): | |||||
| ''' | |||||
| Calculate Spearman Correlation | |||||
| ''' | |||||
| label = np.concatenate(self.label) | |||||
| logit = np.concatenate(self.logit) | |||||
| sort_label = label.argsort()[::-1] | |||||
| sort_logit = logit.argsort()[::-1] | |||||
| n = len(label) | |||||
| d_acc = 0 | |||||
| for i in range(n): | |||||
| d = np.where(sort_label == i)[0] - np.where(sort_logit == i)[0] | |||||
| d_acc += d**2 | |||||
| ps = 1 - 6*d_acc/n/(n**2-1) | |||||
| return ps | |||||
| @@ -0,0 +1,327 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert for finetune script. | |||||
| ''' | |||||
| import mindspore.nn as nn | |||||
| from mindspore.ops import operations as P | |||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||||
| from mindspore.common import dtype as mstype | |||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore.train.parallel_utils import ParallelMode | |||||
| from mindspore.communication.management import get_group_size | |||||
| from mindspore import context | |||||
| from .bert_for_pre_training import clip_grad | |||||
| from .finetune_eval_model import BertCLSModel, BertNERModel, BertSquadModel | |||||
| from .utils import CrossEntropyCalculation | |||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| GRADIENT_CLIP_VALUE = 1.0 | |||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| reciprocal = P.Reciprocal() | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * reciprocal(scale) | |||||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||||
| grad_overflow = P.FloatStatus() | |||||
| @_grad_overflow.register("Tensor") | |||||
| def _tensor_grad_overflow(grad): | |||||
| return grad_overflow(grad) | |||||
| class BertFinetuneCell(nn.Cell): | |||||
| """ | |||||
| Especifically defined for finetuning where only four inputs tensor are needed. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(BertFinetuneCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', | |||||
| get_by_list=True, | |||||
| sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.allreduce = P.AllReduce() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.gpu_target = False | |||||
| if context.get_context("device_target") == "GPU": | |||||
| self.gpu_target = True | |||||
| self.float_status = P.FloatStatus() | |||||
| self.addn = P.AddN() | |||||
| self.reshape = P.Reshape() | |||||
| else: | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids, | |||||
| sens=None): | |||||
| weights = self.weights | |||||
| init = False | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| if not self.gpu_target: | |||||
| init = self.alloc_status() | |||||
| clear_before_grad = self.clear_before_grad(init) | |||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| if not self.gpu_target: | |||||
| flag = self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| else: | |||||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||||
| flag_sum = self.addn(flag_sum) | |||||
| flag_sum = self.reshape(flag_sum, (())) | |||||
| if self.is_distributed: | |||||
| flag_reduce = self.allreduce(flag_sum) | |||||
| cond = self.less_equal(self.base, flag_reduce) | |||||
| else: | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond) | |||||
| return F.depend(ret, succ) | |||||
| class BertSquadCell(nn.Cell): | |||||
| """ | |||||
| specifically defined for finetuning where only four inputs tensor are needed. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(BertSquadCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.allreduce = P.AllReduce() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible, | |||||
| sens=None): | |||||
| weights = self.weights | |||||
| init = self.alloc_status() | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| clear_before_grad = self.clear_before_grad(init) | |||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| flag = self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| if self.is_distributed: | |||||
| flag_reduce = self.allreduce(flag_sum) | |||||
| cond = self.less_equal(self.base, flag_reduce) | |||||
| else: | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond) | |||||
| return F.depend(ret, succ) | |||||
| class BertCLS(nn.Cell): | |||||
| """ | |||||
| Train interface for classification finetuning task. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, | |||||
| assessment_method=""): | |||||
| super(BertCLS, self).__init__() | |||||
| self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings, | |||||
| assessment_method) | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| self.assessment_method = assessment_method | |||||
| self.is_training = is_training | |||||
| def construct(self, input_ids, input_mask, token_type_id, label_ids): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.assessment_method == "spearman_correlation": | |||||
| if self.is_training: | |||||
| loss = self.loss(logits, label_ids) | |||||
| else: | |||||
| loss = logits | |||||
| else: | |||||
| loss = self.loss(logits, label_ids, self.num_labels) | |||||
| return loss | |||||
| class BertNER(nn.Cell): | |||||
| """ | |||||
| Train interface for sequence labeling finetuning task. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, | |||||
| use_one_hot_embeddings=False): | |||||
| super(BertNER, self).__init__() | |||||
| self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) | |||||
| if use_crf: | |||||
| if not tag_to_index: | |||||
| raise Exception("The dict for tag-index mapping should be provided for CRF.") | |||||
| from src.CRF import CRF | |||||
| self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) | |||||
| else: | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| self.use_crf = use_crf | |||||
| def construct(self, input_ids, input_mask, token_type_id, label_ids): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.use_crf: | |||||
| loss = self.loss(logits, label_ids) | |||||
| else: | |||||
| loss = self.loss(logits, label_ids, self.num_labels) | |||||
| return loss | |||||
| class BertSquad(nn.Cell): | |||||
| ''' | |||||
| Train interface for SQuAD finetuning task. | |||||
| ''' | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertSquad, self).__init__() | |||||
| self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| self.seq_length = config.seq_length | |||||
| self.is_training = is_training | |||||
| self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') | |||||
| self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') | |||||
| self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') | |||||
| self.sum = P.ReduceSum() | |||||
| self.equal = P.Equal() | |||||
| self.argmax = P.ArgMaxWithValue(axis=1) | |||||
| self.squeeze = P.Squeeze(axis=-1) | |||||
| def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.is_training: | |||||
| unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) | |||||
| unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) | |||||
| start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) | |||||
| end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) | |||||
| total_loss = (start_loss + end_loss) / 2.0 | |||||
| else: | |||||
| start_logits = self.squeeze(logits[:, :, 0:1]) | |||||
| end_logits = self.squeeze(logits[:, :, 1:2]) | |||||
| total_loss = (unique_id, start_logits, end_logits) | |||||
| return total_loss | |||||
| @@ -19,15 +19,13 @@ import json | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from . import tokenization | |||||
| from .sample_process import label_generation, process_one_example_p | |||||
| from .evaluation_config import cfg | |||||
| from .CRF import postprocess | |||||
| from src import tokenization | |||||
| from src.sample_process import label_generation, process_one_example_p | |||||
| from src.CRF import postprocess | |||||
| from src.finetune_eval_config import bert_net_cfg | |||||
| vocab_file = "./vocab.txt" | |||||
| tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) | |||||
| def process(model, text, sequence_length): | |||||
| def process(model=None, text="", tokenizer_=None, use_crf="", label2id_file=""): | |||||
| """ | """ | ||||
| process text. | process text. | ||||
| """ | """ | ||||
| @@ -36,13 +34,13 @@ def process(model, text, sequence_length): | |||||
| res = [] | res = [] | ||||
| ids = [] | ids = [] | ||||
| for i in data: | for i in data: | ||||
| feature = process_one_example_p(tokenizer_, i, max_seq_len=sequence_length) | |||||
| feature = process_one_example_p(tokenizer_, i, max_seq_len=bert_net_cfg.seq_length) | |||||
| features.append(feature) | features.append(feature) | ||||
| input_ids, input_mask, token_type_id = feature | input_ids, input_mask, token_type_id = feature | ||||
| input_ids = Tensor(np.array(input_ids), mstype.int32) | input_ids = Tensor(np.array(input_ids), mstype.int32) | ||||
| input_mask = Tensor(np.array(input_mask), mstype.int32) | input_mask = Tensor(np.array(input_mask), mstype.int32) | ||||
| token_type_id = Tensor(np.array(token_type_id), mstype.int32) | token_type_id = Tensor(np.array(token_type_id), mstype.int32) | ||||
| if cfg.use_crf: | |||||
| if use_crf.lower() == "true": | |||||
| backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1)) | backpointers, best_tag_id = model.predict(input_ids, input_mask, token_type_id, Tensor(1)) | ||||
| best_path = postprocess(backpointers, best_tag_id) | best_path = postprocess(backpointers, best_tag_id) | ||||
| logits = [] | logits = [] | ||||
| @@ -54,19 +52,21 @@ def process(model, text, sequence_length): | |||||
| ids = logits.asnumpy() | ids = logits.asnumpy() | ||||
| ids = np.argmax(ids, axis=-1) | ids = np.argmax(ids, axis=-1) | ||||
| ids = list(ids) | ids = list(ids) | ||||
| res = label_generation(text, ids) | |||||
| res = label_generation(text=text, probs=ids, label2id_file=label2id_file) | |||||
| return res | return res | ||||
| def submit(model, path, sequence_length): | |||||
| def submit(model=None, path="", vocab_file="", use_crf="", label2id_file=""): | |||||
| """ | """ | ||||
| submit task | submit task | ||||
| """ | """ | ||||
| tokenizer_ = tokenization.FullTokenizer(vocab_file=vocab_file) | |||||
| data = [] | data = [] | ||||
| for line in open(path): | for line in open(path): | ||||
| if not line.strip(): | if not line.strip(): | ||||
| continue | continue | ||||
| oneline = json.loads(line.strip()) | oneline = json.loads(line.strip()) | ||||
| res = process(model, oneline["text"], sequence_length) | |||||
| res = process(model=model, text=oneline["text"], tokenizer_=tokenizer_, | |||||
| use_crf=use_crf, label2id_file=label2id_file) | |||||
| print("text", oneline["text"]) | print("text", oneline["text"]) | ||||
| print("res:", res) | print("res:", res) | ||||
| data.append(json.dumps({"label": res}, ensure_ascii=False)) | data.append(json.dumps({"label": res}, ensure_ascii=False)) | ||||
| @@ -58,3 +58,77 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e | |||||
| logger.info("data size: {}".format(ds.get_dataset_size())) | logger.info("data size: {}".format(ds.get_dataset_size())) | ||||
| logger.info("repeatcount: {}".format(ds.get_repeat_count())) | logger.info("repeatcount: {}".format(ds.get_repeat_count())) | ||||
| return ds, new_repeat_count | return ds, new_repeat_count | ||||
| def create_ner_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", | |||||
| data_file_path=None, schema_file_path=None): | |||||
| """create finetune or evaluation dataset""" | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) | |||||
| if assessment_method == "Spearman_correlation": | |||||
| type_cast_op_float = C.TypeCast(mstype.float32) | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) | |||||
| else: | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| # apply shuffle operation | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def create_classification_dataset(batch_size=1, repeat_count=1, assessment_method="accuracy", | |||||
| data_file_path=None, schema_file_path=None): | |||||
| """create finetune or evaluation dataset""" | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "label_ids"]) | |||||
| if assessment_method == "Spearman_correlation": | |||||
| type_cast_op_float = C.TypeCast(mstype.float32) | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op_float) | |||||
| else: | |||||
| ds = ds.map(input_columns="label_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| # apply shuffle operation | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def create_squad_dataset(batch_size=1, repeat_count=1, data_file_path=None, schema_file_path=None, is_training=True): | |||||
| """create finetune or evaluation dataset""" | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| if is_training: | |||||
| ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", | |||||
| "start_positions", "end_positions", | |||||
| "unique_ids", "is_impossible"]) | |||||
| ds = ds.map(input_columns="start_positions", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="end_positions", operations=type_cast_op) | |||||
| else: | |||||
| ds = de.TFRecordDataset([data_file_path], schema_file_path if schema_file_path != "" else None, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "unique_ids"]) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| # apply shuffle operation | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| # apply batch operations | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| @@ -1,120 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """ | |||||
| config settings, will be used in finetune.py | |||||
| """ | |||||
| from easydict import EasyDict as edict | |||||
| import mindspore.common.dtype as mstype | |||||
| from .bert_model import BertConfig | |||||
| cfg = edict({ | |||||
| 'task': 'NER', | |||||
| 'num_labels': 41, | |||||
| 'data_file': '/your/path/train.tfrecord', | |||||
| 'schema_file': '/your/path/schema.json', | |||||
| 'epoch_num': 5, | |||||
| 'ckpt_prefix': 'bert', | |||||
| 'ckpt_dir': None, | |||||
| 'pre_training_ckpt': '/your/path/pre_training.ckpt', | |||||
| 'use_crf': False, | |||||
| 'optimizer': 'Lamb', | |||||
| 'AdamWeightDecayDynamicLR': edict({ | |||||
| 'learning_rate': 2e-5, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 1.0, | |||||
| 'weight_decay': 1e-5, | |||||
| 'eps': 1e-6, | |||||
| }), | |||||
| 'Lamb': edict({ | |||||
| 'start_learning_rate': 2e-5, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 1.0, | |||||
| 'weight_decay': 0.01, | |||||
| 'decay_filter': lambda x: False, | |||||
| }), | |||||
| 'Momentum': edict({ | |||||
| 'learning_rate': 2e-5, | |||||
| 'momentum': 0.9, | |||||
| }), | |||||
| }) | |||||
| bert_net_cfg = BertConfig( | |||||
| batch_size=16, | |||||
| seq_length=128, | |||||
| vocab_size=21128, | |||||
| hidden_size=768, | |||||
| num_hidden_layers=12, | |||||
| num_attention_heads=12, | |||||
| intermediate_size=3072, | |||||
| hidden_act="gelu", | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=512, | |||||
| type_vocab_size=2, | |||||
| initializer_range=0.02, | |||||
| use_relative_positions=False, | |||||
| input_mask_from_dataset=True, | |||||
| token_type_ids_from_dataset=True, | |||||
| dtype=mstype.float32, | |||||
| compute_type=mstype.float16, | |||||
| ) | |||||
| tag_to_index = { | |||||
| "O": 0, | |||||
| "S_address": 1, | |||||
| "B_address": 2, | |||||
| "M_address": 3, | |||||
| "E_address": 4, | |||||
| "S_book": 5, | |||||
| "B_book": 6, | |||||
| "M_book": 7, | |||||
| "E_book": 8, | |||||
| "S_company": 9, | |||||
| "B_company": 10, | |||||
| "M_company": 11, | |||||
| "E_company": 12, | |||||
| "S_game": 13, | |||||
| "B_game": 14, | |||||
| "M_game": 15, | |||||
| "E_game": 16, | |||||
| "S_government": 17, | |||||
| "B_government": 18, | |||||
| "M_government": 19, | |||||
| "E_government": 20, | |||||
| "S_movie": 21, | |||||
| "B_movie": 22, | |||||
| "M_movie": 23, | |||||
| "E_movie": 24, | |||||
| "S_name": 25, | |||||
| "B_name": 26, | |||||
| "M_name": 27, | |||||
| "E_name": 28, | |||||
| "S_organization": 29, | |||||
| "B_organization": 30, | |||||
| "M_organization": 31, | |||||
| "E_organization": 32, | |||||
| "S_position": 33, | |||||
| "B_position": 34, | |||||
| "M_position": 35, | |||||
| "E_position": 36, | |||||
| "S_scene": 37, | |||||
| "B_scene": 38, | |||||
| "M_scene": 39, | |||||
| "E_scene": 40, | |||||
| "<START>": 41, | |||||
| "<STOP>": 42 | |||||
| } | |||||
| @@ -21,18 +21,30 @@ from easydict import EasyDict as edict | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| from .bert_model import BertConfig | from .bert_model import BertConfig | ||||
| cfg = edict({ | |||||
| 'task': 'NER', | |||||
| 'num_labels': 41, | |||||
| 'data_file': '/your/path/evaluation.tfrecord', | |||||
| 'schema_file': '/your/path/schema.json', | |||||
| 'finetune_ckpt': '/your/path/your.ckpt', | |||||
| 'use_crf': False, | |||||
| 'clue_benchmark': False, | |||||
| optimizer_cfg = edict({ | |||||
| 'optimizer': 'Lamb', | |||||
| 'AdamWeightDecayDynamicLR': edict({ | |||||
| 'learning_rate': 2e-5, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 1.0, | |||||
| 'weight_decay': 1e-5, | |||||
| 'eps': 1e-6, | |||||
| }), | |||||
| 'Lamb': edict({ | |||||
| 'start_learning_rate': 2e-5, | |||||
| 'end_learning_rate': 1e-7, | |||||
| 'power': 1.0, | |||||
| 'weight_decay': 0.01, | |||||
| 'decay_filter': lambda x: False, | |||||
| }), | |||||
| 'Momentum': edict({ | |||||
| 'learning_rate': 2e-5, | |||||
| 'momentum': 0.9, | |||||
| }), | |||||
| }) | }) | ||||
| bert_net_cfg = BertConfig( | bert_net_cfg = BertConfig( | ||||
| batch_size=16 if not cfg.clue_benchmark else 1, | |||||
| batch_size=16, | |||||
| seq_length=128, | seq_length=128, | ||||
| vocab_size=21128, | vocab_size=21128, | ||||
| hidden_size=768, | hidden_size=768, | ||||
| @@ -40,8 +52,8 @@ bert_net_cfg = BertConfig( | |||||
| num_attention_heads=12, | num_attention_heads=12, | ||||
| intermediate_size=3072, | intermediate_size=3072, | ||||
| hidden_act="gelu", | hidden_act="gelu", | ||||
| hidden_dropout_prob=0.0, | |||||
| attention_probs_dropout_prob=0.0, | |||||
| hidden_dropout_prob=0.1, | |||||
| attention_probs_dropout_prob=0.1, | |||||
| max_position_embeddings=512, | max_position_embeddings=512, | ||||
| type_vocab_size=2, | type_vocab_size=2, | ||||
| initializer_range=0.02, | initializer_range=0.02, | ||||
| @@ -0,0 +1,123 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| ''' | |||||
| Bert finetune and evaluation model script. | |||||
| ''' | |||||
| import mindspore.nn as nn | |||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.ops import operations as P | |||||
| from .bert_model import BertModel | |||||
| class BertCLSModel(nn.Cell): | |||||
| """ | |||||
| This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), | |||||
| LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final | |||||
| logits as the results of log_softmax is propotional to that of softmax. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False, | |||||
| assessment_method=""): | |||||
| super(BertCLSModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout_prob = 0.0 | |||||
| config.hidden_probs_dropout_prob = 0.0 | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.num_labels = num_labels | |||||
| self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||||
| self.assessment_method = assessment_method | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| _, pooled_output, _ = \ | |||||
| self.bert(input_ids, token_type_id, input_mask) | |||||
| cls = self.cast(pooled_output, self.dtype) | |||||
| cls = self.dropout(cls) | |||||
| logits = self.dense_1(cls) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| if self.assessment_method != "spearman_correlation": | |||||
| logits = self.log_softmax(logits) | |||||
| return logits | |||||
| class BertSquadModel(nn.Cell): | |||||
| ''' | |||||
| This class is responsible for SQuAD | |||||
| ''' | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertSquadModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout_prob = 0.0 | |||||
| config.hidden_probs_dropout_prob = 0.0 | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.num_labels = num_labels | |||||
| self.dtype = config.dtype | |||||
| self.log_softmax = P.LogSoftmax(axis=1) | |||||
| self.is_training = is_training | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask) | |||||
| batch_size, seq_length, hidden_size = P.Shape()(sequence_output) | |||||
| sequence = P.Reshape()(sequence_output, (-1, hidden_size)) | |||||
| logits = self.dense1(sequence) | |||||
| logits = P.Cast()(logits, self.dtype) | |||||
| logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels)) | |||||
| logits = self.log_softmax(logits) | |||||
| return logits | |||||
| class BertNERModel(nn.Cell): | |||||
| """ | |||||
| This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). | |||||
| The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, | |||||
| use_one_hot_embeddings=False): | |||||
| super(BertNERModel, self).__init__() | |||||
| if not is_training: | |||||
| config.hidden_dropout_prob = 0.0 | |||||
| config.hidden_probs_dropout_prob = 0.0 | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.num_labels = num_labels | |||||
| self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = (-1, config.hidden_size) | |||||
| self.use_crf = use_crf | |||||
| self.origin_shape = (config.batch_size, config.seq_length, self.num_labels) | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| sequence_output, _, _ = \ | |||||
| self.bert(input_ids, token_type_id, input_mask) | |||||
| seq = self.dropout(sequence_output) | |||||
| seq = self.reshape(seq, self.shape) | |||||
| logits = self.dense_1(seq) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| if self.use_crf: | |||||
| return_value = self.reshape(logits, self.origin_shape) | |||||
| else: | |||||
| return_value = self.log_softmax(logits) | |||||
| return return_value | |||||
| @@ -52,12 +52,12 @@ def process_one_example_p(tokenizer, text, max_seq_len=128): | |||||
| feature = (input_ids, input_mask, segment_ids) | feature = (input_ids, input_mask, segment_ids) | ||||
| return feature | return feature | ||||
| def label_generation(text, probs): | |||||
| def label_generation(text="", probs=None, label2id_file=""): | |||||
| """generate label""" | """generate label""" | ||||
| data = [text] | data = [text] | ||||
| probs = [probs] | probs = [probs] | ||||
| result = [] | result = [] | ||||
| label2id = json.loads(open("./label2id.json").read()) | |||||
| label2id = json.loads(open(label2id_file).read()) | |||||
| id2label = [k for k, v in label2id.items()] | id2label = [k for k, v in label2id.items()] | ||||
| for index, prob in enumerate(probs): | for index, prob in enumerate(probs): | ||||
| @@ -17,347 +17,13 @@ | |||||
| Functional Cells used in Bert finetune and evaluation. | Functional Cells used in Bert finetune and evaluation. | ||||
| """ | """ | ||||
| import os | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.initializer import TruncatedNormal | |||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops import functional as F | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.common.parameter import Parameter, ParameterTuple | |||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore.train.parallel_utils import ParallelMode | |||||
| from mindspore.communication.management import get_group_size | |||||
| from mindspore import context | |||||
| from .bert_model import BertModel | |||||
| from .bert_for_pre_training import clip_grad | |||||
| from .CRF import CRF | |||||
| from mindspore.train.callback import Callback | |||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| GRADIENT_CLIP_VALUE = 1.0 | |||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||||
| reciprocal = P.Reciprocal() | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * reciprocal(scale) | |||||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||||
| grad_overflow = P.FloatStatus() | |||||
| @_grad_overflow.register("Tensor") | |||||
| def _tensor_grad_overflow(grad): | |||||
| return grad_overflow(grad) | |||||
| class BertFinetuneCell(nn.Cell): | |||||
| """ | |||||
| Especifically defined for finetuning where only four inputs tensor are needed. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(BertFinetuneCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', | |||||
| get_by_list=True, | |||||
| sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.allreduce = P.AllReduce() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.gpu_target = False | |||||
| if context.get_context("device_target") == "GPU": | |||||
| self.gpu_target = True | |||||
| self.float_status = P.FloatStatus() | |||||
| self.addn = P.AddN() | |||||
| self.reshape = P.Reshape() | |||||
| else: | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids, | |||||
| sens=None): | |||||
| weights = self.weights | |||||
| init = False | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| if not self.gpu_target: | |||||
| init = self.alloc_status() | |||||
| clear_before_grad = self.clear_before_grad(init) | |||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| label_ids, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| if not self.gpu_target: | |||||
| flag = self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| else: | |||||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||||
| flag_sum = self.addn(flag_sum) | |||||
| flag_sum = self.reshape(flag_sum, (())) | |||||
| if self.is_distributed: | |||||
| flag_reduce = self.allreduce(flag_sum) | |||||
| cond = self.less_equal(self.base, flag_reduce) | |||||
| else: | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond) | |||||
| return F.depend(ret, succ) | |||||
| class BertSquadCell(nn.Cell): | |||||
| """ | |||||
| specifically defined for finetuning where only four inputs tensor are needed. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(BertSquadCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.allreduce = P.AllReduce() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible, | |||||
| sens=None): | |||||
| weights = self.weights | |||||
| init = self.alloc_status() | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| clear_before_grad = self.clear_before_grad(init) | |||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| flag = self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| if self.is_distributed: | |||||
| flag_reduce = self.allreduce(flag_sum) | |||||
| cond = self.less_equal(self.base, flag_reduce) | |||||
| else: | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond) | |||||
| return F.depend(ret, succ) | |||||
| class BertRegressionModel(nn.Cell): | |||||
| """ | |||||
| Bert finetune model for regression task | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertRegressionModel, self).__init__() | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.num_labels = num_labels | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||||
| self.dense_1 = nn.Dense(config.hidden_size, 1, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(mstype.float16) | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| _, pooled_output, _ = self.bert(input_ids, token_type_id, input_mask) | |||||
| cls = self.cast(pooled_output, self.dtype) | |||||
| cls = self.dropout(cls) | |||||
| logits = self.dense_1(cls) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| return logits | |||||
| class BertCLSModel(nn.Cell): | |||||
| """ | |||||
| This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), | |||||
| LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final | |||||
| logits as the results of log_softmax is propotional to that of softmax. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertCLSModel, self).__init__() | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.num_labels = num_labels | |||||
| self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| _, pooled_output, _ = \ | |||||
| self.bert(input_ids, token_type_id, input_mask) | |||||
| cls = self.cast(pooled_output, self.dtype) | |||||
| cls = self.dropout(cls) | |||||
| logits = self.dense_1(cls) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| log_probs = self.log_softmax(logits) | |||||
| return log_probs | |||||
| class BertSquadModel(nn.Cell): | |||||
| """ | |||||
| Bert finetune model for SQuAD v1.1 task | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertSquadModel, self).__init__() | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.num_labels = num_labels | |||||
| self.dtype = config.dtype | |||||
| self.log_softmax = P.LogSoftmax(axis=1) | |||||
| self.is_training = is_training | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask) | |||||
| batch_size, seq_length, hidden_size = P.Shape()(sequence_output) | |||||
| sequence = P.Reshape()(sequence_output, (-1, hidden_size)) | |||||
| logits = self.dense1(sequence) | |||||
| logits = P.Cast()(logits, self.dtype) | |||||
| logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels)) | |||||
| logits = self.log_softmax(logits) | |||||
| return logits | |||||
| class BertNERModel(nn.Cell): | |||||
| """ | |||||
| This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). | |||||
| The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, | |||||
| use_one_hot_embeddings=False): | |||||
| super(BertNERModel, self).__init__() | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.cast = P.Cast() | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.log_softmax = P.LogSoftmax(axis=-1) | |||||
| self.dtype = config.dtype | |||||
| self.num_labels = num_labels | |||||
| self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.dropout = nn.Dropout(1 - dropout_prob) | |||||
| self.reshape = P.Reshape() | |||||
| self.shape = (-1, config.hidden_size) | |||||
| self.use_crf = use_crf | |||||
| self.origin_shape = (config.batch_size, config.seq_length, self.num_labels) | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| sequence_output, _, _ = \ | |||||
| self.bert(input_ids, token_type_id, input_mask) | |||||
| seq = self.dropout(sequence_output) | |||||
| seq = self.reshape(seq, self.shape) | |||||
| logits = self.dense_1(seq) | |||||
| logits = self.cast(logits, self.dtype) | |||||
| if self.use_crf: | |||||
| return_value = self.reshape(logits, self.origin_shape) | |||||
| else: | |||||
| return_value = self.log_softmax(logits) | |||||
| return return_value | |||||
| class CrossEntropyCalculation(nn.Cell): | class CrossEntropyCalculation(nn.Cell): | ||||
| """ | """ | ||||
| @@ -387,95 +53,73 @@ class CrossEntropyCalculation(nn.Cell): | |||||
| return_value = logits * 1.0 | return_value = logits * 1.0 | ||||
| return return_value | return return_value | ||||
| class BertCLS(nn.Cell): | |||||
| """ | |||||
| Train interface for classification finetuning task. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertCLS, self).__init__() | |||||
| self.bert = BertCLSModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| def construct(self, input_ids, input_mask, token_type_id, label_ids): | |||||
| log_probs = self.bert(input_ids, input_mask, token_type_id) | |||||
| loss = self.loss(log_probs, label_ids, self.num_labels) | |||||
| return loss | |||||
| class BertNER(nn.Cell): | |||||
| """ | |||||
| Train interface for sequence labeling finetuning task. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=11, use_crf=False, tag_to_index=None, dropout_prob=0.0, | |||||
| use_one_hot_embeddings=False): | |||||
| super(BertNER, self).__init__() | |||||
| self.bert = BertNERModel(config, is_training, num_labels, use_crf, dropout_prob, use_one_hot_embeddings) | |||||
| if use_crf: | |||||
| if not tag_to_index: | |||||
| raise Exception("The dict for tag-index mapping should be provided for CRF.") | |||||
| self.loss = CRF(tag_to_index, config.batch_size, config.seq_length, is_training) | |||||
| else: | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| self.use_crf = use_crf | |||||
| def construct(self, input_ids, input_mask, token_type_id, label_ids): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.use_crf: | |||||
| loss = self.loss(logits, label_ids) | |||||
| else: | |||||
| loss = self.loss(logits, label_ids, self.num_labels) | |||||
| return loss | |||||
| class BertSquad(nn.Cell): | |||||
| """ | |||||
| Train interface for SQuAD finetuning task. | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertSquad, self).__init__() | |||||
| self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| self.seq_length = config.seq_length | |||||
| self.is_training = is_training | |||||
| self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') | |||||
| self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') | |||||
| self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') | |||||
| self.sum = P.ReduceSum() | |||||
| self.equal = P.Equal() | |||||
| self.argmax = P.ArgMaxWithValue(axis=1) | |||||
| self.squeeze = P.Squeeze(axis=-1) | |||||
| def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.is_training: | |||||
| unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) | |||||
| unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) | |||||
| start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) | |||||
| end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) | |||||
| total_loss = (start_loss + end_loss) / 2.0 | |||||
| else: | |||||
| start_logits = self.squeeze(logits[:, :, 0:1]) | |||||
| end_logits = self.squeeze(logits[:, :, 1:2]) | |||||
| total_loss = (unique_id, start_logits, end_logits) | |||||
| return total_loss | |||||
| class BertReg(nn.Cell): | |||||
| """ | |||||
| Bert finetune model with loss for regression task | |||||
| """ | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertReg, self).__init__() | |||||
| self.bert = BertRegressionModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) | |||||
| self.loss = nn.MSELoss() | |||||
| self.is_training = is_training | |||||
| self.sigmoid = P.Sigmoid() | |||||
| self.cast = P.Cast() | |||||
| self.mul = P.Mul() | |||||
| def construct(self, input_ids, input_mask, token_type_id, labels): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.is_training: | |||||
| loss = self.loss(logits, labels) | |||||
| else: | |||||
| loss = logits | |||||
| return loss | |||||
| def make_directory(path: str): | |||||
| """Make directory.""" | |||||
| if path is None or not isinstance(path, str) or path.strip() == "": | |||||
| logger.error("The path(%r) is invalid type.", path) | |||||
| raise TypeError("Input path is invaild type") | |||||
| # convert the relative paths | |||||
| path = os.path.realpath(path) | |||||
| logger.debug("The abs path is %r", path) | |||||
| # check the path is exist and write permissions? | |||||
| if os.path.exists(path): | |||||
| real_path = path | |||||
| else: | |||||
| # All exceptions need to be caught because create directory maybe have some limit(permissions) | |||||
| logger.debug("The directory(%s) doesn't exist, will create it", path) | |||||
| try: | |||||
| os.makedirs(path, exist_ok=True) | |||||
| real_path = path | |||||
| except PermissionError as e: | |||||
| logger.error("No write permission on the directory(%r), error = %r", path, e) | |||||
| raise TypeError("No write permission on the directory.") | |||||
| return real_path | |||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss in NAN or INF terminating training. | |||||
| Note: | |||||
| if per_print_times is 0 do not print loss. | |||||
| Args: | |||||
| per_print_times (int): Print loss every times. Default: 1. | |||||
| """ | |||||
| def __init__(self, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0") | |||||
| self._per_print_times = per_print_times | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num, | |||||
| str(cb_params.net_outputs))) | |||||
| def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix): | |||||
| """ | |||||
| Find the ckpt finetune generated and load it into eval network. | |||||
| """ | |||||
| files = os.listdir(load_finetune_checkpoint_dir) | |||||
| pre_len = len(prefix) | |||||
| max_num = 0 | |||||
| for filename in files: | |||||
| name_ext = os.path.splitext(filename) | |||||
| if name_ext[-1] != ".ckpt": | |||||
| continue | |||||
| #steps_per_epoch = ds.get_dataset_size() | |||||
| if filename.find(prefix) == 0 and not filename[pre_len].isalpha(): | |||||
| index = filename[pre_len:].find("-") | |||||
| if index == 0 and max_num == 0: | |||||
| load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) | |||||
| elif index not in (0, -1): | |||||
| name_split = name_ext[-2].split('_') | |||||
| if (steps_per_epoch != int(name_split[len(name_split)-1])) \ | |||||
| or (epoch_num != int(filename[pre_len + index + 1:pre_len + index + 2])): | |||||
| continue | |||||
| num = filename[pre_len + 1:pre_len + index] | |||||
| if int(num) > max_num: | |||||
| max_num = int(num) | |||||
| load_finetune_checkpoint_path = os.path.join(load_finetune_checkpoint_dir, filename) | |||||
| return load_finetune_checkpoint_path | |||||