You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

writer.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. ######################## write mindrecord example ########################
  17. Write mindrecord by data dictionary:
  18. python writer.py --mindrecord_script /YourScriptPath ...
  19. """
  20. import argparse
  21. import os
  22. import pickle
  23. import time
  24. from importlib import import_module
  25. from multiprocessing import Pool
  26. from mindspore.mindrecord import FileWriter
  27. def _exec_task(task_id, parallel_writer=True):
  28. """
  29. Execute task with specified task id
  30. """
  31. print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
  32. imagenet_iter = mindrecord_dict_data(task_id)
  33. batch_size = 2048
  34. transform_count = 0
  35. while True:
  36. data_list = []
  37. try:
  38. for _ in range(batch_size):
  39. data_list.append(imagenet_iter.__next__())
  40. transform_count += 1
  41. writer.write_raw_data(data_list, parallel_writer=parallel_writer)
  42. print("transformed {} record...".format(transform_count))
  43. except StopIteration:
  44. if data_list:
  45. writer.write_raw_data(data_list, parallel_writer=parallel_writer)
  46. print("transformed {} record...".format(transform_count))
  47. break
  48. if __name__ == "__main__":
  49. parser = argparse.ArgumentParser(description='Mind record writer')
  50. parser.add_argument('--mindrecord_script', type=str, default="template",
  51. help='path where script is saved')
  52. parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord",
  53. help='written file name prefix')
  54. parser.add_argument('--mindrecord_partitions', type=int, default=1,
  55. help='number of written files')
  56. parser.add_argument('--mindrecord_workers', type=int, default=8,
  57. help='number of parallel workers')
  58. args = parser.parse_known_args()
  59. args, other_args = parser.parse_known_args()
  60. print(args)
  61. print(other_args)
  62. with open('mr_argument.pickle', 'wb') as file_handle:
  63. pickle.dump(other_args, file_handle)
  64. try:
  65. mr_api = import_module(args.mindrecord_script + '.mr_api')
  66. except ModuleNotFoundError:
  67. raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
  68. num_tasks = mr_api.mindrecord_task_number()
  69. print("Write mindrecord ...")
  70. mindrecord_dict_data = mr_api.mindrecord_dict_data
  71. # get number of files
  72. writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
  73. start_time = time.time()
  74. # set the header size
  75. try:
  76. header_size = mr_api.mindrecord_header_size
  77. writer.set_header_size(header_size)
  78. except AttributeError:
  79. print("Default header size: {}".format(1 << 24))
  80. # set the page size
  81. try:
  82. page_size = mr_api.mindrecord_page_size
  83. writer.set_page_size(page_size)
  84. except AttributeError:
  85. print("Default page size: {}".format(1 << 25))
  86. # get schema
  87. try:
  88. mindrecord_schema = mr_api.mindrecord_schema
  89. except AttributeError:
  90. raise RuntimeError("mindrecord_schema is not defined in mr_api.py.")
  91. # create the schema
  92. writer.add_schema(mindrecord_schema, "mindrecord_schema")
  93. # add the index
  94. try:
  95. index_fields = mr_api.mindrecord_index_fields
  96. writer.add_index(index_fields)
  97. except AttributeError:
  98. print("Default index fields: all simple fields are indexes.")
  99. writer.open_and_set_header()
  100. task_list = list(range(num_tasks))
  101. # set number of workers
  102. num_workers = args.mindrecord_workers
  103. if num_tasks < 1:
  104. num_tasks = 1
  105. if num_workers > num_tasks:
  106. num_workers = num_tasks
  107. if os.name == 'nt':
  108. for window_task_id in task_list:
  109. _exec_task(window_task_id, False)
  110. elif num_tasks > 1:
  111. with Pool(num_workers) as p:
  112. p.map(_exec_task, task_list)
  113. else:
  114. _exec_task(0, False)
  115. ret = writer.commit()
  116. os.remove("{}".format("mr_argument.pickle"))
  117. end_time = time.time()
  118. print("--------------------------------------------")
  119. print("END. Total time: {}".format(end_time - start_time))
  120. print("--------------------------------------------")