You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ofrecord.py 6.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. /**
  3. * Copyright 2020 Zhejiang Lab. All Rights Reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. * =============================================================
  17. */
  18. """
  19. # -*- coding: utf-8 -*-
  20. import logging
  21. import json
  22. import os
  23. import struct
  24. import cv2
  25. import sched
  26. import numpy as np
  27. import oneflow.core.record.record_pb2 as of_record
  28. import luascript.delaytaskscript as delay_script
  29. import time
  30. import common.config as config
  31. from datetime import datetime
  32. schedule = sched.scheduler(time.time, time.sleep)
  33. delayId = ""
  34. class ImageCoder(object):
  35. """Helper class that provides image coding utilities."""
  36. def __init__(self, size=None):
  37. self.size = size
  38. def _resize(self, image_data):
  39. if self.size is not None and image_data.shape[:2] != self.size:
  40. return cv2.resize(image_data, self.size)
  41. return image_data
  42. def image_to_jpeg(self, image_data):
  43. image_data = cv2.imdecode(np.fromstring(image_data, np.uint8), 1)
  44. image_data = self._resize(image_data)
  45. return cv2.imencode(".jpg", image_data)[1].tostring(
  46. ), image_data.shape[0], image_data.shape[1]
  47. def _process_image(filename, coder):
  48. """Process a single image file.
  49. Args:
  50. filename: string, path to an image file e.g., '/path/to/example.JPG'.
  51. coder: instance of ImageCoder to provide image coding utils.
  52. Returns:
  53. image_buffer: string, JPEG encoding of RGB image.
  54. height: integer, image height in pixels.
  55. width: integer, image width in pixels.
  56. """
  57. # Read the image file.
  58. with open(filename, 'rb') as f:
  59. image_data = f.read()
  60. image_data, height, width = coder.image_to_jpeg(image_data)
  61. return image_data, height, width
  62. def _bytes_feature(value):
  63. """Wrapper for inserting bytes features into Example proto."""
  64. return of_record.Feature(bytes_list=of_record.BytesList(value=[value]))
  65. def dense_to_one_hot(labels_dense, num_classes):
  66. """Convert class labels from scalars to one-hot vectors."""
  67. num_labels = labels_dense.shape[0]
  68. index_offset = np.arange(num_labels) * num_classes
  69. labels_one_hot = np.zeros((num_labels, num_classes))
  70. labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  71. return labels_one_hot
  72. def extract_img_label(names, path):
  73. """Extract the images and labels into np array [index].
  74. Args:
  75. f: A file object that contain images and annotations.
  76. Returns:
  77. data: A 4D uint8 np array [index, h, w, depth].
  78. labels: a 1D uint8 np array.
  79. num_img: the number of images.
  80. """
  81. train_img = os.path.join(path, 'origin/')
  82. train_label = os.path.join(path, 'annotation/')
  83. num_imgs = len(names)
  84. data = []
  85. labels = []
  86. print('^^^^^^^^^^ start img_set for sycle')
  87. for i in names:
  88. name = os.path.splitext(i)[0]
  89. print(name)
  90. coder = ImageCoder((224, 224))
  91. image_buffer, height, width = _process_image(
  92. os.path.join(train_img, i), coder)
  93. data += [image_buffer]
  94. if os.path.exists(os.path.join(train_label, name)):
  95. with open(os.path.join(train_label, name), "r", encoding='utf-8') as jsonFile:
  96. la = json.load(jsonFile)
  97. if la:
  98. labels += [la[0]['category_id']]
  99. else:
  100. data.pop()
  101. num_imgs -= 1
  102. else:
  103. print('File is not found')
  104. print('^^^^^^^^^ img_set for end')
  105. data = np.array(data)
  106. labels = np.array(labels)
  107. print(data.shape, labels.shape)
  108. return num_imgs, data, labels
  109. def execute(src_path, desc, label_map, files, part_id, key):
  110. """Execute ofrecord task method."""
  111. global delayId
  112. delayId = delayId = "\"" + eval(str(key, encoding="utf-8")) + "\""
  113. logging.info(part_id)
  114. num_imgs, images, labels = extract_img_label(files, src_path)
  115. keys = sorted(list(map(int, label_map.keys())))
  116. for i in range(len(keys)):
  117. label_map[str(keys[i])] = i
  118. if not num_imgs:
  119. return False, 0, 0
  120. try:
  121. os.makedirs(desc)
  122. except Exception as e:
  123. print('{} exists.'.format(desc))
  124. for i in range(num_imgs):
  125. filename = 'part-{}'.format(part_id)
  126. filename = os.path.join(desc, filename)
  127. f = open(filename, 'wb')
  128. print(filename)
  129. img = images[i]
  130. label = label_map[str(labels[i])]
  131. sample = of_record.OFRecord(feature={
  132. 'class/label': of_record.Feature(int32_list=of_record.Int32List(value=[label])),
  133. 'encoded': _bytes_feature(img)
  134. })
  135. size = sample.ByteSize()
  136. f.write(struct.pack("q", size))
  137. f.write(sample.SerializeToString())
  138. if f:
  139. f.close()
  140. def delaySchduled(inc, redisClient):
  141. """Delay task method.
  142. Args:
  143. inc: scheduled task time.
  144. redisClient: redis client.
  145. """
  146. try:
  147. print("delay:" + datetime.now().strftime("B%Y-%m-%d %H:%M:%S"))
  148. redisClient.eval(delay_script.delayTaskLua, 1, config.ofrecordStartQueue, delayId, int(time.time()))
  149. schedule.enter(inc, 0, delaySchduled, (inc, redisClient))
  150. except Exception as e:
  151. print("delay error" + e)
  152. def delayKeyThread(redisClient):
  153. """Delay task thread.
  154. Args:
  155. redisClient: redis client.
  156. """
  157. schedule.enter(0, 0, delaySchduled, (5, redisClient))
  158. schedule.run()

一站式算法开发平台、高性能分布式深度学习框架、先进算法模型库、视觉模型炼知平台、数据可视化分析平台等一系列平台及工具,在模型高效分布式训练、数据处理和可视分析、模型炼知和轻量化等技术上形成独特优势,目前已在产学研等各领域近千家单位及个人提供AI应用赋能

Contributors (1)