You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pybind11_layer.h 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /* Tencent is pleased to support the open source community by making ncnn available.
  2. *
  3. * Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
  4. *
  5. * Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. * in compliance with the License. You may obtain a copy of the License at
  7. *
  8. * https://opensource.org/licenses/BSD-3-Clause
  9. *
  10. * Unless required by applicable law or agreed to in writing, software distributed
  11. * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. * CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. * specific language governing permissions and limitations under the License.
  14. */
  15. #ifndef PYBIND11_NCNN_LAYER_H
  16. #define PYBIND11_NCNN_LAYER_H
  17. #include <layer.h>
  18. #include "pybind11_bind.h"
  19. class PyLayer : public ncnn::Layer
  20. {
  21. public:
  22. virtual int load_param(const ncnn::ParamDict& pd)
  23. {
  24. PYBIND11_OVERRIDE_REFERENCE(
  25. int,
  26. ncnn::Layer,
  27. load_param,
  28. pd);
  29. }
  30. virtual int load_model(const ncnn::ModelBin& mb)
  31. {
  32. PYBIND11_OVERRIDE_REFERENCE(
  33. int,
  34. ncnn::Layer,
  35. load_model,
  36. mb);
  37. }
  38. virtual int create_pipeline(const ncnn::Option& opt)
  39. {
  40. PYBIND11_OVERRIDE_REFERENCE(
  41. int,
  42. ncnn::Layer,
  43. create_pipeline,
  44. opt);
  45. }
  46. virtual int destroy_pipeline(const ncnn::Option& opt)
  47. {
  48. PYBIND11_OVERRIDE_REFERENCE(
  49. int,
  50. ncnn::Layer,
  51. destroy_pipeline,
  52. opt);
  53. }
  54. public:
  55. virtual int forward(const std::vector<ncnn::Mat>& bottom_blobs, std::vector<ncnn::Mat>& top_blobs, const ncnn::Option& opt) const
  56. {
  57. PYBIND11_OVERRIDE_REFERENCE(
  58. int,
  59. ncnn::Layer,
  60. forward,
  61. bottom_blobs,
  62. top_blobs,
  63. opt);
  64. }
  65. virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const
  66. {
  67. PYBIND11_OVERRIDE_REFERENCE(
  68. int,
  69. ncnn::Layer,
  70. forward,
  71. bottom_blob,
  72. top_blob,
  73. opt);
  74. }
  75. virtual int forward_inplace(std::vector<ncnn::Mat>& bottom_top_blobs, const ncnn::Option& opt) const
  76. {
  77. PYBIND11_OVERRIDE_REFERENCE(
  78. int,
  79. ncnn::Layer,
  80. forward_inplace,
  81. bottom_top_blobs,
  82. opt);
  83. }
  84. virtual int forward_inplace(ncnn::Mat& bottom_top_blob, const ncnn::Option& opt) const
  85. {
  86. PYBIND11_OVERRIDE_REFERENCE(
  87. int,
  88. ncnn::Layer,
  89. forward_inplace,
  90. bottom_top_blob,
  91. opt);
  92. }
  93. #if NCNN_VULKAN
  94. public:
  95. virtual int upload_model(ncnn::VkTransfer& cmd, const ncnn::Option& opt)
  96. {
  97. PYBIND11_OVERRIDE_REFERENCE(
  98. int,
  99. ncnn::Layer,
  100. upload_model,
  101. cmd,
  102. opt);
  103. }
  104. public:
  105. virtual int forward(const std::vector<ncnn::VkMat>& bottom_blobs, std::vector<ncnn::VkMat>& top_blobs, ncnn::VkCompute& cmd, const ncnn::Option& opt) const
  106. {
  107. PYBIND11_OVERRIDE_REFERENCE(
  108. int,
  109. ncnn::Layer,
  110. forward,
  111. bottom_blobs,
  112. top_blobs,
  113. cmd,
  114. opt);
  115. }
  116. virtual int forward(const ncnn::VkMat& bottom_blob, ncnn::VkMat& top_blob, ncnn::VkCompute& cmd, const ncnn::Option& opt) const
  117. {
  118. PYBIND11_OVERRIDE_REFERENCE(
  119. int,
  120. ncnn::Layer,
  121. forward,
  122. bottom_blob,
  123. top_blob,
  124. cmd,
  125. opt);
  126. }
  127. virtual int forward_inplace(std::vector<ncnn::VkMat>& bottom_top_blobs, ncnn::VkCompute& cmd, const ncnn::Option& opt) const
  128. {
  129. PYBIND11_OVERRIDE_REFERENCE(
  130. int,
  131. ncnn::Layer,
  132. forward_inplace,
  133. bottom_top_blobs,
  134. cmd,
  135. opt);
  136. }
  137. virtual int forward_inplace(ncnn::VkMat& bottom_top_blob, ncnn::VkCompute& cmd, const ncnn::Option& opt) const
  138. {
  139. PYBIND11_OVERRIDE_REFERENCE(
  140. int,
  141. ncnn::Layer,
  142. forward_inplace,
  143. bottom_top_blob,
  144. cmd,
  145. opt);
  146. }
  147. #endif // NCNN_VULKAN
  148. };
  149. #endif