|
|
|
@@ -15,7 +15,6 @@ |
|
|
|
|
|
|
|
"""CTPN dataset""" |
|
|
|
from __future__ import division |
|
|
|
import os |
|
|
|
import numpy as np |
|
|
|
from numpy import random |
|
|
|
import mmcv |
|
|
|
@@ -23,7 +22,6 @@ import mindspore.dataset as de |
|
|
|
import mindspore.dataset.vision.c_transforms as C |
|
|
|
import mindspore.dataset.transforms.c_transforms as CC |
|
|
|
import mindspore.common.dtype as mstype |
|
|
|
from mindspore.mindrecord import FileWriter |
|
|
|
from src.config import config |
|
|
|
|
|
|
|
class PhotoMetricDistortion: |
|
|
|
@@ -98,7 +96,7 @@ class Expand: |
|
|
|
boxes += np.tile((left, top), 2) |
|
|
|
return img, boxes, labels |
|
|
|
|
|
|
|
def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def rescale_column(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
"""rescale operation for image""" |
|
|
|
img_data, scale_factor = mmcv.imrescale(img, (config.img_width, config.img_height), return_scale=True) |
|
|
|
if img_data.shape[0] > config.img_height: |
|
|
|
@@ -112,10 +110,10 @@ def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) |
|
|
|
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) |
|
|
|
|
|
|
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) |
|
|
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
|
|
|
|
def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def resize_column(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
"""resize operation for image""" |
|
|
|
img_data = img |
|
|
|
img_data, w_scale, h_scale = mmcv.imresize( |
|
|
|
@@ -129,10 +127,10 @@ def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) |
|
|
|
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) |
|
|
|
|
|
|
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) |
|
|
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
|
|
|
|
def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def resize_column_test(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
"""resize operation for image of eval""" |
|
|
|
img_data = img |
|
|
|
img_data, w_scale, h_scale = mmcv.imresize( |
|
|
|
@@ -149,34 +147,34 @@ def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) |
|
|
|
gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) |
|
|
|
|
|
|
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) |
|
|
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
def flipped_generation(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def flipped_generation(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
"""flipped generation""" |
|
|
|
img_data = img |
|
|
|
flipped = gt_bboxes.copy() |
|
|
|
_, w, _ = img_data.shape |
|
|
|
flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 |
|
|
|
flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 |
|
|
|
return (img_data, img_shape, flipped, gt_label, gt_num) |
|
|
|
return (img_data, flipped, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
def image_bgr_rgb(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def image_bgr_rgb(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
img_data = img[:, :, ::-1] |
|
|
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) |
|
|
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def photo_crop_column(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
"""photo crop operation for image""" |
|
|
|
random_photo = PhotoMetricDistortion() |
|
|
|
img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) |
|
|
|
|
|
|
|
return (img_data, img_shape, gt_bboxes, gt_label, gt_num) |
|
|
|
return (img_data, gt_bboxes, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num): |
|
|
|
def expand_column(img, gt_bboxes, gt_label, gt_num, img_shape): |
|
|
|
"""expand operation for image""" |
|
|
|
expand = Expand() |
|
|
|
img, gt_bboxes, gt_label = expand(img, gt_bboxes, gt_label) |
|
|
|
|
|
|
|
return (img, img_shape, gt_bboxes, gt_label, gt_num) |
|
|
|
return (img, gt_bboxes, gt_label, gt_num, img_shape) |
|
|
|
|
|
|
|
def split_gtbox_label(gt_bbox_total): |
|
|
|
"""split ground truth box label""" |
|
|
|
@@ -193,7 +191,7 @@ def split_gtbox_label(gt_bbox_total): |
|
|
|
gtbox_list.append([x0, gt_bbox[1], x0+15, gt_bbox[3], 1]) |
|
|
|
return np.array(gtbox_list) |
|
|
|
|
|
|
|
def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid): |
|
|
|
def pad_label(img, gt_bboxes, gt_label, gt_valid, img_shape): |
|
|
|
"""pad ground truth label""" |
|
|
|
pad_max_number = 256 |
|
|
|
gt_label = gt_bboxes[:, 4] |
|
|
|
@@ -208,13 +206,13 @@ def pad_label(img, img_shape, gt_bboxes, gt_label, gt_valid): |
|
|
|
gt_box = gt_bboxes[0:pad_max_number] |
|
|
|
gt_label = gt_label[0:pad_max_number] |
|
|
|
gt_valid = gt_valid[0:pad_max_number] |
|
|
|
return (img, img_shape, gt_box[:, :4], gt_label, gt_valid) |
|
|
|
return (img, gt_box[:, :4], gt_label, gt_valid, img_shape) |
|
|
|
|
|
|
|
def preprocess_fn(image, box, is_training): |
|
|
|
"""Preprocess function for dataset.""" |
|
|
|
def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid): |
|
|
|
def _infer_data(image_bgr, gt_box_new, gt_label_new, gt_valid, image_shape): |
|
|
|
image_shape = image_shape[:2] |
|
|
|
input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_valid |
|
|
|
input_data = image_bgr, gt_box_new, gt_label_new, gt_valid, image_shape |
|
|
|
if config.keep_ratio: |
|
|
|
input_data = rescale_column(*input_data) |
|
|
|
else: |
|
|
|
@@ -234,9 +232,9 @@ def preprocess_fn(image, box, is_training): |
|
|
|
gt_box = box[:, :4] |
|
|
|
gt_label = box[:, 4] |
|
|
|
gt_valid = box[:, 4] |
|
|
|
input_data = image_bgr, image_shape, gt_box, gt_label, gt_valid |
|
|
|
input_data = image_bgr, gt_box, gt_label, gt_valid, image_shape |
|
|
|
if not is_training: |
|
|
|
return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_valid) |
|
|
|
return _infer_data(image_bgr, gt_box, gt_label, gt_valid, image_shape) |
|
|
|
expand = (np.random.rand() < config.expand_ratio) |
|
|
|
if expand: |
|
|
|
input_data = expand_column(*input_data) |
|
|
|
@@ -260,46 +258,6 @@ def anno_parser(annos_str): |
|
|
|
annos.append(anno) |
|
|
|
return annos |
|
|
|
|
|
|
|
def filter_valid_data(image_dir, anno_path): |
|
|
|
"""Filter valid image file, which both in image_dir and anno_path.""" |
|
|
|
image_files = [] |
|
|
|
image_anno_dict = {} |
|
|
|
if not os.path.isdir(image_dir): |
|
|
|
raise RuntimeError("Path given is not valid.") |
|
|
|
if not os.path.isfile(anno_path): |
|
|
|
raise RuntimeError("Annotation file is not valid.") |
|
|
|
|
|
|
|
with open(anno_path, "rb") as f: |
|
|
|
lines = f.readlines() |
|
|
|
for line in lines: |
|
|
|
line_str = line.decode("utf-8").strip() |
|
|
|
line_split = str(line_str).split(' ') |
|
|
|
file_name = line_split[0] |
|
|
|
image_path = os.path.join(image_dir, file_name) |
|
|
|
if os.path.isfile(image_path): |
|
|
|
image_anno_dict[image_path] = anno_parser(line_split[1:]) |
|
|
|
image_files.append(image_path) |
|
|
|
return image_files, image_anno_dict |
|
|
|
|
|
|
|
def data_to_mindrecord_byte_image(is_training=True, prefix="cptn_mlt.mindrecord", file_num=8): |
|
|
|
"""Create MindRecord file.""" |
|
|
|
mindrecord_dir = config.mindrecord_dir |
|
|
|
mindrecord_path = os.path.join(mindrecord_dir, prefix) |
|
|
|
writer = FileWriter(mindrecord_path, file_num) |
|
|
|
image_files, image_anno_dict = create_icdar_test_label() |
|
|
|
ctpn_json = { |
|
|
|
"image": {"type": "bytes"}, |
|
|
|
"annotation": {"type": "int32", "shape": [-1, 6]}, |
|
|
|
} |
|
|
|
writer.add_schema(ctpn_json, "ctpn_json") |
|
|
|
for image_name in image_files: |
|
|
|
with open(image_name, 'rb') as f: |
|
|
|
img = f.read() |
|
|
|
annos = np.array(image_anno_dict[image_name], dtype=np.int32) |
|
|
|
row = {"image": img, "annotation": annos} |
|
|
|
writer.write_raw_data([row]) |
|
|
|
writer.commit() |
|
|
|
|
|
|
|
def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num=1, rank_id=0, |
|
|
|
is_training=True, num_parallel_workers=12): |
|
|
|
"""Creatr ctpn dataset with MindDataset.""" |
|
|
|
@@ -316,8 +274,8 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num= |
|
|
|
type_cast3 = CC.TypeCast(mstype.bool_) |
|
|
|
if is_training: |
|
|
|
ds = ds.map(operations=compose_map_func, input_columns=["image", "annotation"], |
|
|
|
output_columns=["image", "image_shape", "box", "label", "valid_num"], |
|
|
|
column_order=["image", "image_shape", "box", "label", "valid_num"], |
|
|
|
output_columns=["image", "box", "label", "valid_num", "image_shape"], |
|
|
|
column_order=["image", "box", "label", "valid_num", "image_shape"], |
|
|
|
num_parallel_workers=num_parallel_workers, |
|
|
|
python_multiprocessing=True) |
|
|
|
ds = ds.map(operations=[normalize_op, type_cast0], input_columns=["image"], |
|
|
|
@@ -329,8 +287,8 @@ def create_ctpn_dataset(mindrecord_file, batch_size=1, repeat_num=1, device_num= |
|
|
|
else: |
|
|
|
ds = ds.map(operations=compose_map_func, |
|
|
|
input_columns=["image", "annotation"], |
|
|
|
output_columns=["image", "image_shape", "box", "label", "valid_num"], |
|
|
|
column_order=["image", "image_shape", "box", "label", "valid_num"], |
|
|
|
output_columns=["image", "box", "label", "valid_num", "image_shape"], |
|
|
|
column_order=["image", "box", "label", "valid_num", "image_shape"], |
|
|
|
num_parallel_workers=num_parallel_workers, |
|
|
|
python_multiprocessing=True) |
|
|
|
|
|
|
|
|