You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

each_mode.cpp 17 kB


  1. #include "megbrain_build_config.h"
  2. #if MGB_JIT && MGB_JIT_MLIR
  3. #include "./common.h"
  4. #include "./each_mode.h"
  5. #include "./numerical.h"
  6. #include "./types.h"
  7. #include "megbrain/common.h"
  8. #include "megbrain/exception.h"
  9. #include "megbrain/jit/mlir/ir/dialect.h"
  10. #include <llvm/Support/raw_ostream.h>
  11. #include <mlir/Dialect/StandardOps/IR/Ops.h>
  12. namespace mgb {
  13. namespace jit {
  14. using Mode = megdnn::param::Elemwise::Mode;
  15. template <Mode mode>
  16. mlir::Value lower_mode(
  17. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands);
  18. /* ===================== trivial implementations ===================== */
  19. #define cb(mode, fun) \
  20. template <> \
  21. mlir::Value lower_mode<Mode::mode>( \
  22. mlir::OpBuilder & builder, mlir::Location loc, ValueRange operands) { \
  23. ValueBuilderHelper helper(builder, loc); \
  24. return helper.fun(operands); \
  25. }
  26. //! unary
  27. cb(ABS, abs);
  28. cb(CEIL, ceil);
  29. cb(COS, cos);
  30. cb(EXP, exp);
  31. cb(FLOOR, floor);
  32. cb(LOG, log);
  33. cb(NEGATE, neg);
  34. cb(SIN, sin);
  35. cb(TANH, tanh);
  36. //! binary
  37. cb(ADD, add);
  38. cb(MAX, max);
  39. cb(MIN, min);
  40. cb(MOD, mod);
  41. cb(MUL, mul);
  42. cb(SUB, sub);
  43. cb(TRUE_DIV, div);
  44. #undef cb
  45. /* ===================== unary op ===================== */
  46. //! ACOS: pi / 2 - arctan2(x, sqrt(1 - x * x))
  47. template <>
  48. mlir::Value lower_mode<Mode::ACOS>(
  49. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  50. ValueBuilderHelper helper(builder, loc);
  51. auto x = operands[0];
  52. auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
  53. auto asin = atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
  54. auto pi_over_2 = helper.const_f32(1.57079637f);
  55. return helper.sub(pi_over_2, asin);
  56. }
  57. //! ASIN: arctan2(x, sqrt(1 - x * x))
  58. template <>
  59. mlir::Value lower_mode<Mode::ASIN>(
  60. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  61. ValueBuilderHelper helper(builder, loc);
  62. auto x = operands[0];
  63. auto one_minus_x_2 = helper.sub(helper.const_f32(1.f), helper.mul(x, x));
  64. return atan2_approx(helper, x, helper.sqrt(one_minus_x_2));
  65. }
  66. //! ERFCINV: inverse of complementary gauss error function
  67. //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
  68. template <>
  69. mlir::Value lower_mode<Mode::ERFCINV>(
  70. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  71. ValueBuilderHelper helper(builder, loc);
  72. auto minus_sqrt2 = helper.const_f32(-1.4142135623f);
  73. auto x = helper.mul(helper.const_f32(0.5f), operands[0]);
  74. return helper.div(ndtri_approx(helper, x), minus_sqrt2);
  75. }
  76. //! ERFC: complementary error function
  77. template <>
  78. mlir::Value lower_mode<Mode::ERFC>(
  79. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  80. ValueBuilderHelper helper(builder, loc);
  81. return helper.sub(helper.const_f32(1.f), erf_approx(helper, operands[0]));
  82. }
  83. //! ERFINV: inverse of gauss error function
  84. //! https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
  85. template <>
  86. mlir::Value lower_mode<Mode::ERFINV>(
  87. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  88. ValueBuilderHelper helper(builder, loc);
  89. auto sqrt2 = helper.const_f32(1.4142135623f);
  90. auto x = helper.mul(
  91. helper.const_f32(0.5f), helper.add(operands[0], helper.const_f32(1.f)));
  92. return helper.div(ndtri_approx(helper, x), sqrt2);
  93. }
  94. //! ERF: gauss error function
  95. template <>
  96. mlir::Value lower_mode<Mode::ERF>(
  97. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  98. ValueBuilderHelper helper(builder, loc);
  99. return erf_approx(helper, operands[0]);
  100. }
  101. //! EXPM1: exp(x) - 1
  102. template <>
  103. mlir::Value lower_mode<Mode::EXPM1>(
  104. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  105. ValueBuilderHelper helper(builder, loc);
  106. return helper.sub(helper.exp(operands[0]), helper.const_f32(1.f));
  107. }
  108. //! FAST_TANH: x * (27.f + x * x) / (27.f + 9.f * x * x);
  109. template <>
  110. mlir::Value lower_mode<Mode::FAST_TANH>(
  111. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  112. ValueBuilderHelper helper(builder, loc);
  113. auto square = helper.mul(operands[0], operands[0]);
  114. return helper.div(
  115. helper.mul(operands[0], helper.add(helper.const_f32(27.f), square)),
  116. helper.add(
  117. helper.const_f32(27.f), helper.mul(helper.const_f32(9.f), square)));
  118. }
  119. //! H_SWISH: x * clip(x + 3, 0, 6) / 6
  120. template <>
  121. mlir::Value lower_mode<Mode::H_SWISH>(
  122. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  123. ValueBuilderHelper helper(builder, loc);
  124. auto const_3 = helper.const_f32(3.f);
  125. auto const_0 = helper.const_f32(0.f);
  126. auto const_6 = helper.const_f32(6.f);
  127. auto tmp = helper.add(operands[0], const_3);
  128. return helper.div(
  129. helper.mul(operands[0], helper.min(helper.max(tmp, const_0), const_6)),
  130. const_6);
  131. }
  132. //! LOG1P: log(1 + p)
  133. template <>
  134. mlir::Value lower_mode<Mode::LOG1P>(
  135. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  136. ValueBuilderHelper helper(builder, loc);
  137. return helper.log(helper.add(operands[0], helper.const_f32(1.f)));
  138. }
  139. //! RELU: max(x, 0)
  140. template <>
  141. mlir::Value lower_mode<Mode::RELU>(
  142. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  143. ValueBuilderHelper helper(builder, loc);
  144. return helper.max(operands[0], helper.const_f32(0.f));
  145. }
  146. //! ROUND
  147. template <>
  148. mlir::Value lower_mode<Mode::ROUND>(
  149. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  150. ValueBuilderHelper helper(builder, loc);
  151. return helper.select(
  152. helper.gt(operands[0], helper.const_f32(0.f)),
  153. helper.floor(helper.add(operands[0], helper.const_f32(0.5f))),
  154. helper.ceil(helper.sub(operands[0], helper.const_f32(0.5f))));
  155. }
  156. //! SIGMOID: 1.f / (expf(-y) + 1.f))
  157. template <>
  158. mlir::Value lower_mode<Mode::SIGMOID>(
  159. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  160. ValueBuilderHelper helper(builder, loc);
  161. return helper.div(
  162. helper.const_f32(1.f),
  163. helper.add(helper.exp(helper.neg(operands[0])), helper.const_f32(1.f)));
  164. }
  165. /* ===================== binary op ===================== */
  166. //! ABS_GRAD: x > 0 ? y : -y
  167. template <>
  168. mlir::Value lower_mode<Mode::ABS_GRAD>(
  169. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  170. ValueBuilderHelper helper(builder, loc);
  171. return helper.select(
  172. helper.gt(operands[0], helper.const_f32(0.f)), operands[1],
  173. helper.neg(operands[1]));
  174. }
  175. //! ATAN2
  176. template <>
  177. mlir::Value lower_mode<Mode::ATAN2>(
  178. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  179. ValueBuilderHelper helper(builder, loc);
  180. return atan2_approx(helper, operands[0], operands[1]);
  181. }
  182. //! EQ: x == y ? 1 : 0
  183. template <>
  184. mlir::Value lower_mode<Mode::EQ>(
  185. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  186. ValueBuilderHelper helper(builder, loc);
  187. return helper.select(
  188. helper.eq(operands[0], operands[1]), helper.const_f32(1.f),
  189. helper.const_f32(0.f));
  190. }
  191. //! FAST_TANH_GRAD: ((-48.f * x * x) / (3.f + x * x) + 27.f + x * x) / (3.f + x
  192. //! * x) * y
  193. template <>
  194. mlir::Value lower_mode<Mode::FAST_TANH_GRAD>(
  195. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  196. ValueBuilderHelper helper(builder, loc);
  197. auto x_pow2 = helper.mul(operands[0], operands[0]);
  198. auto deno = helper.add(helper.const_f32(3.f), x_pow2);
  199. return helper.mul(
  200. helper.div(
  201. helper.add(
  202. helper.add(
  203. helper.div(
  204. helper.mul(helper.const_f32(-48.f), x_pow2),
  205. deno),
  206. helper.const_f32(27.f)),
  207. x_pow2),
  208. helper.mul(deno, helper.const_f32(9.f))),
  209. operands[1]);
  210. }
  211. //! FLOOR_DIV: floor(x/y)
  212. template <>
  213. mlir::Value lower_mode<Mode::FLOOR_DIV>(
  214. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  215. ValueBuilderHelper helper(builder, loc);
  216. return helper.floor(helper.div(operands[0], operands[1]));
  217. }
  218. //! FUSE_ADD_H_SWISH: (x+y) * min(max(x + y + 3, 0), 6) * (1/6)
  219. template <>
  220. mlir::Value lower_mode<Mode::FUSE_ADD_H_SWISH>(
  221. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  222. ValueBuilderHelper helper(builder, loc);
  223. auto sum = helper.add(operands[0], operands[1]);
  224. auto const_3 = helper.const_f32(3.f);
  225. auto const_0 = helper.const_f32(0.f);
  226. auto const_6 = helper.const_f32(6.f);
  227. auto tmp = helper.add(sum, const_3);
  228. return helper.div(
  229. helper.mul(sum, helper.min(helper.max(tmp, const_0), const_6)), const_6);
  230. }
  231. //! FUSE_ADD_RELU: (x + y) <= ctype(0) ? ctype(0) : (x + y)
  232. template <>
  233. mlir::Value lower_mode<Mode::FUSE_ADD_RELU>(
  234. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  235. ValueBuilderHelper helper(builder, loc);
  236. auto sum = helper.add(operands[0], operands[1]);
  237. return helper.max(sum, helper.const_f32(0.f));
  238. }
  239. //! FUSE_ADD_SIGMOID: 1.f / (expf(-(x+y)) + 1.f))
  240. template <>
  241. mlir::Value lower_mode<Mode::FUSE_ADD_SIGMOID>(
  242. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  243. ValueBuilderHelper helper(builder, loc);
  244. return helper.div(
  245. helper.const_f32(1.f),
  246. helper.add(
  247. helper.exp(helper.neg(helper.add(operands[0], operands[1]))),
  248. helper.const_f32(1.f)));
  249. }
  250. //! FUSE_ADD_TANH: tanh(x + y)
  251. template <>
  252. mlir::Value lower_mode<Mode::FUSE_ADD_TANH>(
  253. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  254. ValueBuilderHelper helper(builder, loc);
  255. return helper.tanh(helper.add(operands[0], operands[1]));
  256. }
  257. //! H_SWISH_GRAD: x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y)
  258. template <>
  259. mlir::Value lower_mode<Mode::H_SWISH_GRAD>(
  260. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  261. ValueBuilderHelper helper(builder, loc);
  262. return helper.select(
  263. helper.lt(operands[0], helper.const_f32(-3.f)), helper.const_f32(0.f),
  264. helper.select(
  265. helper.gt(operands[0], helper.const_f32(3.f)), operands[1],
  266. helper.mul(
  267. helper.div(
  268. helper.add(
  269. helper.mul(
  270. helper.const_f32(2.f), operands[0]),
  271. helper.const_f32(3.f)),
  272. helper.const_f32(6.f)),
  273. operands[1])));
  274. }
  275. //! LEQ: x <= y ? 1 : 0
  276. template <>
  277. mlir::Value lower_mode<Mode::LEQ>(
  278. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  279. ValueBuilderHelper helper(builder, loc);
  280. return helper.select(
  281. helper.le(operands[0], operands[1]), helper.const_f32(1.f),
  282. helper.const_f32(0.f));
  283. }
  284. //! LOG_SUM_EXP: log(exp(x) + exp(y))
  285. template <>
  286. mlir::Value lower_mode<Mode::LOG_SUM_EXP>(
  287. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  288. ValueBuilderHelper helper(builder, loc);
  289. return helper.log(helper.add(helper.exp(operands[0]), helper.exp(operands[1])));
  290. }
  291. //! LT: x < y ? 1 : 0
  292. template <>
  293. mlir::Value lower_mode<Mode::LT>(
  294. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  295. ValueBuilderHelper helper(builder, loc);
  296. return helper.select(
  297. helper.lt(operands[0], operands[1]), helper.const_f32(1.f),
  298. helper.const_f32(0.f));
  299. }
  300. //! POW: x^y = exp(y * log(x))
  301. template <>
  302. mlir::Value lower_mode<Mode::POW>(
  303. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  304. ValueBuilderHelper helper(builder, loc);
  305. return helper.exp(helper.mul(operands[1], helper.log(operands[0])));
  306. }
  307. //! SIGMOID_GRAD: x * (1 - x) * y
  308. template <>
  309. mlir::Value lower_mode<Mode::SIGMOID_GRAD>(
  310. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  311. ValueBuilderHelper helper(builder, loc);
  312. return helper.mul(
  313. helper.mul(operands[0], helper.sub(helper.const_f32(1.f), operands[0])),
  314. operands[1]);
  315. }
  316. //! SWITCH_GT0: (x > 0) * y
  317. template <>
  318. mlir::Value lower_mode<Mode::SWITCH_GT0>(
  319. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  320. ValueBuilderHelper helper(builder, loc);
  321. return helper.select(
  322. helper.gt(operands[0], helper.const_f32(0.f)), operands[1],
  323. helper.const_f32(0.f));
  324. }
  325. //! TANH_GRAD: (1 - x * x) * y
  326. template <>
  327. mlir::Value lower_mode<Mode::TANH_GRAD>(
  328. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  329. ValueBuilderHelper helper(builder, loc);
  330. return helper.mul(
  331. helper.sub(helper.const_f32(1.0f), helper.mul(operands[0], operands[0])),
  332. operands[1]);
  333. }
  334. /* ===================== ternary op ===================== */
  335. //! COND_LEQ_MOV: x <= y ? z : ctype(0)
  336. template <>
  337. mlir::Value lower_mode<Mode::COND_LEQ_MOV>(
  338. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  339. ValueBuilderHelper helper(builder, loc);
  340. return helper.select(
  341. helper.le(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
  342. }
  343. //! COND_LT_MOV: x < y ? z : ctype(0)
  344. template <>
  345. mlir::Value lower_mode<Mode::COND_LT_MOV>(
  346. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  347. ValueBuilderHelper helper(builder, loc);
  348. return helper.select(
  349. helper.lt(operands[0], operands[1]), operands[2], helper.const_f32(0.f));
  350. }
  351. //! FUSE_MUL_ADD3: x * y + z
  352. template <>
  353. mlir::Value lower_mode<Mode::FUSE_MUL_ADD3>(
  354. mlir::OpBuilder& builder, mlir::Location loc, ValueRange operands) {
  355. ValueBuilderHelper helper(builder, loc);
  356. return helper.add(helper.mul(operands[0], operands[1]), operands[2]);
  357. }
  358. /* ===================== elemwise ===================== */
  359. mlir::Value lower_elemwise_to_std(
  360. mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc,
  361. ValueRange operands) {
  362. auto mode = llvm::dyn_cast<dialect::Elemwise>(op).mode();
  363. switch (mode) {
  364. #define cb(_, _mode) \
  365. case Mode::_mode: \
  366. return lower_mode<Mode::_mode>(builder, loc, operands);
  367. MLIR_MGB_FOREACH_ELEMWISE_MODE_UNARY(cb);
  368. MLIR_MGB_FOREACH_ELEMWISE_MODE_BINARY(cb);
  369. MLIR_MGB_FOREACH_ELEMWISE_MODE_TERNARY(cb);
  370. default:
  371. return nullptr;
  372. }
  373. #undef cb
  374. }
  375. /* ===================== typecvt ===================== */
  376. mlir::Value lower_typecvt_to_std(
  377. mlir::Operation* op, mlir::OpBuilder& builder, mlir::Location loc,
  378. mlir::Value input) {
  379. auto&& typecvt = llvm::dyn_cast<dialect::TypeCvt>(op);
  380. mlir::Type idtype = typecvt.idtype();
  381. mlir::Type odtype =
  382. megdnn_dtype_to_mlir_type(typecvt.dtype(), builder.getContext());
  383. mlir::Type itype = input.getType();
  384. mlir::Type otype = signless(odtype);
  385. mgb_assert(signless(idtype) == itype);
  386. if (mlir::FPExtOp::areCastCompatible(itype, otype)) {
  387. return builder.create<mlir::FPExtOp>(loc, otype, input);
  388. } else if (mlir::FPTruncOp::areCastCompatible(itype, otype)) {
  389. return builder.create<mlir::FPTruncOp>(loc, otype, input);
  390. } else if (
  391. mlir::FPToSIOp::areCastCompatible(itype, otype) and
  392. odtype.isSignedInteger()) {
  393. return builder.create<mlir::FPToSIOp>(loc, otype, input);
  394. } else if (
  395. mlir::FPToUIOp::areCastCompatible(itype, otype) and
  396. odtype.isUnsignedInteger()) {
  397. return builder.create<mlir::FPToUIOp>(loc, otype, input);
  398. } else if (
  399. mlir::SIToFPOp::areCastCompatible(itype, otype) and
  400. idtype.isSignedInteger()) {
  401. return builder.create<mlir::SIToFPOp>(loc, otype, input);
  402. } else if (
  403. mlir::UIToFPOp::areCastCompatible(itype, otype) and
  404. idtype.isUnsignedInteger()) {
  405. return builder.create<mlir::UIToFPOp>(loc, otype, input);
  406. } else {
  407. std::string tmp;
  408. llvm::raw_string_ostream os(tmp);
  409. os << "cannot convert from " << idtype << " to " << odtype;
  410. mgb_throw_raw(InternalError{tmp});
  411. }
  412. return nullptr;
  413. }
  414. } // namespace jit
  415. } // namespace mgb
  416. #endif // MGB_JIT && MGB_JIT_MLIR
  417. // vim: syntax=cpp.doxygen