|
|
|
@@ -19,8 +19,8 @@ import numpy as np |
|
|
|
|
|
|
|
from mindspore.mindrecord import FileWriter |
|
|
|
|
|
|
|
from src.config import config |
|
|
|
from src.utils import initialize_vocabulary |
|
|
|
from config import config |
|
|
|
from utils import initialize_vocabulary |
|
|
|
|
|
|
|
|
|
|
|
def serialize_annotation(img_path, lex, vocab): |
|
|
|
@@ -82,7 +82,7 @@ def create_fsns_label(image_dir, anno_file_dirs): |
|
|
|
|
|
|
|
def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8): |
|
|
|
|
|
|
|
anno_file_dirs = [config.train_annotation_file] |
|
|
|
anno_file_dirs = [config.annotation_file] |
|
|
|
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root, |
|
|
|
anno_file_dirs=anno_file_dirs) |
|
|
|
vocab, _ = initialize_vocabulary(config.vocab_path) |
|
|
|
@@ -104,8 +104,8 @@ def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", |
|
|
|
image_path = image_path_dict[img_id] |
|
|
|
annotation = image_anno_dict[img_id] |
|
|
|
|
|
|
|
label_max_len = config.max_text_len |
|
|
|
text_max_len = config.max_text_len - 2 |
|
|
|
label_max_len = config.max_length |
|
|
|
text_max_len = config.max_length - 2 |
|
|
|
|
|
|
|
if len(annotation) > text_max_len: |
|
|
|
continue |
|
|
|
@@ -151,8 +151,8 @@ def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", |
|
|
|
|
|
|
|
def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8): |
|
|
|
|
|
|
|
anno_file_dirs = [config.train_annotation_file] |
|
|
|
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root, |
|
|
|
anno_file_dirs = [config.val_annotation_file] |
|
|
|
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.val_data_root, |
|
|
|
anno_file_dirs=anno_file_dirs) |
|
|
|
vocab, _ = initialize_vocabulary(config.vocab_path) |
|
|
|
|
|
|
|
@@ -171,8 +171,8 @@ def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", fi |
|
|
|
image_path = image_path_dict[img_id] |
|
|
|
annotation = image_anno_dict[img_id] |
|
|
|
|
|
|
|
label_max_len = config.max_text_len |
|
|
|
text_max_len = config.max_text_len - 2 |
|
|
|
label_max_len = config.max_length |
|
|
|
text_max_len = config.max_length - 2 |
|
|
|
|
|
|
|
if len(annotation) > text_max_len: |
|
|
|
continue |
|
|
|
@@ -226,7 +226,7 @@ def create_mindrecord(dataset="fsns", prefix="fsns.mindrecord", is_training=True |
|
|
|
print("{} dataset is not defined!".format(dataset)) |
|
|
|
|
|
|
|
if not is_training: |
|
|
|
mindrecord_dir = os.path.join(config.mindrecord_dir, "val") |
|
|
|
mindrecord_dir = os.path.join(config.mindrecord_dir, "test") |
|
|
|
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")] |
|
|
|
|
|
|
|
if not os.path.exists(mindrecord_files[0]): |
|
|
|
@@ -243,3 +243,9 @@ def create_mindrecord(dataset="fsns", prefix="fsns.mindrecord", is_training=True |
|
|
|
print("{} dataset is not defined!".format(dataset)) |
|
|
|
|
|
|
|
return mindrecord_files |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
create_mindrecord(is_training=True) |
|
|
|
create_mindrecord(is_training=False) |
|
|
|
print("END") |