You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 8.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """utils for test"""
  16. import os
  17. import re
  18. import string
  19. import collections
  20. import json
  21. import numpy as np
  22. from mindspore import log as logger
  23. def get_data(dir_name):
  24. """
  25. Return raw data of imagenet dataset.
  26. Args:
  27. dir_name (str): String of imagenet dataset's path.
  28. Returns:
  29. List
  30. """
  31. if not os.path.isdir(dir_name):
  32. raise IOError("Directory {} not exists".format(dir_name))
  33. img_dir = os.path.join(dir_name, "images")
  34. ann_file = os.path.join(dir_name, "annotation.txt")
  35. with open(ann_file, "r") as file_reader:
  36. lines = file_reader.readlines()
  37. data_list = []
  38. for line in lines:
  39. try:
  40. filename, label = line.split(",")
  41. label = label.strip("\n")
  42. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  43. img = file_reader.read()
  44. data_json = {"file_name": filename,
  45. "data": img,
  46. "label": int(label)}
  47. data_list.append(data_json)
  48. except FileNotFoundError:
  49. continue
  50. return data_list
  51. def get_two_bytes_data(file_name):
  52. """
  53. Return raw data of two-bytes dataset.
  54. Args:
  55. file_name (str): String of two-bytes dataset's path.
  56. Returns:
  57. List
  58. """
  59. if not os.path.exists(file_name):
  60. raise IOError("map file {} not exists".format(file_name))
  61. dir_name = os.path.dirname(file_name)
  62. with open(file_name, "r") as file_reader:
  63. lines = file_reader.readlines()
  64. data_list = []
  65. row_num = 0
  66. for line in lines:
  67. try:
  68. img, label = line.strip('\n').split(" ")
  69. with open(os.path.join(dir_name, img), "rb") as file_reader:
  70. img_data = file_reader.read()
  71. with open(os.path.join(dir_name, label), "rb") as file_reader:
  72. label_data = file_reader.read()
  73. data_json = {"file_name": img,
  74. "img_data": img_data,
  75. "label_name": label,
  76. "label_data": label_data,
  77. "id": row_num
  78. }
  79. row_num += 1
  80. data_list.append(data_json)
  81. except FileNotFoundError:
  82. continue
  83. return data_list
  84. def get_multi_bytes_data(file_name, bytes_num=3):
  85. """
  86. Return raw data of multi-bytes dataset.
  87. Args:
  88. file_name (str): String of multi-bytes dataset's path.
  89. bytes_num (int): Number of bytes fields.
  90. Returns:
  91. List
  92. """
  93. if not os.path.exists(file_name):
  94. raise IOError("map file {} not exists".format(file_name))
  95. dir_name = os.path.dirname(file_name)
  96. with open(file_name, "r") as file_reader:
  97. lines = file_reader.readlines()
  98. data_list = []
  99. row_num = 0
  100. for line in lines:
  101. try:
  102. img10_path = line.strip('\n').split(" ")
  103. img5 = []
  104. for path in img10_path[:bytes_num]:
  105. with open(os.path.join(dir_name, path), "rb") as file_reader:
  106. img5 += [file_reader.read()]
  107. data_json = {"image_{}".format(i): img5[i]
  108. for i in range(len(img5))}
  109. data_json.update({"id": row_num})
  110. row_num += 1
  111. data_list.append(data_json)
  112. except FileNotFoundError:
  113. continue
  114. return data_list
  115. def get_mkv_data(dir_name):
  116. """
  117. Return raw data of Vehicle_and_Person dataset.
  118. Args:
  119. dir_name (str): String of Vehicle_and_Person dataset's path.
  120. Returns:
  121. List
  122. """
  123. if not os.path.isdir(dir_name):
  124. raise IOError("Directory {} not exists".format(dir_name))
  125. img_dir = os.path.join(dir_name, "Image")
  126. label_dir = os.path.join(dir_name, "prelabel")
  127. data_list = []
  128. file_list = os.listdir(label_dir)
  129. index = 1
  130. for file in file_list:
  131. if os.path.splitext(file)[1] == '.json':
  132. file_path = os.path.join(label_dir, file)
  133. image_name = ''.join([os.path.splitext(file)[0], ".jpg"])
  134. image_path = os.path.join(img_dir, image_name)
  135. with open(file_path, "r") as load_f:
  136. load_dict = json.load(load_f)
  137. if os.path.exists(image_path):
  138. with open(image_path, "rb") as file_reader:
  139. img = file_reader.read()
  140. data_json = {"file_name": image_name,
  141. "prelabel": str(load_dict),
  142. "data": img,
  143. "id": index}
  144. data_list.append(data_json)
  145. index += 1
  146. logger.info('{} images are missing'.format(len(file_list) - len(data_list)))
  147. return data_list
  148. def get_nlp_data(dir_name, vocab_file, num):
  149. """
  150. Return raw data of aclImdb dataset.
  151. Args:
  152. dir_name (str): String of aclImdb dataset's path.
  153. vocab_file (str): String of dictionary's path.
  154. num (int): Number of sample.
  155. Returns:
  156. List
  157. """
  158. if not os.path.isdir(dir_name):
  159. raise IOError("Directory {} not exists".format(dir_name))
  160. for root, _, files in os.walk(dir_name):
  161. for index, file_name_extension in enumerate(files):
  162. if index < num:
  163. file_path = os.path.join(root, file_name_extension)
  164. file_name, _ = file_name_extension.split('.', 1)
  165. id_, rating = file_name.split('_', 1)
  166. with open(file_path, 'r') as f:
  167. raw_content = f.read()
  168. dictionary = load_vocab(vocab_file)
  169. vectors = [dictionary.get('[CLS]')]
  170. vectors += [dictionary.get(i) if i in dictionary
  171. else dictionary.get('[UNK]')
  172. for i in re.findall(r"[\w']+|[{}]"
  173. .format(string.punctuation),
  174. raw_content)]
  175. vectors += [dictionary.get('[SEP]')]
  176. input_, mask, segment = inputs(vectors)
  177. input_ids = np.reshape(np.array(input_), [1, -1])
  178. input_mask = np.reshape(np.array(mask), [1, -1])
  179. segment_ids = np.reshape(np.array(segment), [1, -1])
  180. data = {
  181. "label": 1,
  182. "id": id_,
  183. "rating": float(rating),
  184. "input_ids": input_ids,
  185. "input_mask": input_mask,
  186. "segment_ids": segment_ids
  187. }
  188. yield data
  189. def convert_to_uni(text):
  190. if isinstance(text, str):
  191. return text
  192. if isinstance(text, bytes):
  193. return text.decode('utf-8', 'ignore')
  194. raise Exception("The type %s does not convert!" % type(text))
  195. def load_vocab(vocab_file):
  196. """load vocabulary to translate statement."""
  197. vocab = collections.OrderedDict()
  198. vocab.setdefault('blank', 2)
  199. index = 0
  200. with open(vocab_file) as reader:
  201. while True:
  202. tmp = reader.readline()
  203. if not tmp:
  204. break
  205. token = convert_to_uni(tmp)
  206. token = token.strip()
  207. vocab[token] = index
  208. index += 1
  209. return vocab
  210. def inputs(vectors, maxlen=50):
  211. length = len(vectors)
  212. if length > maxlen:
  213. return vectors[0:maxlen], [1] * maxlen, [0] * maxlen
  214. input_ = vectors + [0] * (maxlen - length)
  215. mask = [1] * length + [0] * (maxlen - length)
  216. segment = [0] * maxlen
  217. return input_, mask, segment