You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

arithmetic.h 8.7 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_LITE_NNACL_ARITHMETIC_H_
  17. #define MINDSPORE_LITE_NNACL_ARITHMETIC_H_
  18. #ifdef ENABLE_NEON
  19. #include <arm_neon.h>
  20. #endif
  21. #include "nnacl/op_base.h"
  22. #include "nnacl/arithmetic_common.h"
  23. #include "nnacl/errorcode.h"
  24. #ifdef __cplusplus
  25. extern "C" {
  26. #endif
  27. int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  28. int ElementOptAddInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
  29. int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  30. int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  31. int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  32. int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  33. int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  34. int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  35. int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  36. int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  37. int ElementOptMulInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
  38. int ElementOptMulReluInt(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
  39. int ElementOptMulRelu6Int(int *input0, int *input1, int *output, int element_size, ArithmeticParameter *param);
  40. int ElementOptDiv(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  41. int ElementOptDivRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  42. int ElementOptDivRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param);
  43. int ElementMul(float *input0, float *input1, float *output, int element_size);
  44. int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
  45. int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
  46. int ElementMulInt(int *input0, int *input1, int *output, int element_size);
  47. int ElementMulReluInt(int *input0, int *input1, int *output, int element_size);
  48. int ElementMulRelu6Int(int *input0, int *input1, int *output, int element_size);
  49. int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
  50. ArithmeticParameter *param);
  51. int ElementAdd(float *input0, float *input1, float *output, int element_size);
  52. int ElementAddRelu(float *input0, float *input1, float *output, int element_size);
  53. int ElementAddRelu6(float *input0, float *input1, float *output, int element_size);
  54. int ElementAddInt(int *input0, int *input1, int *output, int element_size);
  55. int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
  56. ArithmeticParameter *param);
  57. int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output,
  58. int element_size, ArithmeticParameter *param);
  59. int ElementSub(float *input0, float *input1, float *output, int element_size);
  60. int ElementSubRelu(float *input0, float *input1, float *output, int element_size);
  61. int ElementSubRelu6(float *input0, float *input1, float *output, int element_size);
  62. int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
  63. ArithmeticParameter *param);
  64. int ElementDiv(float *input0, float *input1, float *output, int element_size);
  65. int ElementDivRelu(float *input0, float *input1, float *output, int element_size);
  66. int ElementDivRelu6(float *input0, float *input1, float *output, int element_size);
  67. int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
  68. ArithmeticParameter *param);
  69. int ElementLogicalAnd(float *input0, float *input1, float *output, int element_size);
  70. int BroadcastLogicalAnd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  71. int element_size, ArithmeticParameter *param);
  72. int ElementLogicalOr(float *input0, float *input1, float *output, int element_size);
  73. int BroadcastLogicalOr(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  74. int element_size, ArithmeticParameter *param);
  75. int ElementMaximum(float *input0, float *input1, float *output, int element_size);
  76. int BroadcastMaximum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  77. int element_size, ArithmeticParameter *param);
  78. int ElementMinimum(float *input0, float *input1, float *output, int element_size);
  79. int BroadcastMinimum(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  80. int element_size, ArithmeticParameter *param);
  81. int ElementFloorDiv(float *input0, float *input1, float *output, int element_size);
  82. int ElementFloorDivInt(int *input0, int *input1, int *output, int element_size);
  83. int BroadcastFloorDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  84. int element_size, ArithmeticParameter *param);
  85. int ElementFloorMod(float *input0, float *input1, float *output, int element_size);
  86. int ElementFloorModInt(int *input0, int *input1, int *output, int element_size);
  87. int BroadcastFloorMod(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  88. int element_size, ArithmeticParameter *param);
  89. int ElementSquaredDifference(float *input0, float *input1, float *output, int element_size);
  90. int BroadcastSquaredDifference(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  91. int element_size, ArithmeticParameter *param);
  92. int ElementNotEqual(float *input0, float *input1, float *output, int element_size);
  93. int BroadcastNotEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  94. int element_size, ArithmeticParameter *param);
  95. int ElementEqual(float *input0, float *input1, float *output, int element_size);
  96. int BroadcastEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  97. int element_size, ArithmeticParameter *param);
  98. int ElementLess(float *input0, float *input1, float *output, int element_size);
  99. int BroadcastLess(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
  100. ArithmeticParameter *param);
  101. int ElementLessEqual(float *input0, float *input1, float *output, int element_size);
  102. int BroadcastLessEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  103. int element_size, ArithmeticParameter *param);
  104. int ElementGreater(float *input0, float *input1, float *output, int element_size);
  105. int BroadcastGreater(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  106. int element_size, ArithmeticParameter *param);
  107. int ElementGreaterEqual(float *input0, float *input1, float *output, int element_size);
  108. int BroadcastGreaterEqual(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output,
  109. int element_size, ArithmeticParameter *param);
  110. #ifdef ENABLE_NNACL_INFER_SHAPE
  111. int ArithmeticInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_format, int *out_format,
  112. int *in_datatype, int *out_datatype, OpParameter *param);
  113. #endif
  114. #ifdef __cplusplus
  115. }
  116. #endif
  117. #endif // MINDSPORE_LITE_NNACL_ARITHMETIC_H_