You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

writer.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. ######################## write mindrecord example ########################
  17. Write mindrecord by data dictionary:
  18. python writer.py --mindrecord_script /YourScriptPath ...
  19. """
  20. import argparse
  21. import os
  22. import time
  23. from importlib import import_module
  24. from multiprocessing import Pool
  25. from mindspore.mindrecord import FileWriter
  26. from graph_map_schema import GraphMapSchema
  27. def exec_task(task_id, parallel_writer=True):
  28. """
  29. Execute task with specified task id
  30. """
  31. print("exec task {}, parallel: {} ...".format(task_id, parallel_writer))
  32. imagenet_iter = mindrecord_dict_data(task_id)
  33. batch_size = 512
  34. transform_count = 0
  35. while True:
  36. data_list = []
  37. try:
  38. for _ in range(batch_size):
  39. data = imagenet_iter.__next__()
  40. if 'dst_id' in data:
  41. data = graph_map_schema.transform_edge(data)
  42. else:
  43. data = graph_map_schema.transform_node(data)
  44. data_list.append(data)
  45. transform_count += 1
  46. writer.write_raw_data(data_list, parallel_writer=parallel_writer)
  47. print("transformed {} record...".format(transform_count))
  48. except StopIteration:
  49. if data_list:
  50. writer.write_raw_data(data_list, parallel_writer=parallel_writer)
  51. print("transformed {} record...".format(transform_count))
  52. break
  53. def read_args():
  54. """
  55. read args
  56. """
  57. parser = argparse.ArgumentParser(description='Mind record writer')
  58. parser.add_argument('--mindrecord_script', type=str, default="template",
  59. help='path where script is saved')
  60. parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord/xyz",
  61. help='written file name prefix')
  62. parser.add_argument('--mindrecord_partitions', type=int, default=1,
  63. help='number of written files')
  64. parser.add_argument('--mindrecord_header_size_by_bit', type=int, default=24,
  65. help='mindrecord file header size')
  66. parser.add_argument('--mindrecord_page_size_by_bit', type=int, default=25,
  67. help='mindrecord file page size')
  68. parser.add_argument('--mindrecord_workers', type=int, default=8,
  69. help='number of parallel workers')
  70. parser.add_argument('--num_node_tasks', type=int, default=1,
  71. help='number of node tasks')
  72. parser.add_argument('--num_edge_tasks', type=int, default=1,
  73. help='number of node tasks')
  74. parser.add_argument('--graph_api_args', type=str, default="/tmp/nodes.csv:/tmp/edges.csv",
  75. help='nodes and edges data file, csv format with header.')
  76. ret_args = parser.parse_args()
  77. return ret_args
  78. def init_writer(mr_schema):
  79. """
  80. init writer
  81. """
  82. print("Init writer ...")
  83. mr_writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions)
  84. # set the header size
  85. if args.mindrecord_header_size_by_bit != 24:
  86. header_size = 1 << args.mindrecord_header_size_by_bit
  87. mr_writer.set_header_size(header_size)
  88. # set the page size
  89. if args.mindrecord_page_size_by_bit != 25:
  90. page_size = 1 << args.mindrecord_page_size_by_bit
  91. mr_writer.set_page_size(page_size)
  92. # create the schema
  93. mr_writer.add_schema(mr_schema, "mindrecord_graph_schema")
  94. # open file and set header
  95. mr_writer.open_and_set_header()
  96. return mr_writer
  97. def run_parallel_workers(num_tasks):
  98. """
  99. run parallel workers
  100. """
  101. # set number of workers
  102. num_workers = args.mindrecord_workers
  103. task_list = list(range(num_tasks))
  104. if num_workers > num_tasks:
  105. num_workers = num_tasks
  106. if os.name == 'nt':
  107. for window_task_id in task_list:
  108. exec_task(window_task_id, False)
  109. elif num_tasks > 1:
  110. with Pool(num_workers) as p:
  111. p.map(exec_task, task_list)
  112. else:
  113. exec_task(0, False)
  114. if __name__ == "__main__":
  115. args = read_args()
  116. print(args)
  117. start_time = time.time()
  118. # pass mr_api arguments
  119. os.environ['graph_api_args'] = args.graph_api_args
  120. # import mr_api
  121. try:
  122. mr_api = import_module(args.mindrecord_script + '.mr_api')
  123. except ModuleNotFoundError:
  124. raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api'))
  125. # init graph schema
  126. graph_map_schema = GraphMapSchema()
  127. num_features, feature_data_types, feature_shapes = mr_api.node_profile
  128. graph_map_schema.set_node_feature_profile(num_features, feature_data_types, feature_shapes)
  129. num_features, feature_data_types, feature_shapes = mr_api.edge_profile
  130. graph_map_schema.set_edge_feature_profile(num_features, feature_data_types, feature_shapes)
  131. graph_schema = graph_map_schema.get_schema()
  132. # init writer
  133. writer = init_writer(graph_schema)
  134. # write nodes data
  135. mindrecord_dict_data = mr_api.yield_nodes
  136. run_parallel_workers(args.num_node_tasks)
  137. # write edges data
  138. mindrecord_dict_data = mr_api.yield_edges
  139. run_parallel_workers(args.num_edge_tasks)
  140. # writer wrap up
  141. ret = writer.commit()
  142. end_time = time.time()
  143. print("--------------------------------------------")
  144. print("END. Total time: {}".format(end_time - start_time))
  145. print("--------------------------------------------")