Browse Source

fix warpctc precision bug

tags/v1.1.0
gengdongjie 5 years ago
parent
commit
353c677b2c
1 changed files with 10 additions and 1 deletions
  1. +10
    -1
      model_zoo/official/cv/warpctc/src/dataset.py

+ 10
- 1
model_zoo/official/cv/warpctc/src/dataset.py View File

@@ -67,6 +67,12 @@ def transpose_hwc2whc(image):
return image return image




def transpose_hwc2chw(image):
"""transpose image from HWC to CHW"""
image = np.transpose(image, (2, 0, 1))
return image


def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'): def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
""" """
create train or evaluation dataset for warpctc create train or evaluation dataset for warpctc
@@ -91,7 +97,10 @@ def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_
c.TypeCast(mstype.int32) c.TypeCast(mstype.int32)
] ]
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8) ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8)
if device_target == 'Ascend':
ds = ds.map(operations=transpose_hwc2whc, input_columns=["image"], num_parallel_workers=8)
else:
ds = ds.map(operations=transpose_hwc2chw, input_columns=["image"], num_parallel_workers=8)
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8) ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)


ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)


Loading…
Cancel
Save