|
|
|
@@ -67,12 +67,6 @@ def transpose_hwc2whc(image): |
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
def transpose_hwc2chw(image): |
|
|
|
"""transpose image from HWC to CHW""" |
|
|
|
image = np.transpose(image, (2, 0, 1)) |
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): |
|
|
|
""" |
|
|
|
create train or evaluation dataset for warpctc |
|
|
|
@@ -93,14 +87,20 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_ |
|
|
|
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), |
|
|
|
c.TypeCast(mstype.float16) |
|
|
|
] |
|
|
|
image_trans_gpu = [ |
|
|
|
vc.Rescale(1.0 / 255.0, 0.0), |
|
|
|
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), |
|
|
|
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), |
|
|
|
vc.HWC2CHW() |
|
|
|
] |
|
|
|
label_trans = [ |
|
|
|
c.TypeCast(mstype.int32) |
|
|
|
] |
|
|
|
data_set = data_set.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) |
|
|
|
if device_target == 'Ascend': |
|
|
|
data_set = data_set.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) |
|
|
|
data_set = data_set.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8) |
|
|
|
else: |
|
|
|
data_set = data_set.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8) |
|
|
|
data_set = data_set.map(operations=image_trans_gpu, input_columns=["image"], num_parallel_workers=8) |
|
|
|
data_set = data_set.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8) |
|
|
|
|
|
|
|
data_set = data_set.batch(batch_size, drop_remainder=True) |
|
|
|
|