You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

util_minddataset.py 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. This module contains common utility functions for minddataset tests.
  17. """
  18. import os
  19. import pytest
  20. from mindspore.mindrecord import FileWriter
  21. FILES_NUM = 4
  22. CV_DIR_NAME = "../data/mindrecord/testImageNetData"
  23. def get_data(dir_name):
  24. """
  25. usage: get data from imagenet dataset
  26. params:
  27. dir_name: directory containing folder images and annotation information
  28. """
  29. if not os.path.isdir(dir_name):
  30. raise IOError("Directory {} does not exist".format(dir_name))
  31. img_dir = os.path.join(dir_name, "images")
  32. ann_file = os.path.join(dir_name, "annotation.txt")
  33. with open(ann_file, "r") as file_reader:
  34. lines = file_reader.readlines()
  35. data_list = []
  36. for i, line in enumerate(lines):
  37. try:
  38. filename, label = line.split(",")
  39. label = label.strip("\n")
  40. with open(os.path.join(img_dir, filename), "rb") as file_reader:
  41. img = file_reader.read()
  42. data_json = {"id": i,
  43. "file_name": filename,
  44. "data": img,
  45. "label": int(label)}
  46. data_list.append(data_json)
  47. except FileNotFoundError:
  48. continue
  49. return data_list
  50. @pytest.fixture
  51. def add_and_remove_cv_file():
  52. """add/remove cv file"""
  53. file_name = os.environ.get('PYTEST_CURRENT_TEST').split(':')[-1].split(' ')[0]
  54. paths = ["{}{}".format(file_name, str(x).rjust(1, '0'))
  55. for x in range(FILES_NUM)]
  56. try:
  57. for x in paths:
  58. if os.path.exists("{}".format(x)):
  59. os.remove("{}".format(x))
  60. if os.path.exists("{}.db".format(x)):
  61. os.remove("{}.db".format(x))
  62. writer = FileWriter(file_name, FILES_NUM)
  63. data = get_data(CV_DIR_NAME)
  64. cv_schema_json = {"id": {"type": "int32"},
  65. "file_name": {"type": "string"},
  66. "label": {"type": "int32"},
  67. "data": {"type": "bytes"}}
  68. writer.add_schema(cv_schema_json, "img_schema")
  69. writer.add_index(["file_name", "label"])
  70. writer.write_raw_data(data)
  71. writer.commit()
  72. yield "yield_cv_data"
  73. except Exception as error:
  74. for x in paths:
  75. os.remove("{}".format(x))
  76. os.remove("{}.db".format(x))
  77. raise error
  78. else:
  79. for x in paths:
  80. os.remove("{}".format(x))
  81. os.remove("{}.db".format(x))