Merge pull request !1821 from yoonlee666/pretrain-evatags/v0.5.0-beta
| @@ -19,7 +19,7 @@ Bert evaluation script. | |||
| import os | |||
| from src import BertModel, GetMaskedLMOutput | |||
| from evaluation_config import cfg, bert_net_cfg | |||
| from src.evaluation_config import cfg, bert_net_cfg | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore import context | |||
| from mindspore.common.tensor import Tensor | |||
| @@ -87,17 +87,18 @@ class BertPretrainEva(nn.Cell): | |||
| self.cast = P.Cast() | |||
| def construct(self, input_ids, input_mask, token_type_id, masked_pos, masked_ids, nsp_label, masked_weights): | |||
| def construct(self, input_ids, input_mask, token_type_id, masked_pos, masked_ids, masked_weights, nsp_label): | |||
| bs, _ = self.shape(input_ids) | |||
| probs = self.bert(input_ids, input_mask, token_type_id, masked_pos) | |||
| index = self.argmax(probs) | |||
| index = self.reshape(index, (bs, -1)) | |||
| eval_acc = self.equal(index, masked_ids) | |||
| eval_acc1 = self.cast(eval_acc, mstype.float32) | |||
| acc = self.mean(eval_acc1) | |||
| P.Print()(acc) | |||
| self.total += self.shape(probs)[0] | |||
| self.acc += self.sum(eval_acc1) | |||
| real_acc = eval_acc1 * masked_weights | |||
| acc = self.sum(real_acc) | |||
| total = self.sum(masked_weights) | |||
| self.total += total | |||
| self.acc += acc | |||
| return acc, self.total, self.acc | |||
| @@ -107,8 +108,8 @@ def get_enwiki_512_dataset(batch_size=1, repeat_count=1, distribute_file=''): | |||
| ''' | |||
| ds = de.TFRecordDataset([cfg.data_file], cfg.schema_file, columns_list=["input_ids", "input_mask", "segment_ids", | |||
| "masked_lm_positions", "masked_lm_ids", | |||
| "next_sentence_labels", | |||
| "masked_lm_weights"]) | |||
| "masked_lm_weights", | |||
| "next_sentence_labels"]) | |||
| type_cast_op = C.TypeCast(mstype.int32) | |||
| ds = ds.map(input_columns="segment_ids", operations=type_cast_op) | |||
| ds = ds.map(input_columns="input_mask", operations=type_cast_op) | |||
| @@ -143,7 +144,8 @@ def MLM_eval(): | |||
| Evaluate function | |||
| ''' | |||
| _, dataset, net_for_pretraining = bert_predict() | |||
| net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2], metrics={myMetric()}) | |||
| net = Model(net_for_pretraining, eval_network=net_for_pretraining, eval_indexes=[0, 1, 2], | |||
| metrics={'name': myMetric()}) | |||
| res = net.eval(dataset, dataset_sink_mode=False) | |||
| print("==============================================================") | |||
| for _, v in res.items(): | |||
| @@ -66,6 +66,8 @@ def run_pretrain(): | |||
| parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path") | |||
| parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, " | |||
| "default is 1000.") | |||
| parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, " | |||
| "meaning run all steps according to epoch number.") | |||
| parser.add_argument("--save_checkpoint_num", type=int, default=1, help="Save checkpoint numbers, default is 1.") | |||
| parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path") | |||
| parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path") | |||
| @@ -93,11 +95,12 @@ def run_pretrain(): | |||
| ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle, | |||
| args_opt.enable_data_sink, args_opt.data_sink_steps, | |||
| args_opt.data_dir, args_opt.schema_dir) | |||
| if args_opt.train_steps > 0: | |||
| new_repeat_count = min(new_repeat_count, args_opt.train_steps // args_opt.data_sink_steps) | |||
| netwithloss = BertNetworkWithLoss(bert_net_cfg, True) | |||
| if cfg.optimizer == 'Lamb': | |||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), | |||
| optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * new_repeat_count, | |||
| start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, | |||
| power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay, | |||
| eps=cfg.Lamb.eps) | |||
| @@ -106,7 +109,7 @@ def run_pretrain(): | |||
| momentum=cfg.Momentum.momentum) | |||
| elif cfg.optimizer == 'AdamWeightDecayDynamicLR': | |||
| optimizer = AdamWeightDecayDynamicLR(netwithloss.trainable_params(), | |||
| decay_steps=ds.get_dataset_size() * ds.get_repeat_count(), | |||
| decay_steps=ds.get_dataset_size() * new_repeat_count, | |||
| learning_rate=cfg.AdamWeightDecayDynamicLR.learning_rate, | |||
| end_learning_rate=cfg.AdamWeightDecayDynamicLR.end_learning_rate, | |||
| power=cfg.AdamWeightDecayDynamicLR.power, | |||
| @@ -19,8 +19,8 @@ import json | |||
| import numpy as np | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.common.tensor import Tensor | |||
| import tokenization | |||
| from sample_process import label_generation, process_one_example_p | |||
| from . import tokenization | |||
| from .sample_process import label_generation, process_one_example_p | |||
| from .evaluation_config import cfg | |||
| from .CRF import postprocess | |||