You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_extractor.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright 2021 Tencent
  2. # SPDX-License-Identifier: BSD-3-Clause
  3. import pytest
  4. import ncnn
  5. alloctor = ncnn.PoolAllocator()
  6. def test_extractor():
  7. with pytest.raises(TypeError, match="No constructor"):
  8. ex = ncnn.Extractor()
  9. dr = ncnn.DataReaderFromEmpty()
  10. net = ncnn.Net()
  11. net.load_param("tests/test.param")
  12. net.load_model(dr)
  13. in_mat = ncnn.Mat((227, 227, 3))
  14. with net.create_extractor() as ex:
  15. ex.set_light_mode(True)
  16. ex.set_num_threads(2)
  17. ex.set_blob_allocator(alloctor)
  18. ex.set_workspace_allocator(alloctor)
  19. ex.input("data", in_mat)
  20. ret, out_mat = ex.extract("conv0_fwd")
  21. assert (
  22. ret == 0
  23. and out_mat.dims == 3
  24. and out_mat.w == 225
  25. and out_mat.h == 225
  26. and out_mat.c == 3
  27. )
  28. ret, out_mat = ex.extract("output")
  29. assert ret == 0 and out_mat.dims == 1 and out_mat.w == 1
  30. def test_extractor_index():
  31. with pytest.raises(TypeError, match="No constructor"):
  32. ex = ncnn.Extractor()
  33. dr = ncnn.DataReaderFromEmpty()
  34. net = ncnn.Net()
  35. net.load_param("tests/test.param")
  36. net.load_model(dr)
  37. in_mat = ncnn.Mat((227, 227, 3))
  38. ex = net.create_extractor()
  39. ex.set_light_mode(True)
  40. ex.set_num_threads(2)
  41. ex.set_blob_allocator(alloctor)
  42. ex.set_workspace_allocator(alloctor)
  43. ex.input(0, in_mat)
  44. ret, out_mat = ex.extract(1)
  45. assert (
  46. ret == 0
  47. and out_mat.dims == 3
  48. and out_mat.w == 225
  49. and out_mat.h == 225
  50. and out_mat.c == 3
  51. )
  52. ret, out_mat = ex.extract(2)
  53. assert ret == 0 and out_mat.dims == 1 and out_mat.w == 1
  54. # not use with sentence, call clear manually to ensure ex destruct before net
  55. ex.clear()