You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

label_image.py 2.9 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # This program is free software: you can redistribute it and / or modify
  2. # it under the terms of the GNU General Public License as published by
  3. # the Free Software Foundation, either version 3 of the License, or
  4. # (at your option) any later version.
  5. # This program is distributed in the hope that it will be useful,
  6. # but WITHOUT ANY WARRANTY without even the implied warranty of 
  7. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  8. # GNU General Public License for more details.
  9. # You should have received a copy of the GNU General Public License
  10. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  11. #!/usr/bin/python
  12. # -*-coding:utf-8-*-
  13. from tflite_runtime.interpreter import Interpreter
  14. from PIL import Image
  15. import re
  16. import numpy as np
  17. def loadLabels(labelPath):
  18. p = re.compile(r'\s*(\d+)(.+)')
  19. with open(labelPath, 'r', encoding='utf-8') as labelFile:
  20. lines = (p.match(line).groups() for line in labelFile.readlines())
  21. return {int(num): text.strip() for num, text in lines}
  22. def load_labels(path):
  23. with open(path, 'r', errors='ignore') as f:
  24. return {i: line.strip() for i, line in enumerate(f.readlines())}
  25. def set_input_tensor(interpreter, image):
  26. tensor_index = interpreter.get_input_details()[0]['index']
  27. input_tensor = interpreter.tensor(tensor_index)()[0]
  28. input_tensor[:, :] = image
  29. def classify_image(interpreter, image, top_k=1):
  30. set_input_tensor(interpreter, image)
  31. interpreter.invoke()
  32. output_details = interpreter.get_output_details()[0]
  33. output = np.squeeze(interpreter.get_tensor(output_details['index']))
  34. # If the model is quantized (uint8 data), then dequantize the results
  35. if output_details['dtype'] == np.uint8:
  36. scale, zero_point = output_details['quantization']
  37. output = scale * (output - zero_point)
  38. ordered = np.argpartition(-output, top_k)
  39. return [(i, output[i]) for i in ordered[:top_k]]
  40. def main():
  41. # labels = load_labels('labels_mobilenet_quant_v1_224.txt')
  42. # interpreter = Interpreter('mobilenet_v1_1.0_224_quant.tflite')
  43. labels = loadLabels('../WasteSorting/tensorflow/labels.txt')
  44. interpreter = Interpreter('../WasteSorting/tensorflow/model.tflite')
  45. interpreter.allocate_tensors()
  46. pil_im = Image.open('../WasteSorting/WasteSorting.jpg').convert(
  47. 'RGB').resize((224, 224), Image.ANTIALIAS)
  48. #pil_im.transpose(Image.FLIP_LEFT_RIGHT)
  49. results = classify_image(interpreter, pil_im)
  50. print(results)
  51. print(labels[results[0][0]])
  52. label = results[0][0]
  53. if label == 0:
  54. print('识别失败')
  55. elif label in range(1, 4):
  56. print('有害垃圾')
  57. elif label in range(4, 7):
  58. print('可回收垃圾')
  59. elif label in range(7, 9):
  60. print('厨余垃圾')
  61. else:
  62. print('其他垃圾')
  63. if __name__ == '__main__':
  64. main()

No Description