| @@ -39,6 +39,7 @@ import mindspore.dataset.engine.datasets as de | |||||
| import mindspore.dataset.transforms.c_transforms as C | import mindspore.dataset.transforms.c_transforms as C | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.train.model import Model | from mindspore.train.model import Model | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor | ||||
| from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell | from mindspore.model_zoo.Bert_NEZHA import BertNetworkWithLoss, BertTrainOneStepCell | ||||
| @@ -49,9 +50,9 @@ def create_train_dataset(batch_size): | |||||
| """create train dataset""" | """create train dataset""" | ||||
| # apply repeat operations | # apply repeat operations | ||||
| repeat_count = bert_train_cfg.epoch_size | repeat_count = bert_train_cfg.epoch_size | ||||
| ds = de.StorageDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | |||||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) | |||||
| ds = de.TFRecordDataset([bert_train_cfg.DATA_DIR], bert_train_cfg.SCHEMA_DIR, | |||||
| columns_list=["input_ids", "input_mask", "segment_ids", "next_sentence_labels", | |||||
| "masked_lm_positions", "masked_lm_ids", "masked_lm_weights"]) | |||||
| type_cast_op = C.TypeCast(mstype.int32) | type_cast_op = C.TypeCast(mstype.int32) | ||||
| ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) | ||||
| ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) | ||||