You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ast_c.cpp 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * \file src/jit/impl/ast_c.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/jit/ast_c.h"
  12. #include "megbrain/jit/executor_opr.h"
  13. #include "megbrain/opr/tensor_manip.h"
  14. #if MGB_JIT
  15. using namespace mgb;
  16. using namespace jit;
  17. using namespace ast_c;
  18. namespace {
  19. ASTPtr gen_powc(ASTPtr inp, float exp) {
  20. auto int_neg = [exp](ASTPtr x) {
  21. if (exp < 0) {
  22. return 1.f / x;
  23. }
  24. return x;
  25. };
  26. if (almost_equal(std::abs(exp), 0.f)) {
  27. return 1.f;
  28. }
  29. if (almost_equal(std::abs(exp), 1.f)) {
  30. return int_neg(inp);
  31. }
  32. if (almost_equal(std::abs(exp), 2.f)) {
  33. return int_neg(inp * inp);
  34. }
  35. if (almost_equal(std::abs(exp), 3.f)) {
  36. return int_neg(inp * inp * inp);
  37. }
  38. if (almost_equal(exp, 1.f / 3.f)) {
  39. return make_call("cbrtf", {inp});
  40. }
  41. if (almost_equal(exp, -1.f / 3.f)) {
  42. return make_call("rcbrtf", {inp});
  43. }
  44. if (almost_equal(exp, .5f)) {
  45. return make_call("sqrtf", {inp});
  46. }
  47. if (almost_equal(exp, -.5f)) {
  48. return make_call("rsqrtf", {inp});
  49. }
  50. int exp_i = std::round(exp);
  51. if (almost_equal(static_cast<float>(exp_i), exp)) {
  52. auto inp_abs = make_call("fabsf", {inp});
  53. if (exp_i & 1) {
  54. auto pow = make_call("powf", {inp_abs, exp});
  55. return make_call("copysign", {pow, inp});
  56. } else {
  57. return make_call("powf", {inp_abs, exp});
  58. }
  59. }
  60. return make_call("powf", {inp, exp});
  61. }
  62. } // anonymous namespace
  63. const ElemGeneratorMap& ast_c::elem_opr_generator() {
  64. #define ENTRY(_mode, _impl) \
  65. { \
  66. ElemMode::_mode, { \
  67. [](const ASTPtrArray& inps) -> ASTPtrArray { return {_impl}; } \
  68. } \
  69. }
  70. static ElemGeneratorMap map = {
  71. // unary
  72. ENTRY(RELU, make_call("fmaxf", {inps[0], 0.f})),
  73. ENTRY(ABS, make_call("fabsf", inps)),
  74. ENTRY(ACOS, make_call("acosf", inps)),
  75. ENTRY(ASIN, make_call("asinf", inps)),
  76. ENTRY(CEIL, make_call("ceilf", inps)),
  77. ENTRY(COS, make_call("cosf", inps)),
  78. ENTRY(EXP, make_call("expf", inps)),
  79. ENTRY(EXPM1, make_call("expm1f", inps)),
  80. ENTRY(FLOOR, make_call("floorf", inps)),
  81. ENTRY(LOG, make_call("logf", inps)),
  82. ENTRY(LOG1P, make_call("log1pf", inps)),
  83. ENTRY(NEGATE, make_call("-", inps)),
  84. ENTRY(SIGMOID, 1 / (1 + make_call("expf", {0 - inps[0]}))),
  85. ENTRY(SIN, make_call("sinf", inps)),
  86. ENTRY(TANH, make_call("tanhf", inps)),
  87. ENTRY(ERF, make_call("erff", inps)),
  88. ENTRY(ERFC, make_call("erfcf", inps)),
  89. ENTRY(H_SWISH,
  90. inps[0] *
  91. make_call("fmaxf",
  92. {make_call("fminf", {inps[0] + 3.f, 6.f}),
  93. 0.f}) /
  94. 6.f),
  95. // binary
  96. ENTRY(ABS_GRAD,
  97. ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], -inps[1])),
  98. ENTRY(ADD, inps[0] + inps[1]),
  99. ENTRY(FLOOR_DIV, make_call("floorf", {inps[0] / inps[1]})),
  100. ENTRY(MAX, make_call("fmaxf", inps)),
  101. ENTRY(MIN, make_call("fminf", inps)),
  102. ENTRY(MOD, make_call("fmodf", inps)),
  103. ENTRY(MUL, inps[0] * inps[1]),
  104. ENTRY(POW, make_call("powf", inps)),
  105. ENTRY(SIGMOID_GRAD, inps[0] * (1 - inps[0]) * inps[1]),
  106. ENTRY(SUB, inps[0] - inps[1]),
  107. ENTRY(SWITCH_GT0, ASTPtr::make<Cond3AST>(inps[0] > 0, inps[1], 0)),
  108. ENTRY(TANH_GRAD, (1 - inps[0] * inps[0]) * inps[1]),
  109. ENTRY(TRUE_DIV, inps[0] / inps[1]),
  110. ENTRY(LOG_SUM_EXP,
  111. make_call("mgb_log_sum_exp", {inps[0], inps[1]})),
  112. ENTRY(LT, ASTPtr::make<BinaryAST>("<", inps[0], inps[1])),
  113. ENTRY(LEQ, ASTPtr::make<BinaryAST>("<=", inps[0], inps[1])),
  114. ENTRY(EQ, ASTPtr::make<BinaryAST>("==", inps[0], inps[1])),
  115. ENTRY(ATAN2, make_call("atan2f", inps)),
  116. ENTRY(H_SWISH_GRAD,
  117. ASTPtr::make<Cond3AST>(
  118. -inps[0] > 3.f, 0.f,
  119. ASTPtr::make<Cond3AST>(
  120. inps[0] > 3.f, inps[1],
  121. (2.f * inps[0] + 3.f) * inps[1] / 6.f))),
  122. // misc
  123. ENTRY(COND_LEQ_MOV,
  124. ASTPtr::make<BinaryAST>("<=", inps[0], inps[1]) * inps[2]),
  125. ENTRY(FUSE_MUL_ADD3, inps[0] * inps[1] + inps[2]),
  126. ENTRY(FUSE_MUL_ADD4, inps[0] * inps[1] + inps[2] * inps[3]),
  127. ENTRY(FUSE_ADD_RELU, make_call("fmaxf", {inps[0] + inps[1], 0})),
  128. ENTRY(FUSE_ADD_SIGMOID,
  129. 1 / (1 + make_call("expf", {-(inps[0] + inps[1])}))),
  130. ENTRY(FUSE_ADD_TANH, make_call("tanhf", {inps[0] + inps[1]})),
  131. ENTRY(FUSE_ADD_H_SWISH,
  132. (inps[0] + inps[1]) *
  133. make_call(
  134. "fmaxf",
  135. {make_call("fminf",
  136. {(inps[0] + inps[1]) + 3.f, 6.f}),
  137. 0.f}) /
  138. 6.f),
  139. };
  140. mgb_assert(map.size() + 12 == opr::Elemwise::Param::MODE_NR_MEMBER);
  141. // unimplemented modes: SHL, SHR, FAST_TANH, FAST_TANH_GRAD, ROUND, RMULH,
  142. // ERFINV, ERFCINV, NOT, AND, OR, XOR
  143. return map;
  144. #undef ADD_OPR
  145. }
  146. ASTPtrArray ast_c::opr2AST(cg::OperatorNodeBase* opr,
  147. const ASTPtrArray& inputs) {
  148. using namespace opr;
  149. if (auto elem = gopt::try_cast_as_op<Elemwise>(opr)) {
  150. if (check_elem_mode(elem->param().mode)) {
  151. return elem_opr_generator()
  152. .find(elem->param().mode)
  153. ->second(inputs);
  154. }
  155. }
  156. if (auto powc = gopt::try_cast_as_op<PowC>(opr)) {
  157. mgb_assert(inputs.size() == 1);
  158. return {gen_powc(inputs[0], powc->param().exp)};
  159. }
  160. auto imm = SymbolVar{opr->output(0)}.as_immutable_scalar();
  161. if (imm.valid()) {
  162. auto dtype = imm->dtype();
  163. if (dtype == dtype::Int32{}) {
  164. return {ASTPtr::make<IntAST>(imm->get<int>())};
  165. }
  166. float scalar_value;
  167. if (dtype == dtype::Float32()) {
  168. scalar_value = imm->get<float>();
  169. } else if (dtype == dtype::Float16()) {
  170. scalar_value = imm->get<dt_float16>();
  171. } else {
  172. mgb_throw(InternalError,
  173. "dtype(%s) is not any of [Float16, Float32, Int32]",
  174. dtype.name());
  175. }
  176. return {ASTPtr::make<FloatAST>(scalar_value)};
  177. }
  178. if (opr->same_type<opr::TypeCvt>()) {
  179. // simply ignore TypeCvt oprs.
  180. mgb_assert(inputs.size() == 1);
  181. return inputs;
  182. }
  183. mgb_throw(InternalError, "unknown opr %s{%s}", opr->cname(),
  184. opr->dyn_typeinfo()->name);
  185. }
  186. #endif // MGB_JIT
  187. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台