You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_minddataset.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This is the test module for mindrecord
  17. """
  18. import collections
  19. import json
  20. import os
  21. import re
  22. import string
  23. import mindspore.dataset.transforms.vision.c_transforms as vision
  24. import numpy as np
  25. import pytest
  26. from mindspore._c_dataengine import InterpolationMode
  27. from mindspore.dataset.transforms.vision import Inter
  28. from mindspore import log as logger
  29. import mindspore.dataset as ds
  30. from mindspore.mindrecord import FileWriter
  31. FILES_NUM = 4
  32. CV_FILE_NAME = "../data/mindrecord/imagenet.mindrecord"
  33. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  34. NLP_FILE_NAME = "../data/mindrecord/aclImdb.mindrecord"
  35. NLP_FILE_POS = "../data/mindrecord/testAclImdbData/pos"
  36. NLP_FILE_VOCAB= "../data/mindrecord/testAclImdbData/vocab.txt"
  37. @pytest.fixture
  38. def add_and_remove_cv_file():
  39. """add/remove cv file"""
  40. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  41. for x in range(FILES_NUM)]
  42. for x in paths:
  43. os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
  44. os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
  45. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  46. data = get_data(CV_DIR_NAME)
  47. cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
  48. "data": {"type": "bytes"}}
  49. writer.add_schema(cv_schema_json, "img_schema")
  50. writer.add_index(["file_name", "label"])
  51. writer.write_raw_data(data)
  52. writer.commit()
  53. yield "yield_cv_data"
  54. for x in paths:
  55. os.remove("{}".format(x))
  56. os.remove("{}.db".format(x))
  57. @pytest.fixture
  58. def add_and_remove_nlp_file():
  59. """add/remove nlp file"""
  60. paths = ["{}{}".format(NLP_FILE_NAME, str(x).rjust(1, '0'))
  61. for x in range(FILES_NUM)]
  62. for x in paths:
  63. if os.path.exists("{}".format(x)):
  64. os.remove("{}".format(x))
  65. if os.path.exists("{}.db".format(x)):
  66. os.remove("{}.db".format(x))
  67. writer = FileWriter(NLP_FILE_NAME, FILES_NUM)
  68. data = [x for x in get_nlp_data(NLP_FILE_POS, NLP_FILE_VOCAB, 10)]
  69. nlp_schema_json = {"id": {"type": "string"}, "label": {"type": "int32"},
  70. "rating": {"type": "float32"},
  71. "input_ids": {"type": "int64",
  72. "shape": [-1]},
  73. "input_mask": {"type": "int64",
  74. "shape": [1, -1]},
  75. "segment_ids": {"type": "int64",
  76. "shape": [2,-1]}
  77. }
  78. writer.set_header_size(1 << 14)
  79. writer.set_page_size(1 << 15)
  80. writer.add_schema(nlp_schema_json, "nlp_schema")
  81. writer.add_index(["id", "rating"])
  82. writer.write_raw_data(data)
  83. writer.commit()
  84. yield "yield_nlp_data"
  85. for x in paths:
  86. os.remove("{}".format(x))
  87. os.remove("{}.db".format(x))
  88. def test_cv_minddataset_writer_tutorial():
  89. """tutorial for cv dataset writer."""
  90. paths = ["{}{}".format(CV_FILE_NAME, str(x).rjust(1, '0'))
  91. for x in range(FILES_NUM)]
  92. for x in paths:
  93. os.remove("{}".format(x)) if os.path.exists("{}".format(x)) else None
  94. os.remove("{}.db".format(x)) if os.path.exists("{}.db".format(x)) else None
  95. writer = FileWriter(CV_FILE_NAME, FILES_NUM)
  96. data = get_data(CV_DIR_NAME)
  97. cv_schema_json = {"file_name": {"type": "string"}, "label": {"type": "int32"},
  98. "data": {"type": "bytes"}}
  99. writer.add_schema(cv_schema_json, "img_schema")
  100. writer.add_index(["file_name", "label"])
  101. writer.write_raw_data(data)
  102. writer.commit()
  103. for x in paths:
  104. os.remove("{}".format(x))
  105. os.remove("{}.db".format(x))
  106. def test_cv_minddataset_partition_tutorial(add_and_remove_cv_file):
  107. """tutorial for cv minddataset."""
  108. columns_list = ["data", "file_name", "label"]
  109. num_readers = 4
  110. def partitions(num_shards):
  111. for partition_id in range(num_shards):
  112. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  113. num_shards=num_shards, shard_id=partition_id)
  114. num_iter = 0
  115. for item in data_set.create_dict_iterator():
  116. logger.info("-------------- partition : {} ------------------------".format(partition_id))
  117. logger.info("-------------- item[label]: {} -----------------------".format(item["label"]))
  118. num_iter += 1
  119. return num_iter
  120. assert partitions(4) == 3
  121. assert partitions(5) == 2
  122. assert partitions(9) == 2
  123. def test_cv_minddataset_dataset_size(add_and_remove_cv_file):
  124. """tutorial for cv minddataset."""
  125. columns_list = ["data", "file_name", "label"]
  126. num_readers = 4
  127. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  128. assert data_set.get_dataset_size() == 10
  129. repeat_num = 2
  130. data_set = data_set.repeat(repeat_num)
  131. num_iter = 0
  132. for item in data_set.create_dict_iterator():
  133. logger.info("-------------- get dataset size {} -----------------".format(num_iter))
  134. logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
  135. logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
  136. num_iter += 1
  137. assert num_iter == 20
  138. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  139. num_shards=4, shard_id=3)
  140. assert data_set.get_dataset_size() == 3
  141. def test_cv_minddataset_repeat_reshuffle(add_and_remove_cv_file):
  142. """tutorial for cv minddataset."""
  143. columns_list = ["data", "label"]
  144. num_readers = 4
  145. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  146. decode_op = vision.Decode()
  147. data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
  148. resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
  149. data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
  150. data_set = data_set.batch(2)
  151. data_set = data_set.repeat(2)
  152. num_iter = 0
  153. labels = []
  154. for item in data_set.create_dict_iterator():
  155. logger.info("-------------- get dataset size {} -----------------".format(num_iter))
  156. logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
  157. logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
  158. num_iter += 1
  159. labels.append(item["label"])
  160. assert num_iter == 10
  161. logger.info("repeat shuffle: {}".format(labels))
  162. assert len(labels) == 10
  163. assert labels[0:5] == labels[0:5]
  164. assert labels[0:5] != labels[5:5]
  165. def test_cv_minddataset_batch_size_larger_than_records(add_and_remove_cv_file):
  166. """tutorial for cv minddataset."""
  167. columns_list = ["data", "label"]
  168. num_readers = 4
  169. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  170. decode_op = vision.Decode()
  171. data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=2)
  172. resize_op = vision.Resize((32, 32), interpolation=Inter.LINEAR)
  173. data_set = data_set.map(input_columns="data", operations=resize_op, num_parallel_workers=2)
  174. data_set = data_set.batch(32, drop_remainder=True)
  175. num_iter = 0
  176. for item in data_set.create_dict_iterator():
  177. logger.info("-------------- get dataset size {} -----------------".format(num_iter))
  178. logger.info("-------------- item[label]: {} ---------------------".format(item["label"]))
  179. logger.info("-------------- item[data]: {} ----------------------".format(item["data"]))
  180. num_iter += 1
  181. assert num_iter == 0
  182. def test_cv_minddataset_issue_888(add_and_remove_cv_file):
  183. """issue 888 test."""
  184. columns_list = ["data", "label"]
  185. num_readers = 2
  186. data = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, shuffle=False, num_shards=5, shard_id=1)
  187. data = data.shuffle(2)
  188. data = data.repeat(9)
  189. num_iter = 0
  190. for item in data.create_dict_iterator():
  191. num_iter += 1
  192. assert num_iter == 18
  193. def test_cv_minddataset_blockreader_tutorial(add_and_remove_cv_file):
  194. """tutorial for cv minddataset."""
  195. columns_list = ["data", "label"]
  196. num_readers = 4
  197. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers,
  198. block_reader=True)
  199. assert data_set.get_dataset_size() == 10
  200. repeat_num = 2
  201. data_set = data_set.repeat(repeat_num)
  202. num_iter = 0
  203. for item in data_set.create_dict_iterator():
  204. logger.info("-------------- block reader repeat tow {} -----------------".format(num_iter))
  205. logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
  206. logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
  207. num_iter += 1
  208. assert num_iter == 20
  209. def test_cv_minddataset_reader_basic_tutorial(add_and_remove_cv_file):
  210. """tutorial for cv minderdataset."""
  211. columns_list = ["data", "file_name", "label"]
  212. num_readers = 4
  213. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  214. assert data_set.get_dataset_size() == 10
  215. num_iter = 0
  216. for item in data_set.create_dict_iterator():
  217. logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
  218. logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  219. logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
  220. logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  221. logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
  222. num_iter += 1
  223. assert num_iter == 10
  224. def test_nlp_minddataset_reader_basic_tutorial(add_and_remove_nlp_file):
  225. """tutorial for nlp minderdataset."""
  226. num_readers = 4
  227. data_set = ds.MindDataset(NLP_FILE_NAME + "0", None, num_readers)
  228. assert data_set.get_dataset_size() == 10
  229. num_iter = 0
  230. for item in data_set.create_dict_iterator():
  231. logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
  232. logger.info("-------------- num_iter: {} ------------------------".format(num_iter))
  233. logger.info("-------------- item[id]: {} ------------------------".format(item["id"]))
  234. logger.info("-------------- item[rating]: {} --------------------".format(item["rating"]))
  235. logger.info("-------------- item[input_ids]: {}, shape: {} -----------------".format(
  236. item["input_ids"], item["input_ids"].shape))
  237. logger.info("-------------- item[input_mask]: {}, shape: {} -----------------".format(
  238. item["input_mask"], item["input_mask"].shape))
  239. logger.info("-------------- item[segment_ids]: {}, shape: {} -----------------".format(
  240. item["segment_ids"], item["segment_ids"].shape))
  241. assert item["input_ids"].shape == (50,)
  242. assert item["input_mask"].shape == (1, 50)
  243. assert item["segment_ids"].shape == (2, 25)
  244. num_iter += 1
  245. assert num_iter == 10
  246. def test_cv_minddataset_reader_basic_tutorial_5_epoch(add_and_remove_cv_file):
  247. """tutorial for cv minderdataset."""
  248. columns_list = ["data", "file_name", "label"]
  249. num_readers = 4
  250. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  251. assert data_set.get_dataset_size() == 10
  252. for epoch in range(5):
  253. num_iter = 0
  254. for data in data_set:
  255. logger.info("data is {}".format(data))
  256. num_iter += 1
  257. assert num_iter == 10
  258. data_set.reset()
  259. def test_cv_minddataset_reader_basic_tutorial_5_epoch_with_batch(add_and_remove_cv_file):
  260. """tutorial for cv minderdataset."""
  261. columns_list = ["data", "file_name", "label"]
  262. num_readers = 4
  263. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  264. resize_height = 32
  265. resize_width = 32
  266. # define map operations
  267. decode_op = vision.Decode()
  268. resize_op = vision.Resize((resize_height, resize_width), ds.transforms.vision.Inter.LINEAR)
  269. data_set = data_set.map(input_columns=["data"], operations=decode_op, num_parallel_workers=4)
  270. data_set = data_set.map(input_columns=["data"], operations=resize_op, num_parallel_workers=4)
  271. data_set = data_set.batch(2)
  272. assert data_set.get_dataset_size() == 5
  273. for epoch in range(5):
  274. num_iter = 0
  275. for data in data_set:
  276. logger.info("data is {}".format(data))
  277. num_iter += 1
  278. assert num_iter == 5
  279. data_set.reset()
  280. def test_cv_minddataset_reader_no_columns(add_and_remove_cv_file):
  281. """tutorial for cv minderdataset."""
  282. data_set = ds.MindDataset(CV_FILE_NAME + "0")
  283. assert data_set.get_dataset_size() == 10
  284. num_iter = 0
  285. for item in data_set.create_dict_iterator():
  286. logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter))
  287. logger.info("-------------- len(item[data]): {} ------------------------".format(len(item["data"])))
  288. logger.info("-------------- item[data]: {} -----------------------------".format(item["data"]))
  289. logger.info("-------------- item[file_name]: {} ------------------------".format(item["file_name"]))
  290. logger.info("-------------- item[label]: {} ----------------------------".format(item["label"]))
  291. num_iter += 1
  292. assert num_iter == 10
  293. def test_cv_minddataset_reader_repeat_tutorial(add_and_remove_cv_file):
  294. """tutorial for cv minderdataset."""
  295. columns_list = ["data", "file_name", "label"]
  296. num_readers = 4
  297. data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers)
  298. repeat_num = 2
  299. data_set = data_set.repeat(repeat_num)
  300. num_iter = 0
  301. for item in data_set.create_dict_iterator():
  302. logger.info("-------------- repeat two test {} ------------------------".format(num_iter))
  303. logger.info("-------------- len(item[data]): {} -----------------------".format(len(item["data"])))
  304. logger.info("-------------- item[data]: {} ----------------------------".format(item["data"]))
  305. logger.info("-------------- item[file_name]: {} -----------------------".format(item["file_name"]))
  306. logger.info("-------------- item[label]: {} ---------------------------".format(item["label"]))
  307. num_iter += 1
  308. assert num_iter == 20
  309. def get_data(dir_name):
  310. """
  311. usage: get data from imagenet dataset
  312. params:
  313. dir_name: directory containing folder images and annotation information
  314. """
  315. if not os.path.isdir(dir_name):
  316. raise IOError("Directory {} not exists".format(dir_name))
  317. img_dir = os.path.join(dir_name, "images")
  318. ann_file = os.path.join(dir_name, "annotation.txt")
  319. with open(ann_file, "r") as file_reader:
  320. lines = file_reader.readlines()
  321. data_list = []
  322. for line in lines:
  323. try:
  324. filename, label = line.split(",")
  325. label = label.strip("\n")
  326. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  327. img = file_reader.read()
  328. data_json = {"file_name": filename,
  329. "data": img,
  330. "label": int(label)}
  331. data_list.append(data_json)
  332. except FileNotFoundError:
  333. continue
  334. return data_list
  335. def get_multi_bytes_data(file_name, bytes_num=3):
  336. """
  337. Return raw data of multi-bytes dataset.
  338. Args:
  339. file_name (str): String of multi-bytes dataset's path.
  340. bytes_num (int): Number of bytes fields.
  341. Returns:
  342. List
  343. """
  344. if not os.path.exists(file_name):
  345. raise IOError("map file {} not exists".format(file_name))
  346. dir_name = os.path.dirname(file_name)
  347. with open(file_name, "r") as file_reader:
  348. lines = file_reader.readlines()
  349. data_list = []
  350. row_num = 0
  351. for line in lines:
  352. try:
  353. img10_path = line.strip('\n').split(" ")
  354. img5 = []
  355. for path in img10_path[:bytes_num]:
  356. with open(os.path.join(dir_name, path), "rb") as file_reader:
  357. img5 += [file_reader.read()]
  358. data_json = {"image_{}".format(i): img5[i]
  359. for i in range(len(img5))}
  360. data_json.update({"id": row_num})
  361. row_num += 1
  362. data_list.append(data_json)
  363. except FileNotFoundError:
  364. continue
  365. return data_list
  366. def get_mkv_data(dir_name):
  367. """
  368. Return raw data of Vehicle_and_Person dataset.
  369. Args:
  370. dir_name (str): String of Vehicle_and_Person dataset's path.
  371. Returns:
  372. List
  373. """
  374. if not os.path.isdir(dir_name):
  375. raise IOError("Directory {} not exists".format(dir_name))
  376. img_dir = os.path.join(dir_name, "Image")
  377. label_dir = os.path.join(dir_name, "prelabel")
  378. data_list = []
  379. file_list = os.listdir(label_dir)
  380. index = 1
  381. for item in file_list:
  382. if os.path.splitext(item)[1] == '.json':
  383. file_path = os.path.join(label_dir, item)
  384. image_name = ''.join([os.path.splitext(item)[0], ".jpg"])
  385. image_path = os.path.join(img_dir, image_name)
  386. with open(file_path, "r") as load_f:
  387. load_dict = json.load(load_f)
  388. if os.path.exists(image_path):
  389. with open(image_path, "rb") as file_reader:
  390. img = file_reader.read()
  391. data_json = {"file_name": image_name,
  392. "prelabel": str(load_dict),
  393. "data": img,
  394. "id": index}
  395. data_list.append(data_json)
  396. index += 1
  397. logger.info('{} images are missing'.format(len(file_list)-len(data_list)))
  398. return data_list
  399. def get_nlp_data(dir_name, vocab_file, num):
  400. """
  401. Return raw data of aclImdb dataset.
  402. Args:
  403. dir_name (str): String of aclImdb dataset's path.
  404. vocab_file (str): String of dictionary's path.
  405. num (int): Number of sample.
  406. Returns:
  407. List
  408. """
  409. if not os.path.isdir(dir_name):
  410. raise IOError("Directory {} not exists".format(dir_name))
  411. for root, dirs, files in os.walk(dir_name):
  412. for index, file_name_extension in enumerate(files):
  413. if index < num:
  414. file_path = os.path.join(root, file_name_extension)
  415. file_name, _ = file_name_extension.split('.', 1)
  416. id_, rating = file_name.split('_', 1)
  417. with open(file_path, 'r') as f:
  418. raw_content = f.read()
  419. dictionary = load_vocab(vocab_file)
  420. vectors = [dictionary.get('[CLS]')]
  421. vectors += [dictionary.get(i) if i in dictionary
  422. else dictionary.get('[UNK]')
  423. for i in re.findall(r"[\w']+|[{}]"
  424. .format(string.punctuation),
  425. raw_content)]
  426. vectors += [dictionary.get('[SEP]')]
  427. input_, mask, segment = inputs(vectors)
  428. input_ids = np.reshape(np.array(input_), [-1])
  429. input_mask = np.reshape(np.array(mask), [1, -1])
  430. segment_ids = np.reshape(np.array(segment), [2, -1])
  431. data = {
  432. "label": 1,
  433. "id": id_,
  434. "rating": float(rating),
  435. "input_ids": input_ids,
  436. "input_mask": input_mask,
  437. "segment_ids": segment_ids
  438. }
  439. yield data
  440. def convert_to_uni(text):
  441. if isinstance(text, str):
  442. return text
  443. if isinstance(text, bytes):
  444. return text.decode('utf-8', 'ignore')
  445. raise Exception("The type %s does not convert!" % type(text))
  446. def load_vocab(vocab_file):
  447. """load vocabulary to translate statement."""
  448. vocab = collections.OrderedDict()
  449. vocab.setdefault('blank', 2)
  450. index = 0
  451. with open(vocab_file) as reader:
  452. while True:
  453. tmp = reader.readline()
  454. if not tmp:
  455. break
  456. token = convert_to_uni(tmp)
  457. token = token.strip()
  458. vocab[token] = index
  459. index += 1
  460. return vocab
  461. def inputs(vectors, maxlen=50):
  462. length = len(vectors)
  463. if length > maxlen:
  464. return vectors[0:maxlen], [1]*maxlen, [0]*maxlen
  465. input_ = vectors+[0]*(maxlen-length)
  466. mask = [1]*length + [0]*(maxlen-length)
  467. segment = [0]*maxlen
  468. return input_, mask, segment