You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

binaryop.cpp 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include "binaryop.h"
  15. #include <math.h>
  16. namespace ncnn {
  17. DEFINE_LAYER_CREATOR(BinaryOp)
  18. BinaryOp::BinaryOp()
  19. {
  20. one_blob_only = false;
  21. support_inplace = false;
  22. }
  23. #if NCNN_STDIO
  24. #if NCNN_STRING
  25. int BinaryOp::load_param(FILE* paramfp)
  26. {
  27. int nscan = fscanf(paramfp, "%d", &op_type);
  28. if (nscan != 1)
  29. {
  30. fprintf(stderr, "BinaryOp load_param failed %d\n", nscan);
  31. return -1;
  32. }
  33. return 0;
  34. }
  35. #endif // NCNN_STRING
  36. int BinaryOp::load_param_bin(FILE* paramfp)
  37. {
  38. fread(&op_type, sizeof(int), 1, paramfp);
  39. return 0;
  40. }
  41. #endif // NCNN_STDIO
  42. int BinaryOp::load_param(const unsigned char*& mem)
  43. {
  44. op_type = *(int*)(mem);
  45. mem += 4;
  46. return 0;
  47. }
  48. int BinaryOp::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs) const
  49. {
  50. const Mat& bottom_blob = bottom_blobs[0];
  51. const Mat& bottom_blob1 = bottom_blobs[1];
  52. int w = bottom_blob.w;
  53. int h = bottom_blob.h;
  54. int channels = bottom_blob.c;
  55. int size = w * h;
  56. Mat& top_blob = top_blobs[0];
  57. top_blob.create(w, h, channels);
  58. if (top_blob.empty())
  59. return -100;
  60. if (op_type == Operation_ADD)
  61. {
  62. #pragma omp parallel for
  63. for (int q=0; q<channels; q++)
  64. {
  65. const float* ptr = bottom_blob.channel(q);
  66. const float* ptr1 = bottom_blob1.channel(q);
  67. float* outptr = top_blob.channel(q);
  68. for (int i=0; i<size; i++)
  69. {
  70. outptr[i] = ptr[i] + ptr1[i];
  71. }
  72. }
  73. }
  74. else if (op_type == Operation_SUB)
  75. {
  76. #pragma omp parallel for
  77. for (int q=0; q<channels; q++)
  78. {
  79. const float* ptr = bottom_blob.channel(q);
  80. const float* ptr1 = bottom_blob1.channel(q);
  81. float* outptr = top_blob.channel(q);
  82. for (int i=0; i<size; i++)
  83. {
  84. outptr[i] = ptr[i] - ptr1[i];
  85. }
  86. }
  87. }
  88. else if (op_type == Operation_MUL)
  89. {
  90. #pragma omp parallel for
  91. for (int q=0; q<channels; q++)
  92. {
  93. const float* ptr = bottom_blob.channel(q);
  94. const float* ptr1 = bottom_blob1.channel(q);
  95. float* outptr = top_blob.channel(q);
  96. for (int i=0; i<size; i++)
  97. {
  98. outptr[i] = ptr[i] * ptr1[i];
  99. }
  100. }
  101. }
  102. else if (op_type == Operation_DIV)
  103. {
  104. #pragma omp parallel for
  105. for (int q=0; q<channels; q++)
  106. {
  107. const float* ptr = bottom_blob.channel(q);
  108. const float* ptr1 = bottom_blob1.channel(q);
  109. float* outptr = top_blob.channel(q);
  110. for (int i=0; i<size; i++)
  111. {
  112. outptr[i] = ptr[i] / ptr1[i];
  113. }
  114. }
  115. }
  116. else if (op_type == Operation_MAX)
  117. {
  118. #pragma omp parallel for
  119. for (int q=0; q<channels; q++)
  120. {
  121. const float* ptr = bottom_blob.channel(q);
  122. const float* ptr1 = bottom_blob1.channel(q);
  123. float* outptr = top_blob.channel(q);
  124. for (int i=0; i<size; i++)
  125. {
  126. outptr[i] = std::max(ptr[i], ptr1[i]);
  127. }
  128. }
  129. }
  130. else if (op_type == Operation_MIN)
  131. {
  132. #pragma omp parallel for
  133. for (int q=0; q<channels; q++)
  134. {
  135. const float* ptr = bottom_blob.channel(q);
  136. const float* ptr1 = bottom_blob1.channel(q);
  137. float* outptr = top_blob.channel(q);
  138. for (int i=0; i<size; i++)
  139. {
  140. outptr[i] = std::min(ptr[i], ptr1[i]);
  141. }
  142. }
  143. }
  144. else if (op_type == Operation_POW)
  145. {
  146. #pragma omp parallel for
  147. for (int q=0; q<channels; q++)
  148. {
  149. const float* ptr = bottom_blob.channel(q);
  150. const float* ptr1 = bottom_blob1.channel(q);
  151. float* outptr = top_blob.channel(q);
  152. for (int i=0; i<size; i++)
  153. {
  154. outptr[i] = pow(ptr[i], ptr1[i]);
  155. }
  156. }
  157. }
  158. return 0;
  159. }
  160. } // namespace ncnn