You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset.py 9.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """cnn_ctc dataset"""
  16. import sys
  17. import pickle
  18. import math
  19. import six
  20. import numpy as np
  21. from PIL import Image
  22. import lmdb
  23. from mindspore.communication.management import get_rank, get_group_size
  24. from .util import CTCLabelConverter
  25. from .config import Config_CNNCTC
  26. config = Config_CNNCTC()
  27. class NormalizePAD():
  28. def __init__(self, max_size, PAD_type='right'):
  29. self.max_size = max_size
  30. self.PAD_type = PAD_type
  31. def __call__(self, img):
  32. # toTensor
  33. img = np.array(img, dtype=np.float32)
  34. img = img.transpose([2, 0, 1])
  35. img = img.astype(np.float)
  36. img = np.true_divide(img, 255)
  37. # normalize
  38. img = np.subtract(img, 0.5)
  39. img = np.true_divide(img, 0.5)
  40. _, _, w = img.shape
  41. Pad_img = np.zeros(shape=self.max_size, dtype=np.float32)
  42. Pad_img[:, :, :w] = img # right pad
  43. if self.max_size[2] != w: # add border Pad
  44. Pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w))
  45. return Pad_img
  46. class AlignCollate():
  47. def __init__(self, imgH=32, imgW=100):
  48. self.imgH = imgH
  49. self.imgW = imgW
  50. def __call__(self, images):
  51. resized_max_w = self.imgW
  52. input_channel = 3
  53. transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
  54. resized_images = []
  55. for image in images:
  56. w, h = image.size
  57. ratio = w / float(h)
  58. if math.ceil(self.imgH * ratio) > self.imgW:
  59. resized_w = self.imgW
  60. else:
  61. resized_w = math.ceil(self.imgH * ratio)
  62. resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
  63. resized_images.append(transform(resized_image))
  64. image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0)
  65. return image_tensors
  66. def get_img_from_lmdb(env, index):
  67. with env.begin(write=False) as txn:
  68. label_key = 'label-%09d'.encode() % index
  69. label = txn.get(label_key).decode('utf-8')
  70. img_key = 'image-%09d'.encode() % index
  71. imgbuf = txn.get(img_key)
  72. buf = six.BytesIO()
  73. buf.write(imgbuf)
  74. buf.seek(0)
  75. try:
  76. img = Image.open(buf).convert('RGB') # for color image
  77. except IOError:
  78. print(f'Corrupted image for {index}')
  79. # make dummy image and dummy label for corrupted image.
  80. img = Image.new('RGB', (config.IMG_W, config.IMG_H))
  81. label = '[dummy_label]'
  82. label = label.lower()
  83. return img, label
  84. class ST_MJ_Generator_batch_fixed_length:
  85. def __init__(self):
  86. self.align_collector = AlignCollate()
  87. self.converter = CTCLabelConverter(config.CHARACTER)
  88. self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
  89. meminit=False)
  90. if not self.env:
  91. print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
  92. raise ValueError(config.TRAIN_DATASET_PATH)
  93. with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
  94. self.st_mj_filtered_index_list = pickle.load(f)
  95. print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
  96. self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE
  97. self.batch_size = config.TRAIN_BATCH_SIZE
  98. def __len__(self):
  99. return self.dataset_size
  100. def __getitem__(self, item):
  101. img_ret = []
  102. text_ret = []
  103. for i in range(item * self.batch_size, (item + 1) * self.batch_size):
  104. index = self.st_mj_filtered_index_list[i]
  105. img, label = get_img_from_lmdb(self.env, index)
  106. img_ret.append(img)
  107. text_ret.append(label)
  108. img_ret = self.align_collector(img_ret)
  109. text_ret, length = self.converter.encode(text_ret)
  110. label_indices = []
  111. for i, _ in enumerate(length):
  112. for j in range(length[i]):
  113. label_indices.append((i, j))
  114. label_indices = np.array(label_indices, np.int64)
  115. sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
  116. text_ret = text_ret.astype(np.int32)
  117. return img_ret, label_indices, text_ret, sequence_length
  118. class ST_MJ_Generator_batch_fixed_length_para:
  119. def __init__(self):
  120. self.align_collector = AlignCollate()
  121. self.converter = CTCLabelConverter(config.CHARACTER)
  122. self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
  123. meminit=False)
  124. if not self.env:
  125. print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
  126. raise ValueError(config.TRAIN_DATASET_PATH)
  127. with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
  128. self.st_mj_filtered_index_list = pickle.load(f)
  129. print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
  130. self.rank_id = get_rank()
  131. self.rank_size = get_group_size()
  132. self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size
  133. self.batch_size = config.TRAIN_BATCH_SIZE
  134. def __len__(self):
  135. return self.dataset_size
  136. def __getitem__(self, item):
  137. img_ret = []
  138. text_ret = []
  139. rank_item = (item * self.rank_size) + self.rank_id
  140. for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size):
  141. index = self.st_mj_filtered_index_list[i]
  142. img, label = get_img_from_lmdb(self.env, index)
  143. img_ret.append(img)
  144. text_ret.append(label)
  145. img_ret = self.align_collector(img_ret)
  146. text_ret, length = self.converter.encode(text_ret)
  147. label_indices = []
  148. for i, _ in enumerate(length):
  149. for j in range(length[i]):
  150. label_indices.append((i, j))
  151. label_indices = np.array(label_indices, np.int64)
  152. sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
  153. text_ret = text_ret.astype(np.int32)
  154. return img_ret, label_indices, text_ret, sequence_length
  155. def IIIT_Generator_batch():
  156. max_len = int((26 + 1) // 2)
  157. align_collector = AlignCollate()
  158. converter = CTCLabelConverter(config.CHARACTER)
  159. env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
  160. if not env:
  161. print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH))
  162. sys.exit(0)
  163. with env.begin(write=False) as txn:
  164. nSamples = int(txn.get('num-samples'.encode()))
  165. nSamples = nSamples
  166. # Filtering
  167. filtered_index_list = []
  168. for index in range(nSamples):
  169. index += 1 # lmdb starts with 1
  170. label_key = 'label-%09d'.encode() % index
  171. label = txn.get(label_key).decode('utf-8')
  172. if len(label) > max_len:
  173. continue
  174. illegal_sample = False
  175. for char_item in label.lower():
  176. if char_item not in config.CHARACTER:
  177. illegal_sample = True
  178. break
  179. if illegal_sample:
  180. continue
  181. filtered_index_list.append(index)
  182. img_ret = []
  183. text_ret = []
  184. print(f'num of samples in IIIT dataset: {len(filtered_index_list)}')
  185. for index in filtered_index_list:
  186. img, label = get_img_from_lmdb(env, index)
  187. img_ret.append(img)
  188. text_ret.append(label)
  189. if len(img_ret) == config.TEST_BATCH_SIZE:
  190. img_ret = align_collector(img_ret)
  191. text_ret, length = converter.encode(text_ret)
  192. label_indices = []
  193. for i, _ in enumerate(length):
  194. for j in range(length[i]):
  195. label_indices.append((i, j))
  196. label_indices = np.array(label_indices, np.int64)
  197. sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32)
  198. text_ret = text_ret.astype(np.int32)
  199. yield img_ret, label_indices, text_ret, sequence_length, length
  200. img_ret = []
  201. text_ret = []