You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 3.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Dataset preprocessing."""
  16. import os
  17. import math as m
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. import mindspore.dataset.engine as de
  21. import mindspore.dataset.transforms.c_transforms as c
  22. import mindspore.dataset.vision.c_transforms as vc
  23. from PIL import Image
  24. from src.config import config as cf
  25. class _CaptchaDataset:
  26. """
  27. create train or evaluation dataset for warpctc
  28. Args:
  29. img_root_dir(str): root path of images
  30. max_captcha_digits(int): max number of digits in images.
  31. device_target(str): platform of training, support Ascend and GPU.
  32. """
  33. def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'):
  34. if not os.path.exists(img_root_dir):
  35. raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
  36. self.img_root_dir = img_root_dir
  37. self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
  38. self.max_captcha_digits = max_captcha_digits
  39. self.target = device_target
  40. self.blank = 10
  41. self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]
  42. def __len__(self):
  43. return len(self.img_names)
  44. def __getitem__(self, item):
  45. img_name = self.img_names[item]
  46. im = Image.open(os.path.join(self.img_root_dir, img_name))
  47. r, g, b = im.split()
  48. im = Image.merge("RGB", (b, g, r))
  49. image = np.array(im)
  50. label_str = os.path.splitext(img_name)[0]
  51. label_str = label_str[label_str.find('-') + 1:]
  52. label = [int(i) for i in label_str]
  53. label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
  54. label = np.array(label)
  55. return image, label
  56. def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
  57. """
  58. create train or evaluation dataset for warpctc
  59. Args:
  60. dataset_path(int): dataset path
  61. batch_size(int): batch size of generated dataset, default is 1
  62. num_shards(int): number of devices
  63. shard_id(int): rank id
  64. device_target(str): platform of training, support Ascend and GPU
  65. """
  66. dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
  67. ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
  68. image_trans = [
  69. vc.Rescale(1.0 / 255.0, 0.0),
  70. vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
  71. vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
  72. vc.HWC2CHW()
  73. ]
  74. label_trans = [
  75. c.TypeCast(mstype.int32)
  76. ]
  77. ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
  78. ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
  79. ds = ds.batch(batch_size, drop_remainder=True)
  80. return ds