| @@ -26,8 +26,8 @@ import mindspore.dataset.text as text | |||
| import mindspore.dataset.transforms.c_transforms as ops | |||
| def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, | |||
| data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): | |||
| def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, | |||
| max_seq_len=128, batch_size=64, drop_remainder=True): | |||
| """Process TNEWS dataset""" | |||
| ### Loading TNEWS from CLUEDataset | |||
| assert data_usage in ['train', 'eval', 'test'] | |||
| @@ -61,26 +61,17 @@ def process_tnews_clue_dataset(data_dir, label_list, bert_vocab_path, | |||
| dataset = dataset.map(input_columns=["sentence"], output_columns=["text_ids"], operations=lookup) | |||
| dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) | |||
| dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], | |||
| columns_order=["label_id", "text_ids", "mask_ids"], operations=ops.Duplicate()) | |||
| columns_order=["text_ids", "mask_ids", "label_id"], operations=ops.Duplicate()) | |||
| dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) | |||
| dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "segment_ids"], | |||
| columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) | |||
| columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate()) | |||
| dataset = dataset.map(input_columns=["segment_ids"], operations=ops.Fill(0)) | |||
| dataset = dataset.batch(batch_size) | |||
| label = [] | |||
| text_ids = [] | |||
| mask_ids = [] | |||
| segment_ids = [] | |||
| for data in dataset: | |||
| label.append(data[0]) | |||
| text_ids.append(data[1]) | |||
| mask_ids.append(data[2]) | |||
| segment_ids.append(data[3]) | |||
| return label, text_ids, mask_ids, segment_ids | |||
| dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) | |||
| return dataset | |||
| def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, | |||
| data_usage='train', shuffle_dataset=False, max_seq_len=128, batch_size=64): | |||
| def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, data_usage='train', shuffle_dataset=False, | |||
| max_seq_len=128, batch_size=64, drop_remainder=True): | |||
| """Process CMNLI dataset""" | |||
| ### Loading CMNLI from CLUEDataset | |||
| assert data_usage in ['train', 'eval', 'test'] | |||
| @@ -138,16 +129,7 @@ def process_cmnli_clue_dataset(data_dir, label_list, bert_vocab_path, | |||
| dataset = dataset.map(input_columns=["text_ids"], operations=ops.PadEnd([max_seq_len], 0)) | |||
| ### Generating mask_ids | |||
| dataset = dataset.map(input_columns=["text_ids"], output_columns=["text_ids", "mask_ids"], | |||
| columns_order=["label_id", "text_ids", "mask_ids", "segment_ids"], operations=ops.Duplicate()) | |||
| columns_order=["text_ids", "mask_ids", "segment_ids", "label_id"], operations=ops.Duplicate()) | |||
| dataset = dataset.map(input_columns=["mask_ids"], operations=ops.Mask(ops.Relational.NE, 0, mstype.int32)) | |||
| dataset = dataset.batch(batch_size) | |||
| label = [] | |||
| text_ids = [] | |||
| mask_ids = [] | |||
| segment_ids = [] | |||
| for data in dataset: | |||
| label.append(data[0]) | |||
| text_ids.append(data[1]) | |||
| mask_ids.append(data[2]) | |||
| segment_ids.append(data[3]) | |||
| return label, text_ids, mask_ids, segment_ids | |||
| dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) | |||
| return dataset | |||