You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

convert_svt.py 4.8 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import argparse
  17. from xml.etree import ElementTree as ET
  18. from PIL import Image
  19. import numpy as np
  20. def init_args():
  21. parser = argparse.ArgumentParser('')
  22. parser.add_argument('-d', '--dataset_dir', type=str, default='./',
  23. help='Directory containing test_features.tfrecords')
  24. parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
  25. help='Directory where character dictionaries for the dataset were stored')
  26. parser.add_argument('-o', '--output_dir', type=str, default='./processed',
  27. help='Directory where ord map dictionaries for the dataset were stored')
  28. return parser.parse_args()
  29. def xml_to_dict(xml_file, save_file=False):
  30. tree = ET.parse(xml_file)
  31. root = tree.getroot()
  32. imgs_labels = []
  33. for ch in root:
  34. im_label = {}
  35. for ch01 in ch:
  36. if ch01.tag in "address":
  37. continue
  38. elif ch01.tag in 'taggedRectangles':
  39. # multiple children
  40. rect_list = []
  41. for ch02 in ch01:
  42. rect = {}
  43. rect['location'] = ch02.attrib
  44. rect['label'] = ch02[0].text
  45. rect_list.append(rect)
  46. im_label['rect'] = rect_list
  47. else:
  48. im_label[ch01.tag] = ch01.text
  49. imgs_labels.append(im_label)
  50. if save_file:
  51. np.save("annotation_train.npy", imgs_labels)
  52. return imgs_labels
  53. def image_crop_save(image, location, output_dir):
  54. """
  55. crop image with location (h,w,x,y)
  56. save cropped image to output directory
  57. """
  58. start_x = location[2]
  59. end_x = start_x + location[1]
  60. start_y = location[3]
  61. if start_y < 0:
  62. start_y = 0
  63. end_y = start_y + location[0]
  64. print("image array shape :{}".format(image.shape))
  65. print("crop region ", start_x, end_x, start_y, end_y)
  66. if len(image.shape) == 3:
  67. cropped = image[start_y:end_y, start_x:end_x, :]
  68. else:
  69. cropped = image[start_y:end_y, start_x:end_x]
  70. im = Image.fromarray(np.uint8(cropped))
  71. im.save(output_dir)
  72. def convert():
  73. args = init_args()
  74. if not os.path.exists(args.dataset_dir):
  75. raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
  76. if not os.path.exists(args.xml_file):
  77. raise ValueError("xml_file :{} does not exist".format(args.xml_file))
  78. if not os.path.exists(args.output_dir):
  79. os.makedirs(args.output_dir)
  80. ims_labels_dict = xml_to_dict(args.xml_file, True)
  81. num_images = len(ims_labels_dict)
  82. lexicon_list = []
  83. annotation_list = []
  84. print("Converting annotation, {} images in total ".format(num_images))
  85. for i in range(num_images):
  86. img_label = ims_labels_dict[i]
  87. image_name = img_label['imageName']
  88. lex = img_label['lex']
  89. rects = img_label['rect']
  90. name, ext = image_name.split('.')
  91. fullpath = os.path.join(args.dataset_dir, image_name)
  92. im_array = np.asarray(Image.open(fullpath))
  93. lexicon_list.append(lex)
  94. print("processing image: {}".format(image_name))
  95. for j, rect in enumerate(rects):
  96. rect = rects[j]
  97. location = rect['location']
  98. h = int(location['height'])
  99. w = int(location['width'])
  100. x = int(location['x'])
  101. y = int(location['y'])
  102. label = rect['label']
  103. loc = [h, w, x, y]
  104. output_name = name + "_" + str(j) + "_" + label + '.' + ext
  105. output_file = os.path.join(args.output_dir, output_name)
  106. image_crop_save(im_array, loc, output_file)
  107. ann = output_name + "," + label + ',' + str(i)
  108. annotation_list.append(ann)
  109. lex_file = './lexicon_ann_train.txt'
  110. ann_file = './annotation_train.txt'
  111. with open(lex_file, 'w') as f:
  112. for line in lexicon_list:
  113. txt = line + '\n'
  114. f.write(txt)
  115. with open(ann_file, 'w') as f:
  116. for line in annotation_list:
  117. txt = line + '\n'
  118. f.write(txt)
  119. if __name__ == "__main__":
  120. convert()