You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

label_image.py 2.9 kB

5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright(C) 2021 刘臣轩
  2. # This program is free software: you can redistribute it and / or modify
  3. # it under the terms of the GNU General Public License as published by
  4. # the Free Software Foundation, either version 3 of the License, or
  5. # (at your option) any later version.
  6. # This program is distributed in the hope that it will be useful,
  7. # but WITHOUT ANY WARRANTY without even the implied warranty of 
  8. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  9. # GNU General Public License for more details.
  10. # You should have received a copy of the GNU General Public License
  11. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  12. #!/usr/bin/python
  13. # -*-coding:utf-8-*-
  14. from tflite_runtime.interpreter import Interpreter
  15. from PIL import Image
  16. import cv2
  17. import re
  18. import os
  19. import numpy as np
  20. def loadLabels(labelPath):
  21. p = re.compile(r'\s*(\d+)(.+)')
  22. with open(labelPath, 'r', encoding='utf-8') as labelFile:
  23. lines = (p.match(line).groups() for line in labelFile.readlines())
  24. return {int(num): text.strip() for num, text in lines}
  25. def load_labels(path):
  26. with open(path, 'r', errors='ignore') as f:
  27. return {i: line.strip() for i, line in enumerate(f.readlines())}
  28. def set_input_tensor(interpreter, image):
  29. tensor_index = interpreter.get_input_details()[0]['index']
  30. input_tensor = interpreter.tensor(tensor_index)()[0]
  31. input_tensor[:, :] = image
  32. def classify_image(interpreter, image, top_k=1):
  33. set_input_tensor(interpreter, image)
  34. interpreter.invoke()
  35. output_details = interpreter.get_output_details()[0]
  36. output = np.squeeze(interpreter.get_tensor(output_details['index']))
  37. # If the model is quantized (uint8 data), then dequantize the results
  38. if output_details['dtype'] == np.uint8:
  39. scale, zero_point = output_details['quantization']
  40. output = scale * (output - zero_point)
  41. ordered = np.argpartition(-output, top_k)
  42. return [(i, output[i]) for i in ordered[:top_k]]
  43. def main():
  44. # labels = load_labels('labels_mobilenet_quant_v1_224.txt')
  45. # interpreter = Interpreter('mobilenet_v1_1.0_224_quant.tflite')
  46. # labels = loadLabels('../WasteSorting/tensorflow/labels.txt')
  47. interpreter = Interpreter('../WasteSorting/tensorflow/model.tflite')
  48. interpreter.allocate_tensors()
  49. pil_im = Image.open('../WasteSorting/WasteSorting.jpg').convert(
  50. 'RGB').resize((224, 224), Image.ANTIALIAS)
  51. pil_im.transpose(Image.FLIP_LEFT_RIGHT)
  52. results = classify_image(interpreter, pil_im)
  53. # print(results)
  54. label = results[0][0]
  55. if label == 0:
  56. print('识别失败')
  57. elif label in range(1, 4):
  58. print('有害垃圾')
  59. elif label in range(4, 7):
  60. print('可回收物')
  61. elif label in range(7, 10):
  62. print('厨余垃圾')
  63. else:
  64. print('其他垃圾')
  65. if __name__ == '__main__':
  66. main()

No Description

Contributors (1)