Browse Source

!3693 Change the column order and add drop_reminder option to make the generated dataset compatible with BertCLS

Merge pull request !3693 from dessyang/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
1165b27f41
1 changed files with 11 additions and 29 deletions
  1. +11
    -29
      model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py

+ 11
- 29
model_zoo/official/nlp/bert/src/clue_classification_dataset_process.py View File

@@ -26,8 +26,8 @@ import mindspore.dataset.text as text
import mindspore.dataset.transforms.c_transforms as ops import mindspore.dataset.transforms.c_transforms as ops
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path,
data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64):
def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False,
max_seq_len=128, batch_size=64, drop_remainder=True):
"""Process TNEWS dataset""" """Process TNEWS dataset"""
### Loading TNEWS from CLUEDataset ### Loading TNEWS from CLUEDataset
assert data_usage in ['train', 'eval', 'test'] assert data_usage in ['train', 'eval', 'test']
@@ -61,26 +61,17 @@ def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path,
dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup)
dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0))
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"],
columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate())
columns_order=["text_ids", "mask_ids", "label_id"], operations=ops.Duplicate())
dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32))
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"],
columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate())
columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate())
dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0))
dataset = dataset.batch(batch_size)
label = []
text_ids = []
mask_ids = []
segment_ids = []
for data in dataset:
label.append(data[0])
text_ids.append(data[1])
mask_ids.append(data[2])
segment_ids.append(data[3])
return label, text_ids, mask_ids, segment_ids
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset
def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path,
data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64):
def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False,
max_seq_len=128, batch_size=64, drop_remainder=True):
"""Process CMNLI dataset""" """Process CMNLI dataset"""
### Loading CMNLI from CLUEDataset ### Loading CMNLI from CLUEDataset
assert data_usage in ['train', 'eval', 'test'] assert data_usage in ['train', 'eval', 'test']
@@ -138,16 +129,7 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path,
dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0))
### Generating mask_ids ### Generating mask_ids
dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"],
columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate())
columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate())
dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32))
dataset = dataset.batch(batch_size)
label = []
text_ids = []
mask_ids = []
segment_ids = []
for data in dataset:
label.append(data[0])
text_ids.append(data[1])
mask_ids.append(data[2])
segment_ids.append(data[3])
return label, text_ids, mask_ids, segment_ids
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
return dataset

Loading…
Cancel
Save