You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

label_image.py 2.9 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # Copyright(C) 2021 刘臣轩
  2. # This program is free software: you can redistribute it and / or modify
  3. # it under the terms of the GNU General Public License as published by
  4. # the Free Software Foundation, either version 3 of the License, or
  5. # (at your option) any later version.
  6. # This program is distributed in the hope that it will be useful,
  7. # but WITHOUT ANY WARRANTY without even the implied warranty of 
  8. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  9. # GNU General Public License for more details.
  10. # You should have received a copy of the GNU General Public License
  11. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  12. #!/usr/bin/python
  13. # -*-coding:utf-8-*-
  14. from tflite_runtime.interpreter import Interpreter
  15. from PIL import Image
  16. import re
  17. import numpy as np
  18. def loadLabels(labelPath):
  19. p = re.compile(r'\s*(\d+)(.+)')
  20. with open(labelPath, 'r', encoding='utf-8') as labelFile:
  21. lines = (p.match(line).groups() for line in labelFile.readlines())
  22. return {int(num): text.strip() for num, text in lines}
  23. def load_labels(path):
  24. with open(path, 'r', errors='ignore') as f:
  25. return {i: line.strip() for i, line in enumerate(f.readlines())}
  26. def set_input_tensor(interpreter, image):
  27. tensor_index = interpreter.get_input_details()[0]['index']
  28. input_tensor = interpreter.tensor(tensor_index)()[0]
  29. input_tensor[:, :] = image
  30. def classify_image(interpreter, image, top_k=1):
  31. set_input_tensor(interpreter, image)
  32. interpreter.invoke()
  33. output_details = interpreter.get_output_details()[0]
  34. output = np.squeeze(interpreter.get_tensor(output_details['index']))
  35. # If the model is quantized (uint8 data), then dequantize the results
  36. if output_details['dtype'] == np.uint8:
  37. scale, zero_point = output_details['quantization']
  38. output = scale * (output - zero_point)
  39. ordered = np.argpartition(-output, top_k)
  40. return [(i, output[i]) for i in ordered[:top_k]]
  41. def main():
  42. # labels = load_labels('labels_mobilenet_quant_v1_224.txt')
  43. # interpreter = Interpreter('mobilenet_v1_1.0_224_quant.tflite')
  44. # labels = loadLabels('../WasteSorting/tensorflow/labels.txt')
  45. interpreter = Interpreter('../WasteSorting/tensorflow/model.tflite')
  46. interpreter.allocate_tensors()
  47. pil_im = Image.open('../WasteSorting/WasteSorting.jpg').convert(
  48. 'RGB').resize((224, 224), Image.ANTIALIAS)
  49. pil_im.transpose(Image.FLIP_LEFT_RIGHT)
  50. results = classify_image(interpreter, pil_im)
  51. #print(results)
  52. label = results[0][0]
  53. if label == 0:
  54. print('识别失败')
  55. elif label in range(1, 4):
  56. print('有害垃圾')
  57. elif label in range(4, 7):
  58. print('可回收垃圾')
  59. elif label in range(7, 9):
  60. print('厨余垃圾')
  61. else:
  62. print('其他垃圾')
  63. if __name__ == '__main__':
  64. main()

No Description

Contributors (1)