You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_minddataset.py 19 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for mindrecord
  17. """
  18. import collections
  19. import json
  20. import os
  21. import re
  22. import string
  23. import mindspore.dataset.transforms.vision.c_transforms as vision
  24. import numpy as np
  25. import pytest
  26. from mindspore._c_dataengine import InterpolationMode
  27. from mindspore import log as logger
  28. import mindspore.dataset as ds
  29. from mindspore.mindrecord import FileWriter
  30. FILES_NUM = 4
  31. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  32. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  33. NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
  34. NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
  35. NLP_FILE_VOCAB= "../data/mindrecord/testAclImdbData/vocab.txt"
  36. @pytest.fixture
  37. def add_and_remove_cv_file():
  38. """add/remove cv file"""
  39. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  40. for x in range(FILES_NUM)]
  41. for x in paths:
  42. os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
  43. os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
  44. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  45. data = get_data(CV_DIR_NAME)
  46. cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
  47. "data": {"type": "bytes"}}
  48. writer.add_schema(cv_schema_json, "img_schema")
  49. writer.add_index(["file_name", "label"])
  50. writer.write_raw_data(data)
  51. writer.commit()
  52. yield "yield_cv_data"
  53. for x in paths:
  54. os.remove("{}".format(x))
  55. os.remove("{}.db".format(x))
  56. @pytest.fixture
  57. def add_and_remove_nlp_file():
  58. """add/remove nlp file"""
  59. paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
  60. for x in range(FILES_NUM)]
  61. for x in paths:
  62. if os.path.exists("{}".format(x)):
  63. os.remove("{}".format(x))
  64. if os.path.exists("{}.db".format(x)):
  65. os.remove("{}.db".format(x))
  66. writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
  67. data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
  68. nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
  69. "rating": {"type": "float32"},
  70. "input_ids": {"type": "int64",
  71. "shape": [-1]},
  72. "input_mask": {"type": "int64",
  73. "shape": [1, -1]},
  74. "segment_ids": {"type": "int64",
  75. "shape": [2,-1]}
  76. }
  77. writer.set_header_size(1 << 14)
  78. writer.set_page_size(1 << 15)
  79. writer.add_schema(nlp_schema_json, "nlp_schema")
  80. writer.add_index(["id", "rating"])
  81. writer.write_raw_data(data)
  82. writer.commit()
  83. yield "yield_nlp_data"
  84. for x in paths:
  85. os.remove("{}".format(x))
  86. os.remove("{}.db".format(x))
  87. def test_cv_minddataset_writer_tutorial():
  88. """tutorial for cv dataset writer."""
  89. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  90. for x in range(FILES_NUM)]
  91. for x in paths:
  92. os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
  93. os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
  94. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  95. data = get_data(CV_DIR_NAME)
  96. cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
  97. "data": {"type": "bytes"}}
  98. writer.add_schema(cv_schema_json, "img_schema")
  99. writer.add_index(["file_name", "label"])
  100. writer.write_raw_data(data)
  101. writer.commit()
  102. for x in paths:
  103. os.remove("{}".format(x))
  104. os.remove("{}.db".format(x))
  105. def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
  106. """tutorial for cv minddataset."""
  107. columns_list = ["data", "file_name", "label"]
  108. num_readers = 4
  109. def partitions(num_shards):
  110. for partition_id in range(num_shards):
  111. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  112. num_shards=num_shards, shard_id=partition_id)
  113. num_iter = 0
  114. for item in data_set.create_dict_iterator():
  115. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  116. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  117. num_iter += 1
  118. return num_iter
  119. assert partitions(4) == 3
  120. assert partitions(5) == 2
  121. assert partitions(9) == 2
  122. def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
  123. """tutorial for cv minddataset."""
  124. columns_list = ["data", "file_name", "label"]
  125. num_readers = 4
  126. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  127. assert data_set.get_dataset_size() == 10
  128. repeat_num = 2
  129. data_set = data_set.repeat(repeat_num)
  130. num_iter = 0
  131. for item in data_set.create_dict_iterator():
  132. logger.info("-------------- get dataset size {} -----------------".format(num_iter))
  133. logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
  134. logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
  135. num_iter += 1
  136. assert num_iter == 20
  137. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  138. num_shards=4, shard_id=3)
  139. assert data_set.get_dataset_size() == 3
  140. def test_cv_minddataset_issue_888(add_and_remove_cv_file):
  141. """issue 888 test."""
  142. columns_list = ["data", "label"]
  143. num_readers = 2
  144. data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1)
  145. data = data.shuffle(2)
  146. data = data.repeat(9)
  147. num_iter = 0
  148. for item in data.create_dict_iterator():
  149. num_iter += 1
  150. assert num_iter == 18
  151. def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
  152. """tutorial for cv minddataset."""
  153. columns_list = ["data", "label"]
  154. num_readers = 4
  155. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  156. block_reader=True)
  157. assert data_set.get_dataset_size() == 10
  158. repeat_num = 2
  159. data_set = data_set.repeat(repeat_num)
  160. num_iter = 0
  161. for item in data_set.create_dict_iterator():
  162. logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
  163. logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
  164. logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
  165. num_iter += 1
  166. assert num_iter == 20
  167. def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
  168. """tutorial for cv minderdataset."""
  169. columns_list = ["data", "file_name", "label"]
  170. num_readers = 4
  171. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  172. assert data_set.get_dataset_size() == 10
  173. num_iter = 0
  174. for item in data_set.create_dict_iterator():
  175. logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
  176. logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  177. logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
  178. logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  179. logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
  180. num_iter += 1
  181. assert num_iter == 10
  182. def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file):
  183. """tutorial for nlp minderdataset."""
  184. num_readers = 4
  185. data_set = ds.MindDataset(NLP_FILE_NAME + "0", None, num_readers)
  186. assert data_set.get_dataset_size() == 10
  187. num_iter = 0
  188. for item in data_set.create_dict_iterator():
  189. logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
  190. logger.info("-------------- num_iter: {} ------------------------".format(num_iter))
  191. logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
  192. logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
  193. logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
  194. item["input_ids"], item["input_ids"].shape))
  195. logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format(
  196. item["input_mask"], item["input_mask"].shape))
  197. logger.info("-------------- item[segment_ids]: {}, shape: {} -----------------".format(
  198. item["segment_ids"], item["segment_ids"].shape))
  199. assert item["input_ids"].shape == (50,)
  200. assert item["input_mask"].shape == (1, 50)
  201. assert item["segment_ids"].shape == (2, 25)
  202. num_iter += 1
  203. assert num_iter == 10
  204. def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
  205. """tutorial for cv minderdataset."""
  206. columns_list = ["data", "file_name", "label"]
  207. num_readers = 4
  208. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  209. assert data_set.get_dataset_size() == 10
  210. for epoch in range(5):
  211. num_iter = 0
  212. for data in data_set:
  213. logger.info("data is {}".format(data))
  214. num_iter += 1
  215. assert num_iter == 10
  216. data_set.reset()
  217. def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file):
  218. """tutorial for cv minderdataset."""
  219. columns_list = ["data", "file_name", "label"]
  220. num_readers = 4
  221. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  222. resize_height = 32
  223. resize_width = 32
  224. # define map operations
  225. decode_op = vision.Decode()
  226. resize_op = vision.Resize((resize_height, resize_width), ds.transforms.vision.Inter.LINEAR)
  227. data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4)
  228. data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4)
  229. data_set = data_set.batch(2)
  230. assert data_set.get_dataset_size() == 5
  231. for epoch in range(5):
  232. num_iter = 0
  233. for data in data_set:
  234. logger.info("data is {}".format(data))
  235. num_iter += 1
  236. assert num_iter == 5
  237. data_set.reset()
  238. def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file):
  239. """tutorial for cv minderdataset."""
  240. data_set = ds.MindDataset(CV_FILE_NAME + "0")
  241. assert data_set.get_dataset_size() == 10
  242. num_iter = 0
  243. for item in data_set.create_dict_iterator():
  244. logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
  245. logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  246. logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
  247. logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  248. logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
  249. num_iter += 1
  250. assert num_iter == 10
  251. def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file):
  252. """tutorial for cv minderdataset."""
  253. columns_list = ["data", "file_name", "label"]
  254. num_readers = 4
  255. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  256. repeat_num = 2
  257. data_set = data_set.repeat(repeat_num)
  258. num_iter = 0
  259. for item in data_set.create_dict_iterator():
  260. logger.info("-------------- repeat two test {} ------------------------".format(num_iter))
  261. logger.info("-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
  262. logger.info("-------------- item[data]: {} ----------------------------".format(item["data"]))
  263. logger.info("-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
  264. logger.info("-------------- item[label]: {} ---------------------------".format(item["label"]))
  265. num_iter += 1
  266. assert num_iter == 20
  267. def get_data(dir_name):
  268. """
  269. usage: get data from imagenet dataset
  270. params:
  271. dir_name: directory containing folder images and annotation information
  272. """
  273. if not os.path.isdir(dir_name):
  274. raise IOError("Directory {} not exists".format(dir_name))
  275. img_dir = os.path.join(dir_name, "images")
  276. ann_file = os.path.join(dir_name, "annotation.txt")
  277. with open(ann_file, "r") as file_reader:
  278. lines = file_reader.readlines()
  279. data_list = []
  280. for line in lines:
  281. try:
  282. filename, label = line.split(",")
  283. label = label.strip("\n")
  284. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  285. img = file_reader.read()
  286. data_json = {"file_name": filename,
  287. "data": img,
  288. "label": int(label)}
  289. data_list.append(data_json)
  290. except FileNotFoundError:
  291. continue
  292. return data_list
  293. def get_multi_bytes_data(file_name, bytes_num=3):
  294. """
  295. Return raw data of multi-bytes dataset.
  296. Args:
  297. file_name (str): String of multi-bytes dataset's path.
  298. bytes_num (int): Number of bytes fields.
  299. Returns:
  300. List
  301. """
  302. if not os.path.exists(file_name):
  303. raise IOError("map file {} not exists".format(file_name))
  304. dir_name = os.path.dirname(file_name)
  305. with open(file_name, "r") as file_reader:
  306. lines = file_reader.readlines()
  307. data_list = []
  308. row_num = 0
  309. for line in lines:
  310. try:
  311. img10_path = line.strip('\n').split(" ")
  312. img5 = []
  313. for path in img10_path[:bytes_num]:
  314. with open(os.path.join(dir_name, path), "rb") as file_reader:
  315. img5 += [file_reader.read()]
  316. data_json = {"image_{}".format(i): img5[i]
  317. for i in range(len(img5))}
  318. data_json.update({"id": row_num})
  319. row_num += 1
  320. data_list.append(data_json)
  321. except FileNotFoundError:
  322. continue
  323. return data_list
  324. def get_mkv_data(dir_name):
  325. """
  326. Return raw data of Vehicle_and_Person dataset.
  327. Args:
  328. dir_name (str): String of Vehicle_and_Person dataset's path.
  329. Returns:
  330. List
  331. """
  332. if not os.path.isdir(dir_name):
  333. raise IOError("Directory {} not exists".format(dir_name))
  334. img_dir = os.path.join(dir_name, "Image")
  335. label_dir = os.path.join(dir_name, "prelabel")
  336. data_list = []
  337. file_list = os.listdir(label_dir)
  338. index = 1
  339. for item in file_list:
  340. if os.path.splitext(item)[1] == '.json':
  341. file_path = os.path.join(label_dir, item)
  342. image_name = ''.join([os.path.splitext(item)[0], ".jpg"])
  343. image_path = os.path.join(img_dir, image_name)
  344. with open(file_path, "r") as load_f:
  345. load_dict = json.load(load_f)
  346. if os.path.exists(image_path):
  347. with open(image_path, "rb") as file_reader:
  348. img = file_reader.read()
  349. data_json = {"file_name": image_name,
  350. "prelabel": str(load_dict),
  351. "data": img,
  352. "id": index}
  353. data_list.append(data_json)
  354. index += 1
  355. logger.info('{} images are missing'.format(len(file_list)-len(data_list)))
  356. return data_list
  357. def get_nlp_data(dir_name, vocab_file, num):
  358. """
  359. Return raw data of aclImdb dataset.
  360. Args:
  361. dir_name (str): String of aclImdb dataset's path.
  362. vocab_file (str): String of dictionary's path.
  363. num (int): Number of sample.
  364. Returns:
  365. List
  366. """
  367. if not os.path.isdir(dir_name):
  368. raise IOError("Directory {} not exists".format(dir_name))
  369. for root, dirs, files in os.walk(dir_name):
  370. for index, file_name_extension in enumerate(files):
  371. if index < num:
  372. file_path = os.path.join(root, file_name_extension)
  373. file_name, _ = file_name_extension.split('.', 1)
  374. id_, rating = file_name.split('_', 1)
  375. with open(file_path, 'r') as f:
  376. raw_content = f.read()
  377. dictionary = load_vocab(vocab_file)
  378. vectors = [dictionary.get('[CLS]')]
  379. vectors += [dictionary.get(i) if i in dictionary
  380. else dictionary.get('[UNK]')
  381. for i in re.findall(r"[\w']+|[{}]"
  382. .format(string.punctuation),
  383. raw_content)]
  384. vectors += [dictionary.get('[SEP]')]
  385. input_, mask, segment = inputs(vectors)
  386. input_ids = np.reshape(np.array(input_), [-1])
  387. input_mask = np.reshape(np.array(mask), [1, -1])
  388. segment_ids = np.reshape(np.array(segment), [2, -1])
  389. data = {
  390. "label": 1,
  391. "id": id_,
  392. "rating": float(rating),
  393. "input_ids": input_ids,
  394. "input_mask": input_mask,
  395. "segment_ids": segment_ids
  396. }
  397. yield data
  398. def convert_to_uni(text):
  399. if isinstance(text, str):
  400. return text
  401. if isinstance(text, bytes):
  402. return text.decode('utf-8', 'ignore')
  403. raise Exception("The type %s does not convert!" % type(text))
  404. def load_vocab(vocab_file):
  405. """load vocabulary to translate statement."""
  406. vocab = collections.OrderedDict()
  407. vocab.setdefault('blank', 2)
  408. index = 0
  409. with open(vocab_file) as reader:
  410. while True:
  411. tmp = reader.readline()
  412. if not tmp:
  413. break
  414. token = convert_to_uni(tmp)
  415. token = token.strip()
  416. vocab[token] = index
  417. index += 1
  418. return vocab
  419. def inputs(vectors, maxlen=50):
  420. length = len(vectors)
  421. if length > maxlen:
  422. return vectors[0:maxlen], [1]*maxlen, [0]*maxlen
  423. input_ = vectors+[0]*(maxlen-length)
  424. mask = [1]*length + [0]*(maxlen-length)
  425. segment = [0]*maxlen
  426. return input_, mask, segment

MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.