| @@ -0,0 +1,173 @@ | |||
| # Guideline to Convert Training Data enwiki to MindRecord For Bert Pre Training | |||
| <!-- TOC --> | |||
| - [What does the example do](#what-does-the-example-do) | |||
| - [How to use the example to process enwiki](#how-to-use-the-example-to-process-enwiki) | |||
| - [Download enwiki training data](#download-enwiki-training-data) | |||
| - [Process the enwiki](#process-the-enwiki) | |||
| - [Generate MindRecord](#generate-mindrecord) | |||
| - [Create MindDataset By MindRecord](#create-minddataset-by-mindrecord) | |||
| <!-- /TOC --> | |||
| ## What does the example do | |||
| This example is based on [enwiki](https://dumps.wikimedia.org/enwiki) training data, generating MindRecord file, and finally used for Bert network training. | |||
| 1. run.sh: generate MindRecord entry script. | |||
| 2. run_read.py: create MindDataset by MindRecord entry script. | |||
| - create_dataset.py: use MindDataset to read MindRecord to generate dataset. | |||
| ## How to use the example to process enwiki | |||
| Download enwiki data, process it, convert it to MindRecord, use MindDataset to read MindRecord. | |||
| ### Download enwiki training data | |||
| > [enwiki dataset download address](https://dumps.wikimedia.org/enwiki) **-> 20200501 -> enwiki-20200501-pages-articles-multistream.xml.bz2** | |||
| ### Process the enwiki | |||
| 1. Please follow the steps in [process enwiki](https://github.com/mlperf/training/tree/master/language_model/tensorflow/bert) | |||
| - All permissions of this step belong to the link address website. | |||
| ### Generate MindRecord | |||
| 1. Run the run.sh script. | |||
| ``` | |||
| bash run.sh input_dir output_dir vocab_file | |||
| ``` | |||
| - input_dir: the directory which contains files like 'part-00251-of-00500'. | |||
| - output_dir: which will store the output mindrecord files. | |||
| - vocab_file: the vocab file which you can download from other opensource project. | |||
| 2. The output like this: | |||
| ``` | |||
| ... | |||
| Begin preprocess Wed Jun 10 09:21:23 CST 2020 | |||
| Begin preprocess input file: /mnt/data/results/part-00000-of-00500 | |||
| Begin output file: part-00000-of-00500.mindrecord | |||
| Total task: 510, processing: 1 | |||
| Begin preprocess input file: /mnt/data/results/part-00001-of-00500 | |||
| Begin output file: part-00001-of-00500.mindrecord | |||
| Total task: 510, processing: 2 | |||
| Begin preprocess input file: /mnt/data/results/part-00002-of-00500 | |||
| Begin output file: part-00002-of-00500.mindrecord | |||
| Total task: 510, processing: 3 | |||
| Begin preprocess input file: /mnt/data/results/part-00003-of-00500 | |||
| Begin output file: part-00003-of-00500.mindrecord | |||
| Total task: 510, processing: 4 | |||
| Begin preprocess input file: /mnt/data/results/part-00004-of-00500 | |||
| Begin output file: part-00004-of-00500.mindrecord | |||
| Total task: 510, processing: 4 | |||
| ... | |||
| ``` | |||
| 3. Generate files like this: | |||
| ```bash | |||
| $ ls {your_output_dir}/ | |||
| part-00000-of-00500.mindrecord part-00000-of-00500.mindrecord.db part-00001-of-00500.mindrecord part-00001-of-00500.mindrecord.db part-00002-of-00500.mindrecord part-00002-of-00500.mindrecord.db ... | |||
| ``` | |||
| ### Create MindDataset By MindRecord | |||
| 1. Run the run_read.sh script. | |||
| ```bash | |||
| bash run_read.sh input_dir | |||
| ``` | |||
| - input_dir: the directory which contains mindrecord files. | |||
| 2. The output like this: | |||
| ``` | |||
| ... | |||
| example 633: input_ids: [ 101 2043 19781 4305 2140 4520 2041 1010 103 2034 2455 2002 | |||
| 7879 2003 1996 2455 1997 103 26378 4160 1012 102 7291 2001 | |||
| 1996 103 1011 2343 1997 6327 1010 3423 1998 103 4262 2005 | |||
| 1996 2118 1997 2329 3996 103 102 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0] | |||
| example 633: input_mask: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 | |||
| 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] | |||
| example 633: segment_ids: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 | |||
| 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] | |||
| example 633: masked_lm_positions: [ 8 17 20 25 33 41 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0] | |||
| example 633: masked_lm_ids: [ 1996 16137 1012 3580 2451 1012 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0 0 0 0 0 0 0 0 0 | |||
| 0 0 0 0] | |||
| example 633: masked_lm_weights: [1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. | |||
| 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. | |||
| 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. | |||
| 0. 0. 0. 0.] | |||
| example 633: next_sentence_labels: [1] | |||
| ... | |||
| ``` | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """create MindDataset by MindRecord""" | |||
| import argparse | |||
| import mindspore.dataset as ds | |||
| def create_dataset(data_file): | |||
| """create MindDataset""" | |||
| num_readers = 4 | |||
| data_set = ds.MindDataset(dataset_file=data_file, num_parallel_workers=num_readers, shuffle=True) | |||
| index = 0 | |||
| for item in data_set.create_dict_iterator(): | |||
| # print("example {}: {}".format(index, item)) | |||
| print("example {}: input_ids: {}".format(index, item['input_ids'])) | |||
| print("example {}: input_mask: {}".format(index, item['input_mask'])) | |||
| print("example {}: segment_ids: {}".format(index, item['segment_ids'])) | |||
| print("example {}: masked_lm_positions: {}".format(index, item['masked_lm_positions'])) | |||
| print("example {}: masked_lm_ids: {}".format(index, item['masked_lm_ids'])) | |||
| print("example {}: masked_lm_weights: {}".format(index, item['masked_lm_weights'])) | |||
| print("example {}: next_sentence_labels: {}".format(index, item['next_sentence_labels'])) | |||
| index += 1 | |||
| if index % 1000 == 0: | |||
| print("read rows: {}".format(index)) | |||
| print("total rows: {}".format(index)) | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("-i", "--input_file", nargs='+', type=str, help='Input mindreord file') | |||
| args = parser.parse_args() | |||
| create_dataset(args.input_file) | |||
| @@ -0,0 +1,133 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 3 ]; then | |||
| echo "Usage: $0 input_dir output_dir vocab_file" | |||
| exit 1 | |||
| fi | |||
| if [ ! -d $1 ]; then | |||
| echo "The input dir: $1 is not exist." | |||
| exit 1 | |||
| fi | |||
| if [ ! -d $2 ]; then | |||
| echo "The output dir: $2 is not exist." | |||
| exit 1 | |||
| fi | |||
| rm -fr $2/*.mindrecord* | |||
| if [ ! -f $3 ]; then | |||
| echo "The vocab file: $3 is not exist." | |||
| exit 1 | |||
| fi | |||
| data_dir=$1 | |||
| output_dir=$2 | |||
| vocab_file=$3 | |||
| file_list=() | |||
| output_filename=() | |||
| file_index=0 | |||
| function getdir() { | |||
| elements=`ls $1` | |||
| for element in ${elements[*]}; | |||
| do | |||
| dir_or_file=$1"/"$element | |||
| if [ -d $dir_or_file ]; | |||
| then | |||
| getdir $dir_or_file | |||
| else | |||
| file_list[$file_index]=$dir_or_file | |||
| echo "${dir_or_file}" | tr '/' '\n' > dir_file_list.txt # dir dir file to mapfile | |||
| mapfile parent_dir < dir_file_list.txt | |||
| rm dir_file_list.txt >/dev/null 2>&1 | |||
| tmp_output_filename=${parent_dir[${#parent_dir[@]}-1]}".mindrecord" | |||
| output_filename[$file_index]=`echo ${tmp_output_filename} | sed 's/ //g'` | |||
| file_index=`expr $file_index + 1` | |||
| fi | |||
| done | |||
| } | |||
| getdir "${data_dir}" | |||
| # echo "The input files: "${file_list[@]} | |||
| # echo "The output files: "${output_filename[@]} | |||
| if [ ! -d "../../../third_party/to_mindrecord/zhwiki" ]; then | |||
| echo "The patch base dir ../../../third_party/to_mindrecord/zhwiki is not exist." | |||
| exit 1 | |||
| fi | |||
| if [ ! -f "../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch" ]; then | |||
| echo "The patch file ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch is not exist." | |||
| exit 1 | |||
| fi | |||
| # patch for create_pretraining_data.py | |||
| patch -p0 -d ../../../third_party/to_mindrecord/zhwiki/ -o create_pretraining_data_patched.py < ../../../third_party/patch/to_mindrecord/zhwiki/create_pretraining_data.patch | |||
| if [ $? -ne 0 ]; then | |||
| echo "Patch ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data.py failed" | |||
| exit 1 | |||
| fi | |||
| # get the cpu core count | |||
| num_cpu_core=`cat /proc/cpuinfo | grep "processor" | wc -l` | |||
| avaiable_core_size=`expr $num_cpu_core / 3 \* 2` | |||
| echo "Begin preprocess `date`" | |||
| # using patched script to generate mindrecord | |||
| file_list_len=`expr ${#file_list[*]} - 1` | |||
| for index in $(seq 0 $file_list_len); do | |||
| echo "Begin preprocess input file: ${file_list[$index]}" | |||
| echo "Begin output file: ${output_filename[$index]}" | |||
| python ../../../third_party/to_mindrecord/zhwiki/create_pretraining_data_patched.py \ | |||
| --input_file=${file_list[$index]} \ | |||
| --output_file=${output_dir}/${output_filename[$index]} \ | |||
| --partition_number=1 \ | |||
| --vocab_file=${vocab_file} \ | |||
| --do_lower_case=True \ | |||
| --max_seq_length=512 \ | |||
| --max_predictions_per_seq=76 \ | |||
| --masked_lm_prob=0.15 \ | |||
| --random_seed=12345 \ | |||
| --dupe_factor=10 >/tmp/${output_filename[$index]}.log 2>&1 & | |||
| process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` | |||
| echo "Total task: ${#file_list[*]}, processing: ${process_count}" | |||
| if [ $process_count -ge $avaiable_core_size ]; then | |||
| while [ 1 ]; do | |||
| process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` | |||
| if [ $process_count -gt $process_num ]; then | |||
| process_count=$process_num | |||
| break; | |||
| fi | |||
| sleep 2 | |||
| done | |||
| fi | |||
| done | |||
| process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` | |||
| while [ 1 ]; do | |||
| if [ $process_num -eq 0 ]; then | |||
| break; | |||
| fi | |||
| echo "There are still ${process_num} preprocess running ..." | |||
| sleep 2 | |||
| process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` | |||
| done | |||
| echo "Preprocess all the data success." | |||
| echo "End preprocess `date`" | |||
| @@ -0,0 +1,44 @@ | |||
| #!/bin/bash | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| if [ $# -ne 1 ]; then | |||
| echo "Usage: $0 input_dir" | |||
| exit 1 | |||
| fi | |||
| if [ ! -d $1 ]; then | |||
| echo "The input dir: $1 is not exist." | |||
| exit 1 | |||
| fi | |||
| file_list=() | |||
| file_index=0 | |||
| # get all the mindrecord file from output dir | |||
| function getdir() { | |||
| elements=`ls $1/part-*.mindrecord` | |||
| for element in ${elements[*]}; | |||
| do | |||
| file_list[$file_index]=$element | |||
| file_index=`expr $file_index + 1` | |||
| done | |||
| } | |||
| getdir $1 | |||
| echo "Get all the mindrecord files: "${file_list[*]} | |||
| # create dataset for train | |||
| python create_dataset.py --input_file ${file_list[*]} | |||
| @@ -85,7 +85,7 @@ for index in $(seq 0 $file_list_len); do | |||
| --random_seed=12345 \ | |||
| --dupe_factor=5 >/tmp/${output_filename[$index]}.log 2>&1 & | |||
| process_count=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` | |||
| echo "Total task: ${file_list_len}, processing: ${process_count}" | |||
| echo "Total task: ${#file_list[*]}, processing: ${process_count}" | |||
| if [ $process_count -ge $avaiable_core_size ]; then | |||
| while [ 1 ]; do | |||
| process_num=`ps -ef | grep create_pretraining_data_patched | grep -v grep | wc -l` | |||
| @@ -89,7 +89,7 @@ | |||
| + "segment_ids": {"type": "int64", "shape": [-1]}, | |||
| + "masked_lm_positions": {"type": "int64", "shape": [-1]}, | |||
| + "masked_lm_ids": {"type": "int64", "shape": [-1]}, | |||
| + "masked_lm_weights": {"type": "float64", "shape": [-1]}, | |||
| + "masked_lm_weights": {"type": "float32", "shape": [-1]}, | |||
| + "next_sentence_labels": {"type": "int64", "shape": [-1]}, | |||
| + } | |||
| + writer.add_schema(data_schema, "zhwiki schema") | |||
| @@ -112,13 +112,13 @@ | |||
| - | |||
| - writers[writer_index].write(tf_example.SerializeToString()) | |||
| - writer_index = (writer_index + 1) % len(writers) | |||
| + features["input_ids"] = np.asarray(input_ids) | |||
| + features["input_mask"] = np.asarray(input_mask) | |||
| + features["segment_ids"] = np.asarray(segment_ids) | |||
| + features["masked_lm_positions"] = np.asarray(masked_lm_positions) | |||
| + features["masked_lm_ids"] = np.asarray(masked_lm_ids) | |||
| + features["masked_lm_weights"] = np.asarray(masked_lm_weights) | |||
| + features["next_sentence_labels"] = np.asarray([next_sentence_label]) | |||
| + features["input_ids"] = np.asarray(input_ids, np.int64) | |||
| + features["input_mask"] = np.asarray(input_mask, np.int64) | |||
| + features["segment_ids"] = np.asarray(segment_ids, np.int64) | |||
| + features["masked_lm_positions"] = np.asarray(masked_lm_positions, np.int64) | |||
| + features["masked_lm_ids"] = np.asarray(masked_lm_ids, np.int64) | |||
| + features["masked_lm_weights"] = np.asarray(masked_lm_weights, np.float32) | |||
| + features["next_sentence_labels"] = np.asarray([next_sentence_label], np.int64) | |||
| total_written += 1 | |||