|
|
|
@@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de |
|
|
|
import mindspore.dataset.transforms.c_transforms as C |
|
|
|
from mindspore import context |
|
|
|
from mindspore.common.tensor import Tensor |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.train.model import Model |
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor |
|
|
|
from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell |
|
|
|
@@ -49,9 +50,9 @@ def create_train_dataset(batch_size): |
|
|
|
"""create train dataset""" |
|
|
|
# apply repeat operations |
|
|
|
repeat_count = bert_train_cfg.epoch_size |
|
|
|
ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, |
|
|
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", |
|
|
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) |
|
|
|
ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, |
|
|
|
columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", |
|
|
|
"masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) |
|
|
|
type_cast_op = C.TypeCast(mstype.int32) |
|
|
|
ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) |
|
|
|
ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) |
|
|
|
|