Merge pull request !1596 from yoonlee666/edit-exampletags/v0.5.0-beta
| @@ -5,7 +5,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( | |||||
| ## Requirements | ## Requirements | ||||
| - Install [MindSpore](https://www.mindspore.cn/install/en). | - Install [MindSpore](https://www.mindspore.cn/install/en). | ||||
| - Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. | - Download the zhwiki dataset for pre-training. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format and move the files to a specified path. | ||||
| - Download the CLUE dataset for fine-tuning and evaluation. | |||||
| - Download the CLUE/SQuAD v1.1 dataset for fine-tuning and evaluation. | |||||
| > Notes: | > Notes: | ||||
| If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. | If you are running a fine-tuning or evaluation task, prepare the corresponding checkpoint file. | ||||
| @@ -36,11 +36,20 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( | |||||
| ### Evaluation | ### Evaluation | ||||
| - Set options in `evaluation_config.py`. Make sure the 'data_file', 'schema_file' and 'finetune_ckpt' are set to your own path. | - Set options in `evaluation_config.py`. Make sure the 'data_file', 'schema_file' and 'finetune_ckpt' are set to your own path. | ||||
| - Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model. | |||||
| - NER: Run `evaluation.py` for evaluation of BERT-base and BERT-NEZHA model. | |||||
| ```bash | ```bash | ||||
| python evaluation.py | python evaluation.py | ||||
| ``` | ``` | ||||
| - SQuAD v1.1: Run `squadeval.py` and `SQuAD_postprocess.py` for evaluation of BERT-base and BERT-NEZHA model. | |||||
| ```bash | |||||
| python squadeval.py | |||||
| ``` | |||||
| ```bash | |||||
| python SQuAD_postprocess.py | |||||
| ``` | |||||
| ## Usage | ## Usage | ||||
| ### Pre-Training | ### Pre-Training | ||||
| @@ -80,7 +89,7 @@ config.py: | |||||
| optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | optimizer optimizer used in the network: AdamWerigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | ||||
| finetune_config.py: | finetune_config.py: | ||||
| task task type: NER | XNLI | LCQMC | SENTIi | OTHERS | |||||
| task task type: NER | SQUAD | OTHERS | |||||
| num_labels number of labels to do classification | num_labels number of labels to do classification | ||||
| data_file dataset file to load: PATH, default is "/your/path/train.tfrecord" | data_file dataset file to load: PATH, default is "/your/path/train.tfrecord" | ||||
| schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" | schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" | ||||
| @@ -92,7 +101,7 @@ finetune_config.py: | |||||
| optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | optimizer optimizer used in fine-tune network: AdamWeigtDecayDynamicLR | Lamb | Momentum, default is "Lamb" | ||||
| evaluation_config.py: | evaluation_config.py: | ||||
| task task type: NER | XNLI | LCQMC | SENTI | OTHERS | |||||
| task task type: NER | SQUAD | OTHERS | |||||
| num_labels number of labels to do classsification | num_labels number of labels to do classsification | ||||
| data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord" | data_file dataset file to load: PATH, default is "/your/path/evaluation.tfrecord" | ||||
| schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" | schema_file dataset schema file to load: PATH, default is "/your/path/schema.json" | ||||
| @@ -18,10 +18,9 @@ Bert finetune script. | |||||
| ''' | ''' | ||||
| import os | import os | ||||
| from src.utils import BertFinetuneCell, BertCLS, BertNER | |||||
| from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell | |||||
| from src.finetune_config import cfg, bert_net_cfg, tag_to_index | from src.finetune_config import cfg, bert_net_cfg, tag_to_index | ||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.communication.management as D | |||||
| from mindspore import context | from mindspore import context | ||||
| import mindspore.dataset as de | import mindspore.dataset as de | ||||
| import mindspore.dataset.transforms.c_transforms as C | import mindspore.dataset.transforms.c_transforms as C | ||||
| @@ -58,8 +57,6 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| ''' | ''' | ||||
| get dataset | get dataset | ||||
| ''' | ''' | ||||
| _ = distribute_file | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", | ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", | ||||
| "segment_ids", "label_ids"]) | "segment_ids", "label_ids"]) | ||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| @@ -77,10 +74,29 @@ def get_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | ds = ds.batch(batch_size, drop_remainder=True) | ||||
| return ds | return ds | ||||
| def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| ''' | |||||
| get SQuAD dataset | |||||
| ''' | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", | |||||
| "start_positions", "end_positions", | |||||
| "unique_ids", "is_impossible"]) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="start_positions", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="end_positions", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| buffer_size = 960 | |||||
| ds = ds.shuffle(buffer_size=buffer_size) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def test_train(): | def test_train(): | ||||
| ''' | ''' | ||||
| finetune function | finetune function | ||||
| pytest -s finetune.py::test_train | |||||
| ''' | ''' | ||||
| devid = int(os.getenv('DEVICE_ID')) | devid = int(os.getenv('DEVICE_ID')) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | ||||
| @@ -92,9 +108,14 @@ def test_train(): | |||||
| tag_to_index=tag_to_index, dropout_prob=0.1) | tag_to_index=tag_to_index, dropout_prob=0.1) | ||||
| else: | else: | ||||
| netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | netwithloss = BertNER(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | ||||
| elif cfg.task == 'SQUAD': | |||||
| netwithloss = BertSquad(bert_net_cfg, True, 2, dropout_prob=0.1) | |||||
| else: | else: | ||||
| netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | netwithloss = BertCLS(bert_net_cfg, True, num_labels=cfg.num_labels, dropout_prob=0.1) | ||||
| dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num) | |||||
| if cfg.task == 'SQUAD': | |||||
| dataset = get_squad_dataset(bert_net_cfg.batch_size, cfg.epoch_num) | |||||
| else: | |||||
| dataset = get_dataset(bert_net_cfg.batch_size, cfg.epoch_num) | |||||
| # optimizer | # optimizer | ||||
| steps_per_epoch = dataset.get_dataset_size() | steps_per_epoch = dataset.get_dataset_size() | ||||
| if cfg.optimizer == 'AdamWeightDecayDynamicLR': | if cfg.optimizer == 'AdamWeightDecayDynamicLR': | ||||
| @@ -103,13 +124,14 @@ def test_train(): | |||||
| learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, | learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, | ||||
| end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, | end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, | ||||
| power=cfg.AdamWeightDecayDynamicLR.power, | power=cfg.AdamWeightDecayDynamicLR.power, | ||||
| warmup_steps=steps_per_epoch, | |||||
| warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), | |||||
| weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, | weight_decay=cfg.AdamWeightDecayDynamicLR.weight_decay, | ||||
| eps=cfg.AdamWeightDecayDynamicLR.eps) | eps=cfg.AdamWeightDecayDynamicLR.eps) | ||||
| elif cfg.optimizer == 'Lamb': | elif cfg.optimizer == 'Lamb': | ||||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, | optimizer = Lamb(netwithloss.trainable_params(), decay_steps=steps_per_epoch * cfg.epoch_num, | ||||
| start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, | start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, | ||||
| power=cfg.Lamb.power, warmup_steps=steps_per_epoch, decay_filter=cfg.Lamb.decay_filter) | |||||
| power=cfg.Lamb.power, weight_decay=cfg.Lamb.weight_decay, | |||||
| warmup_steps=int(steps_per_epoch * cfg.epoch_num * 0.1), decay_filter=cfg.Lamb.decay_filter) | |||||
| elif cfg.optimizer == 'Momentum': | elif cfg.optimizer == 'Momentum': | ||||
| optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, | optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate, | ||||
| momentum=cfg.Momentum.momentum) | momentum=cfg.Momentum.momentum) | ||||
| @@ -122,10 +144,12 @@ def test_train(): | |||||
| load_param_into_net(netwithloss, param_dict) | load_param_into_net(netwithloss, param_dict) | ||||
| update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) | update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**32, scale_factor=2, scale_window=1000) | ||||
| netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| if cfg.task == 'SQUAD': | |||||
| netwithgrads = BertSquadCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| else: | |||||
| netwithgrads = BertFinetuneCell(netwithloss, optimizer=optimizer, scale_update_cell=update_cell) | |||||
| model = Model(netwithgrads) | model = Model(netwithgrads) | ||||
| model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) | model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) | ||||
| D.release() | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| test_train() | test_train() | ||||
| @@ -0,0 +1,99 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """Evaluation script for SQuAD task""" | |||||
| import os | |||||
| import collections | |||||
| import mindspore.dataset as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore import context | |||||
| from mindspore.common.tensor import Tensor | |||||
| from mindspore.train.model import Model | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from src import tokenization | |||||
| from src.evaluation_config import cfg, bert_net_cfg | |||||
| from src.utils import BertSquad | |||||
| from src.create_squad_data import read_squad_examples, convert_examples_to_features | |||||
| from src.run_squad import write_predictions | |||||
| def get_squad_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||||
| """get SQuAD dataset from tfrecord""" | |||||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", | |||||
| "segment_ids", "unique_ids"], | |||||
| shuffle=False) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | |||||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_ids", operations=type_cast_op) | |||||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||||
| ds = ds.repeat(repeat_count) | |||||
| ds = ds.batch(batch_size, drop_remainder=True) | |||||
| return ds | |||||
| def test_eval(): | |||||
| """Evaluation function for SQuAD task""" | |||||
| tokenizer = tokenization.FullTokenizer(vocab_file="./vocab.txt", do_lower_case=True) | |||||
| input_file = "dataset/v1.1/dev-v1.1.json" | |||||
| eval_examples = read_squad_examples(input_file, False) | |||||
| eval_features = convert_examples_to_features( | |||||
| examples=eval_examples, | |||||
| tokenizer=tokenizer, | |||||
| max_seq_length=384, | |||||
| doc_stride=128, | |||||
| max_query_length=64, | |||||
| is_training=False, | |||||
| output_fn=None, | |||||
| verbose_logging=False) | |||||
| device_id = int(os.getenv('DEVICE_ID')) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=device_id) | |||||
| dataset = get_squad_dataset(bert_net_cfg.batch_size, 1) | |||||
| net = BertSquad(bert_net_cfg, False, 2) | |||||
| net.set_train(False) | |||||
| param_dict = load_checkpoint(cfg.finetune_ckpt) | |||||
| load_param_into_net(net, param_dict) | |||||
| model = Model(net) | |||||
| output = [] | |||||
| RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) | |||||
| columns_list = ["input_ids", "input_mask", "segment_ids", "unique_ids"] | |||||
| for data in dataset.create_dict_iterator(): | |||||
| input_data = [] | |||||
| for i in columns_list: | |||||
| input_data.append(Tensor(data[i])) | |||||
| input_ids, input_mask, segment_ids, unique_ids = input_data | |||||
| start_positions = Tensor([1], mstype.float32) | |||||
| end_positions = Tensor([1], mstype.float32) | |||||
| is_impossible = Tensor([1], mstype.float32) | |||||
| logits = model.predict(input_ids, input_mask, segment_ids, start_positions, | |||||
| end_positions, unique_ids, is_impossible) | |||||
| ids = logits[0].asnumpy() | |||||
| start = logits[1].asnumpy() | |||||
| end = logits[2].asnumpy() | |||||
| for i in range(bert_net_cfg.batch_size): | |||||
| unique_id = int(ids[i]) | |||||
| start_logits = [float(x) for x in start[i].flat] | |||||
| end_logits = [float(x) for x in end[i].flat] | |||||
| output.append(RawResult( | |||||
| unique_id=unique_id, | |||||
| start_logits=start_logits, | |||||
| end_logits=end_logits)) | |||||
| write_predictions(eval_examples, eval_features, output, 20, 30, True, "./predictions.json", | |||||
| None, None, False, False) | |||||
| if __name__ == "__main__": | |||||
| test_eval() | |||||
| @@ -43,6 +43,7 @@ cfg = edict({ | |||||
| 'start_learning_rate': 2e-5, | 'start_learning_rate': 2e-5, | ||||
| 'end_learning_rate': 1e-7, | 'end_learning_rate': 1e-7, | ||||
| 'power': 1.0, | 'power': 1.0, | ||||
| 'weight_decay': 0.01, | |||||
| 'decay_filter': lambda x: False, | 'decay_filter': lambda x: False, | ||||
| }), | }), | ||||
| 'Momentum': edict({ | 'Momentum': edict({ | ||||
| @@ -29,7 +29,7 @@ from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| from mindspore.communication.management import get_group_size | from mindspore.communication.management import get_group_size | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.model_zoo.Bert_NEZHA.bert_model import BertModel | |||||
| from .bert_model import BertModel | |||||
| from .bert_for_pre_training import clip_grad | from .bert_for_pre_training import clip_grad | ||||
| from .CRF import CRF | from .CRF import CRF | ||||
| @@ -131,6 +131,98 @@ class BertFinetuneCell(nn.Cell): | |||||
| ret = (loss, cond) | ret = (loss, cond) | ||||
| return F.depend(ret, succ) | return F.depend(ret, succ) | ||||
| class BertSquadCell(nn.Cell): | |||||
| """ | |||||
| specifically defined for finetuning where only four inputs tensor are needed. | |||||
| """ | |||||
| def __init__(self, network, optimizer, scale_update_cell=None): | |||||
| super(BertSquadCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | |||||
| self.weights = ParameterTuple(network.trainable_params()) | |||||
| self.optimizer = optimizer | |||||
| self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) | |||||
| self.reducer_flag = False | |||||
| self.allreduce = P.AllReduce() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| self.reducer_flag = True | |||||
| self.grad_reducer = None | |||||
| if self.reducer_flag: | |||||
| mean = context.get_auto_parallel_context("mirror_mean") | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||||
| self.cast = P.Cast() | |||||
| self.alloc_status = P.NPUAllocFloatStatus() | |||||
| self.get_status = P.NPUGetFloatStatus() | |||||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||||
| self.base = Tensor(1, mstype.float32) | |||||
| self.less_equal = P.LessEqual() | |||||
| self.hyper_map = C.HyperMap() | |||||
| self.loss_scale = None | |||||
| self.loss_scaling_manager = scale_update_cell | |||||
| if scale_update_cell: | |||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | |||||
| name="loss_scale") | |||||
| def construct(self, | |||||
| input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible, | |||||
| sens=None): | |||||
| weights = self.weights | |||||
| init = self.alloc_status() | |||||
| loss = self.network(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible) | |||||
| if sens is None: | |||||
| scaling_sens = self.loss_scale | |||||
| else: | |||||
| scaling_sens = sens | |||||
| grads = self.grad(self.network, weights)(input_ids, | |||||
| input_mask, | |||||
| token_type_id, | |||||
| start_position, | |||||
| end_position, | |||||
| unique_id, | |||||
| is_impossible, | |||||
| self.cast(scaling_sens, | |||||
| mstype.float32)) | |||||
| clear_before_grad = self.clear_before_grad(init) | |||||
| F.control_depend(loss, init) | |||||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | |||||
| grads = self.grad_reducer(grads) | |||||
| flag = self.get_status(init) | |||||
| flag_sum = self.reduce_sum(init, (0,)) | |||||
| if self.is_distributed: | |||||
| flag_reduce = self.allreduce(flag_sum) | |||||
| cond = self.less_equal(self.base, flag_reduce) | |||||
| else: | |||||
| cond = self.less_equal(self.base, flag_sum) | |||||
| F.control_depend(grads, flag) | |||||
| F.control_depend(flag, flag_sum) | |||||
| overflow = cond | |||||
| if sens is None: | |||||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||||
| if overflow: | |||||
| succ = False | |||||
| else: | |||||
| succ = self.optimizer(grads) | |||||
| ret = (loss, cond) | |||||
| return F.depend(ret, succ) | |||||
| class BertCLSModel(nn.Cell): | class BertCLSModel(nn.Cell): | ||||
| """ | """ | ||||
| This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), | This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), | ||||
| @@ -159,6 +251,30 @@ class BertCLSModel(nn.Cell): | |||||
| log_probs = self.log_softmax(logits) | log_probs = self.log_softmax(logits) | ||||
| return log_probs | return log_probs | ||||
| class BertSquadModel(nn.Cell): | |||||
| ''' | |||||
| This class is responsible for SQuAD | |||||
| ''' | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertSquadModel, self).__init__() | |||||
| self.bert = BertModel(config, is_training, use_one_hot_embeddings) | |||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.dense1 = nn.Dense(config.hidden_size, num_labels, weight_init=self.weight_init, | |||||
| has_bias=True).to_float(config.compute_type) | |||||
| self.num_labels = num_labels | |||||
| self.dtype = config.dtype | |||||
| self.log_softmax = P.LogSoftmax(axis=1) | |||||
| self.is_training = is_training | |||||
| def construct(self, input_ids, input_mask, token_type_id): | |||||
| sequence_output, _, _ = self.bert(input_ids, token_type_id, input_mask) | |||||
| batch_size, seq_length, hidden_size = P.Shape()(sequence_output) | |||||
| sequence = P.Reshape()(sequence_output, (-1, hidden_size)) | |||||
| logits = self.dense1(sequence) | |||||
| logits = P.Cast()(logits, self.dtype) | |||||
| logits = P.Reshape()(logits, (batch_size, seq_length, self.num_labels)) | |||||
| logits = self.log_softmax(logits) | |||||
| return logits | |||||
| class BertNERModel(nn.Cell): | class BertNERModel(nn.Cell): | ||||
| """ | """ | ||||
| @@ -261,3 +377,36 @@ class BertNER(nn.Cell): | |||||
| else: | else: | ||||
| loss = self.loss(logits, label_ids, self.num_labels) | loss = self.loss(logits, label_ids, self.num_labels) | ||||
| return loss | return loss | ||||
| class BertSquad(nn.Cell): | |||||
| ''' | |||||
| Train interface for SQuAD finetuning task. | |||||
| ''' | |||||
| def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): | |||||
| super(BertSquad, self).__init__() | |||||
| self.bert = BertSquadModel(config, is_training, num_labels, dropout_prob, use_one_hot_embeddings) | |||||
| self.loss = CrossEntropyCalculation(is_training) | |||||
| self.num_labels = num_labels | |||||
| self.seq_length = config.seq_length | |||||
| self.is_training = is_training | |||||
| self.total_num = Parameter(Tensor([0], mstype.float32), name='total_num') | |||||
| self.start_num = Parameter(Tensor([0], mstype.float32), name='start_num') | |||||
| self.end_num = Parameter(Tensor([0], mstype.float32), name='end_num') | |||||
| self.sum = P.ReduceSum() | |||||
| self.equal = P.Equal() | |||||
| self.argmax = P.ArgMaxWithValue(axis=1) | |||||
| self.squeeze = P.Squeeze(axis=-1) | |||||
| def construct(self, input_ids, input_mask, token_type_id, start_position, end_position, unique_id, is_impossible): | |||||
| logits = self.bert(input_ids, input_mask, token_type_id) | |||||
| if self.is_training: | |||||
| unstacked_logits_0 = self.squeeze(logits[:, :, 0:1]) | |||||
| unstacked_logits_1 = self.squeeze(logits[:, :, 1:2]) | |||||
| start_loss = self.loss(unstacked_logits_0, start_position, self.seq_length) | |||||
| end_loss = self.loss(unstacked_logits_1, end_position, self.seq_length) | |||||
| total_loss = (start_loss + end_loss) / 2.0 | |||||
| else: | |||||
| start_logits = self.squeeze(logits[:, :, 0:1]) | |||||
| end_logits = self.squeeze(logits[:, :, 1:2]) | |||||
| total_loss = (unique_id, start_logits, end_logits) | |||||
| return total_loss | |||||