| @@ -67,6 +67,12 @@ def transpose_hwc2whc(image): | |||||
| return image | return image | ||||
| def transpose_hwc2chw(image): | |||||
| """transpose image from HWC to CHW""" | |||||
| image = np.transpose(image, (2, 0, 1)) | |||||
| return image | |||||
| def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): | def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): | ||||
| """ | """ | ||||
| create train or evaluation dataset for warpctc | create train or evaluation dataset for warpctc | ||||
| @@ -91,7 +97,10 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_ | |||||
| c.TypeCast(mstype.int32) | c.TypeCast(mstype.int32) | ||||
| ] | ] | ||||
| ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) | ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) | ||||
| ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8) | |||||
| if device_target == 'Ascend': | |||||
| ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8) | |||||
| else: | |||||
| ds = ds.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8) | |||||
| ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8) | ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8) | ||||
| ds = ds.batch(batch_size, drop_remainder=True) | ds = ds.batch(batch_size, drop_remainder=True) | ||||