You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

elemwise.cpp 44 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258
  1. #include "./erfinv.h"
  2. #include "megbrain/opr/basic_arith.h"
  3. #include "megbrain/opr/io.h"
  4. #include "megbrain/opr/tensor_manip.h"
  5. #include "megbrain/test/autocheck.h"
  6. #include "megbrain/test/helper.h"
  7. #include <algorithm>
  8. #include <cmath>
  9. using namespace mgb;
  10. namespace {
  11. using Mode = opr::Elemwise::Mode;
  12. using InputGenerator = Maybe<thin_function<void(HostTensorND&)>>;
  13. // msvc would check for callable of None, so we use this to replace None
  14. const InputGenerator NONE_INPUT_GEN;
  15. std::unordered_set<Mode, enumhash> tested_mode;
  16. /* ======================= opr special impls ======================= */
  17. float do_mod(float a, float b) {
  18. return std::fmod(a, b);
  19. }
  20. int do_mod(int a, int b) {
  21. return a % b;
  22. }
  23. float do_floor_div(float a, float b) {
  24. return std::floor(a / b);
  25. }
  26. int do_floor_div(int a, int b) {
  27. if ((a ^ b) < 0) {
  28. const auto quot = a / b;
  29. const auto rem = a % b;
  30. return rem ? quot - 1 : quot;
  31. }
  32. return a / b;
  33. }
  34. float do_erfinv(float x) {
  35. return erfinvf(x);
  36. }
  37. float do_erfcinv(float x) {
  38. return erfcinvf(x);
  39. }
  40. float do_h_swish(float x) {
  41. return x * fmaxf(fminf(x + 3.f, 6.f), 0.f) / 6.f;
  42. }
  43. float do_h_swish_grad(float x, float y) {
  44. return x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y);
  45. }
  46. template <typename T>
  47. T do_log_sum_exp(T a, T b) {
  48. return std::log(std::exp(a) + std::exp(b));
  49. }
  50. float do_fast_tanh(float x) {
  51. return x * (27.f + x * x) / (27.f + 9.f * x * x);
  52. }
  53. float do_fast_tanh_grad(float x, float y) {
  54. float x_pow2 = x * x;
  55. float deno = 3.f + x_pow2;
  56. return ((-48.f * x_pow2) / deno + 27.f + x_pow2) / (deno * 9.f) * y;
  57. }
  58. float do_fuse_add_h_swish(float x, float y) {
  59. float z = x + y;
  60. return z * fmaxf(fminf(z + 3.f, 6.f), 0.f) / 6.f;
  61. }
  62. template <typename T>
  63. T do_shl(T, T); // undefined
  64. template <typename T>
  65. T do_shr(T, T); // undefined
  66. int do_shl(int x, int y) {
  67. return x << y;
  68. }
  69. int do_shr(int x, int y) {
  70. return x >> y;
  71. }
  72. template <typename T>
  73. struct MulType {};
  74. template <>
  75. struct MulType<int8_t> {
  76. typedef int16_t type;
  77. };
  78. template <>
  79. struct MulType<int16_t> {
  80. typedef int32_t type;
  81. };
  82. template <>
  83. struct MulType<int32_t> {
  84. typedef int64_t type;
  85. };
  86. template <>
  87. struct MulType<uint8_t> {
  88. typedef uint16_t type;
  89. };
  90. template <typename T>
  91. T rounding_shift_right_upward(T x, int k) {
  92. T mask = (T(1) << k) - 1;
  93. T threshold = mask >> 1;
  94. return (x >> k) + ((x & mask) > threshold);
  95. }
  96. template <typename T>
  97. T do_round_mulh_saturate(T a, T b) {
  98. MEGDNN_STATIC_ASSERT(
  99. std::numeric_limits<T>::digits <= 32,
  100. "Portable RMULH is not supported for integer "
  101. "types larger than 32 bits.")
  102. MEGDNN_STATIC_ASSERT(
  103. std::numeric_limits<T>::is_integer,
  104. "Input types should be integer for RMULH")
  105. bool overflow = a == b && a == DTypeTrait<T>::min();
  106. // TODO: This really should be
  107. // rounding_shift_right_away_from_zero, but we haven't yet found a fast
  108. // way to implement it on ARM NEON. For now, we just try to align with
  109. // NEON's VQRDMULH and hope that it does not harm our NN badly.
  110. return overflow
  111. ? DTypeTrait<T>::max()
  112. : static_cast<T>(rounding_shift_right_upward(
  113. typename MulType<T>::type(a) * typename MulType<T>::type(b),
  114. std::numeric_limits<T>::digits));
  115. }
  116. float do_gelu_grad(float x, float y) {
  117. float phi = 1.f / sqrtf(2.0 * M_PI) * expf(-0.5f * x * x);
  118. float normcdf_v = 0.5f * (1.f + erff(x / sqrtf(2.f)));
  119. return y * (normcdf_v + x * phi);
  120. }
  121. /* ======================= basic framework ======================= */
  122. template <typename ctype, bool stable_sign = false>
  123. void gen_nozero(HostTensorND& dest) {
  124. static RNGxorshf rng{next_rand_seed()};
  125. auto ptr = dest.template ptr<ctype>();
  126. if (DTypeTrait<ctype>::category == DTypeCategory::FLOAT) {
  127. for (size_t i = 0, it = dest.shape().total_nr_elems(); i < it; ++i) {
  128. auto v = rng() / (rng.max() + 1.0) * 3 - 1.5;
  129. bool vsign = v > 0;
  130. if (stable_sign) {
  131. vsign = i % 2;
  132. }
  133. v = std::abs(v) + 0.1;
  134. ptr[i] = vsign ? v : -v;
  135. }
  136. } else {
  137. for (size_t i = 0, it = dest.shape().total_nr_elems(); i < it; ++i) {
  138. ctype v = rng() / (rng.max() + 1.0) * 65536 - 32767, vsat = i % 2 * 2 - 1;
  139. ptr[i] = v == 0 ? vsat : v;
  140. }
  141. }
  142. }
  143. template <class Trait>
  144. struct CheckerConfig {
  145. static constexpr bool enable_binary_inp_swap() { return true; }
  146. static constexpr bool allow_inp_grad(size_t idx) {
  147. MGB_MARK_USED_VAR(idx);
  148. return true;
  149. }
  150. template <typename ctype>
  151. static InputGenerator get_inp_gen(size_t idx) {
  152. MGB_MARK_USED_VAR(idx);
  153. return NONE_INPUT_GEN;
  154. }
  155. template <class Opt>
  156. static void update_opt(Opt& opt) {
  157. opt.numdiff_eps = 1e-2;
  158. }
  159. template <class Checker>
  160. static void update_checker(Checker& checker) {
  161. MGB_MARK_USED_VAR(checker);
  162. }
  163. };
  164. template <typename ctype>
  165. InputGenerator get_inp_gen_f32_range(float low, float high) {
  166. mgb_assert(std::is_same<ctype MGB_COMMA dt_float32>::value && high - low >= 0.1);
  167. auto gen = [low, high](HostTensorND& dest) {
  168. HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{low, high};
  169. dest = *gen(dest.shape());
  170. };
  171. return gen;
  172. }
  173. #define DEF_TRAIT(_mode, _expr) \
  174. struct _mode { \
  175. static constexpr size_t ARITY = _CUR_ARITY; \
  176. static constexpr Mode MODE = Mode::_mode; \
  177. static constexpr bool ALLOW_INT = _ALLOW_INT; \
  178. static constexpr bool ALLOW_FLOAT = _ALLOW_FLOAT; \
  179. static constexpr bool ALLOW_BOOL = _ALLOW_BOOL; \
  180. static constexpr const char* NAME = #_mode; \
  181. template <typename ctype> \
  182. static inline ctype apply(std::array<const ctype*, ARITY> inp, size_t idx) { \
  183. _EXPAND_PARAMS; \
  184. return _expr; \
  185. } \
  186. };
  187. #include "./elemwise_binary_trait_def.inl"
  188. #include "./elemwise_ternary_trait_def.inl"
  189. #include "./elemwise_unary_trait_def.inl"
  190. #undef DEF_TRAIT
  191. //! ensure nonzero value on some specific input
  192. template <size_t nozero_idx, bool large_eps = true>
  193. struct NoZeroCheckerConfig : public CheckerConfig<void> {
  194. static constexpr bool enable_binary_inp_swap() { return false; }
  195. template <typename ctype>
  196. static InputGenerator get_inp_gen(size_t idx) {
  197. if (idx != nozero_idx)
  198. return NONE_INPUT_GEN;
  199. return gen_nozero<ctype>;
  200. }
  201. template <class Opt>
  202. static void update_opt(Opt& opt) {
  203. if (large_eps)
  204. opt.numdiff_eps_single_inp[nozero_idx] = 0.05;
  205. }
  206. };
  207. struct NoGradCheckerConfig : public CheckerConfig<void> {
  208. static constexpr bool allow_inp_grad(size_t) { return false; }
  209. };
  210. /* ======================= unary config ======================= */
  211. template <>
  212. struct CheckerConfig<RELU> : public NoZeroCheckerConfig<0> {};
  213. template <>
  214. struct CheckerConfig<ABS> : public NoZeroCheckerConfig<0> {};
  215. template <>
  216. struct CheckerConfig<CEIL> : public NoGradCheckerConfig {};
  217. template <>
  218. struct CheckerConfig<FLOOR> : public NoGradCheckerConfig {};
  219. template <>
  220. struct CheckerConfig<ROUND> : public NoGradCheckerConfig {};
  221. template <>
  222. struct CheckerConfig<LOG> : public CheckerConfig<void> {
  223. template <typename ctype>
  224. static InputGenerator get_inp_gen(size_t) {
  225. return get_inp_gen_f32_range<ctype>(0.1, 4);
  226. }
  227. template <class Opt>
  228. static void update_opt(Opt& opt) {
  229. opt.numdiff_eps = 1e-2;
  230. opt.numdiff_max_err = 0.1;
  231. }
  232. };
  233. template <>
  234. struct CheckerConfig<LOG1P> : public CheckerConfig<void> {
  235. template <typename ctype>
  236. static InputGenerator get_inp_gen(size_t) {
  237. return get_inp_gen_f32_range<ctype>(-0.2, 0.2);
  238. }
  239. };
  240. template <>
  241. struct CheckerConfig<ACOS> : public CheckerConfig<void> {
  242. template <typename ctype>
  243. static InputGenerator get_inp_gen(size_t) {
  244. return get_inp_gen_f32_range<ctype>(-0.95, 0.95);
  245. }
  246. template <class Opt>
  247. static void update_opt(Opt& opt) {
  248. opt.numdiff_eps = 2e-3;
  249. opt.numdiff_max_err = 4e-3;
  250. }
  251. };
  252. template <>
  253. struct CheckerConfig<ASIN> : public CheckerConfig<ACOS> {};
  254. template <>
  255. struct CheckerConfig<TANH> : public CheckerConfig<void> {
  256. template <typename ctype>
  257. static InputGenerator get_inp_gen(size_t) {
  258. return get_inp_gen_f32_range<ctype>(-5, 5);
  259. }
  260. template <class Opt>
  261. static void update_opt(Opt& opt) {
  262. opt.numdiff_eps = 2e-2;
  263. }
  264. };
  265. template <>
  266. struct CheckerConfig<SIGMOID_GRAD> : public CheckerConfig<void> {
  267. template <class Opt>
  268. static void update_opt(Opt& opt) {
  269. opt.numdiff_eps = 2e-2;
  270. }
  271. };
  272. template <>
  273. struct CheckerConfig<ERF> : public CheckerConfig<void> {
  274. template <class Opt>
  275. static void update_opt(Opt& opt) {
  276. opt.numdiff_eps = 2e-2;
  277. }
  278. };
  279. template <>
  280. struct CheckerConfig<ERFINV> : public NoGradCheckerConfig {
  281. template <typename ctype>
  282. static InputGenerator get_inp_gen(size_t) {
  283. return get_inp_gen_f32_range<ctype>(-1, 1);
  284. }
  285. template <class Opt>
  286. static void update_opt(Opt& opt) {
  287. opt.numdiff_eps = 2e-2;
  288. }
  289. };
  290. template <>
  291. struct CheckerConfig<ERFC> : public CheckerConfig<void> {
  292. template <class Opt>
  293. static void update_opt(Opt& opt) {
  294. opt.numdiff_eps = 2e-2;
  295. }
  296. };
  297. template <>
  298. struct CheckerConfig<ERFCINV> : public NoGradCheckerConfig {
  299. template <typename ctype>
  300. static InputGenerator get_inp_gen(size_t) {
  301. return get_inp_gen_f32_range<ctype>(0, 2);
  302. }
  303. template <class Opt>
  304. static void update_opt(Opt& opt) {
  305. opt.numdiff_eps = 2e-2;
  306. }
  307. };
  308. template <>
  309. struct CheckerConfig<H_SWISH> : public CheckerConfig<void> {};
  310. template <>
  311. struct CheckerConfig<H_SWISH_GRAD> : public NoGradCheckerConfig {};
  312. template <>
  313. struct CheckerConfig<TAN> : public NoGradCheckerConfig {
  314. template <typename ctype>
  315. static InputGenerator get_inp_gen(size_t) {
  316. return get_inp_gen_f32_range<ctype>(-1.2, 1.2);
  317. }
  318. };
  319. template <>
  320. struct CheckerConfig<SINH> : public CheckerConfig<void> {
  321. template <typename ctype>
  322. static InputGenerator get_inp_gen(size_t) {
  323. return get_inp_gen_f32_range<ctype>(-5, 5);
  324. }
  325. template <class Opt>
  326. static void update_opt(Opt& opt) {
  327. opt.numdiff_eps = 1e-2;
  328. opt.numdiff_max_err = 0.1;
  329. }
  330. };
  331. template <>
  332. struct CheckerConfig<COSH> : public CheckerConfig<SINH> {};
  333. template <>
  334. struct CheckerConfig<ASINH> : public CheckerConfig<void> {
  335. template <class Opt>
  336. static void update_opt(Opt& opt) {
  337. opt.numdiff_eps = 1e-2;
  338. opt.numdiff_max_err = 0.1;
  339. }
  340. };
  341. template <>
  342. struct CheckerConfig<ACOSH> : public CheckerConfig<ASINH> {
  343. template <typename ctype>
  344. static InputGenerator get_inp_gen(size_t) {
  345. return get_inp_gen_f32_range<ctype>(1.05, 5);
  346. }
  347. };
  348. template <>
  349. struct CheckerConfig<ATANH> : public CheckerConfig<ASINH> {
  350. template <typename ctype>
  351. static InputGenerator get_inp_gen(size_t) {
  352. return get_inp_gen_f32_range<ctype>(-0.95, 0.95);
  353. }
  354. };
  355. template <>
  356. struct CheckerConfig<SOFTPLUS> : public CheckerConfig<void> {};
  357. template <>
  358. struct CheckerConfig<LOGSIGMOID> : public CheckerConfig<void> {};
  359. template <>
  360. struct CheckerConfig<SQUARE> : public CheckerConfig<void> {};
  361. template <>
  362. struct CheckerConfig<SQRT> : public CheckerConfig<void> {
  363. template <typename ctype>
  364. static InputGenerator get_inp_gen(size_t) {
  365. return get_inp_gen_f32_range<ctype>(0.05, 5);
  366. }
  367. template <class Opt>
  368. static void update_opt(Opt& opt) {
  369. opt.numdiff_eps = 1e-2;
  370. opt.numdiff_max_err = 0.1;
  371. }
  372. };
  373. template <>
  374. struct CheckerConfig<RELU6> : public CheckerConfig<void> {
  375. template <typename ctype, class Checker>
  376. static void do_update_checker(Checker& checker) {
  377. auto icoord = [](const typename Checker::NumInpArray& inp) {
  378. auto p0 = inp[0]->template ptr<ctype>();
  379. for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
  380. if (std::abs(p0[i]) < 1) {
  381. p0[i] += 2;
  382. } else if (std::abs(p0[i] - 6) < 1) {
  383. p0[i] += 2;
  384. }
  385. }
  386. };
  387. checker.set_input_coordinator(icoord);
  388. }
  389. template <class Checker>
  390. static void update_checker(Checker& checker) {
  391. using ctype = typename Checker::ctype;
  392. return do_update_checker<ctype>(checker);
  393. }
  394. };
  395. template <>
  396. struct CheckerConfig<HSIGMOID> : public CheckerConfig<void> {
  397. template <typename ctype>
  398. static InputGenerator get_inp_gen(size_t) {
  399. return get_inp_gen_f32_range<ctype>(-2.95, 2.95);
  400. }
  401. };
  402. template <>
  403. struct CheckerConfig<SIGN> : public NoZeroCheckerConfig<0> {};
  404. /* ======================= binary config ======================= */
  405. template <bool for_mod>
  406. struct BinaryInputMinGap : public CheckerConfig<void> {
  407. template <typename ctype, class Checker>
  408. static void do_update_checker(Checker& checker) {
  409. auto icoord = [](const typename Checker::NumInpArray& inp) {
  410. static const ctype GAP{for_mod ? 0.01f : 0.1f};
  411. if (DTypeTrait<ctype>::category != DTypeCategory::FLOAT)
  412. return;
  413. auto p0 = inp[0]->template ptr<ctype>(), p1 = inp[1]->template ptr<ctype>();
  414. for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
  415. if (for_mod) {
  416. auto p1v = std::abs(p1[i]), mod = std::fmod(p0[i], p1v);
  417. mod += mod < 0 ? p1v : 0;
  418. if (mod < GAP || mod > p1v - GAP) {
  419. mgb_assert(p1v > GAP * 4);
  420. ctype m0, m1;
  421. do {
  422. p0[i] += GAP;
  423. m0 = std::fmod(p0[i] - GAP, p1[i]);
  424. m1 = std::fmod(p0[i] + GAP, p1[i]);
  425. } while (std::abs(m1 - m0) > GAP * 2 + 1e-3);
  426. }
  427. } else {
  428. if (std::abs(p0[i] - p1[i]) < GAP) {
  429. p1[i] += p0[i] < p1[i] ? GAP : -GAP;
  430. }
  431. }
  432. }
  433. };
  434. checker.set_input_coordinator(icoord);
  435. }
  436. template <class Checker>
  437. static void update_checker(Checker& checker) {
  438. using ctype = typename Checker::ctype;
  439. if (std::is_integral<ctype>::value)
  440. return;
  441. if (std::is_same<ctype, dt_float16>::value)
  442. return do_update_checker<dt_float16>(checker);
  443. if (std::is_same<ctype, dt_float32>::value)
  444. return do_update_checker<dt_float32>(checker);
  445. mgb_assert(0);
  446. }
  447. };
  448. struct BinaryEQInput : public CheckerConfig<void> {
  449. static constexpr bool allow_inp_grad(size_t idx) { return idx >= 2; }
  450. template <class Checker>
  451. static void update_checker(Checker& checker) {
  452. using ctype = typename Checker::ctype;
  453. auto icoord = [](const typename Checker::NumInpArray& inp) {
  454. if (DTypeTrait<ctype>::category != DTypeCategory::FLOAT)
  455. return;
  456. auto p0 = inp[0]->template ptr<ctype>(), p1 = inp[1]->template ptr<ctype>();
  457. RNGxorshf rng{next_rand_seed()};
  458. for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
  459. p0[i] = rng() % 3 == 0 ? p1[i] : p0[i];
  460. }
  461. };
  462. checker.set_input_coordinator(icoord);
  463. }
  464. };
  465. struct BinaryPlaneNoPiInput : public CheckerConfig<void> {
  466. template <class Checker>
  467. static void update_checker(Checker& checker) {
  468. using ctype = typename Checker::ctype;
  469. auto icoord = [](const typename Checker::NumInpArray& inp) {
  470. if (DTypeTrait<ctype>::category != DTypeCategory::FLOAT)
  471. return;
  472. auto p0 = inp[0]->template ptr<ctype>(), p1 = inp[1]->template ptr<ctype>();
  473. RNGxorshf rng{next_rand_seed()};
  474. auto maxv = rng.max() + 1.0;
  475. for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
  476. //! To be numerical stable, r cannot be too small
  477. auto r = rng() / maxv * 2 + 0.5; //! radious
  478. //! Avoid pi value due to periodicity
  479. //! Numerical diff will be wrong there
  480. //! Range [-pi+eps, pi-eps]
  481. auto t = rng() / maxv * 3.1 * 2 - 3.1; //! angle
  482. //! First input is y in space
  483. p0[i] = r * std::sin(t);
  484. //! Second input is x in space
  485. p1[i] = r * std::cos(t);
  486. }
  487. };
  488. checker.set_input_coordinator(icoord);
  489. }
  490. static constexpr bool enable_binary_inp_swap() { return false; }
  491. };
  492. template <>
  493. struct CheckerConfig<ATAN2> : public BinaryPlaneNoPiInput {
  494. template <class Opt>
  495. static void update_opt(Opt& opt) {
  496. opt.numdiff_eps = 1e-3;
  497. opt.numdiff_max_err = 0.02;
  498. }
  499. };
  500. template <>
  501. struct CheckerConfig<ABS_GRAD> : public NoZeroCheckerConfig<0> {};
  502. template <>
  503. struct CheckerConfig<FLOOR_DIV> : public NoZeroCheckerConfig<1, false> {
  504. static constexpr bool allow_inp_grad(size_t) { return false; }
  505. };
  506. template <>
  507. struct CheckerConfig<TRUE_DIV> : public NoZeroCheckerConfig<1, false> {
  508. template <class Opt>
  509. static void update_opt(Opt& opt) {
  510. opt.numdiff_eps = 1e-2;
  511. opt.numdiff_max_err = 0.1;
  512. }
  513. };
  514. template <>
  515. struct CheckerConfig<EQ> : public BinaryEQInput {};
  516. template <>
  517. struct CheckerConfig<LEQ> : public NoGradCheckerConfig {};
  518. template <>
  519. struct CheckerConfig<LT> : public NoGradCheckerConfig {};
  520. template <>
  521. struct CheckerConfig<FUSE_ADD_H_SWISH> : public CheckerConfig<void> {};
  522. template <>
  523. struct CheckerConfig<SWITCH_GT0> : public NoZeroCheckerConfig<0> {};
  524. template <>
  525. struct CheckerConfig<POW> : public CheckerConfig<void> {
  526. static constexpr bool enable_binary_inp_swap() { return false; }
  527. template <class Opt>
  528. static void update_opt(Opt& opt) {
  529. opt.numdiff_eps = 1e-2;
  530. opt.numdiff_max_err = 0.06;
  531. }
  532. template <typename ctype>
  533. static InputGenerator get_inp_gen(size_t idx) {
  534. auto func = [](HostTensorND& dest) {
  535. dest = *HostTensorGenerator<typename DTypeTrait<ctype>::dtype>{}(
  536. dest.shape());
  537. auto ptr = dest.ptr<ctype>();
  538. for (size_t i = 0, t = dest.shape().total_nr_elems(); i < t; ++i) {
  539. ptr[i] = std::abs(ptr[i]) + 0.1;
  540. }
  541. };
  542. if (idx == 0)
  543. return func;
  544. return NONE_INPUT_GEN;
  545. }
  546. };
  547. template <>
  548. struct CheckerConfig<MAX> : public BinaryInputMinGap<false> {};
  549. template <>
  550. struct CheckerConfig<MIN> : public BinaryInputMinGap<false> {};
  551. template <>
  552. struct CheckerConfig<MOD> : public NoZeroCheckerConfig<1, false>,
  553. public BinaryInputMinGap<true> {
  554. using NoZeroCheckerConfig<1, false>::get_inp_gen;
  555. using NoZeroCheckerConfig<1, false>::enable_binary_inp_swap;
  556. using BinaryInputMinGap<true>::update_checker;
  557. template <class Opt>
  558. static void update_opt(Opt& opt) {
  559. opt.numdiff_eps = 0.003;
  560. }
  561. static constexpr bool allow_inp_grad(size_t idx) { return idx == 0; }
  562. };
  563. template <>
  564. struct CheckerConfig<SHL> : public CheckerConfig<void> {
  565. static constexpr bool enable_binary_inp_swap() { return false; }
  566. static constexpr bool allow_inp_grad(size_t idx) { return false; }
  567. template <typename ctype>
  568. static InputGenerator get_inp_gen(size_t);
  569. };
  570. template <>
  571. struct CheckerConfig<SHR> : public CheckerConfig<SHL> {};
  572. template <>
  573. InputGenerator CheckerConfig<SHL>::get_inp_gen<int>(size_t idx) {
  574. if (!idx)
  575. return NONE_INPUT_GEN;
  576. auto gen = [](HostTensorND& dest) {
  577. HostTensorGenerator<dtype::Int32, RandomDistribution::UNIFORM> gen{0, 32};
  578. dest = *gen(dest.shape());
  579. };
  580. return gen;
  581. }
  582. template <>
  583. struct CheckerConfig<FUSE_ADD_RELU> : public CheckerConfig<void> {
  584. template <typename ctype>
  585. static InputGenerator get_inp_gen(size_t) {
  586. return gen_nozero<ctype, true>;
  587. }
  588. };
  589. template <>
  590. struct CheckerConfig<FAST_TANH> : public CheckerConfig<void> {
  591. template <typename ctype>
  592. static InputGenerator get_inp_gen(size_t) {
  593. return get_inp_gen_f32_range<ctype>(0.1, 5);
  594. }
  595. };
  596. template <>
  597. struct CheckerConfig<FAST_TANH_GRAD> : public CheckerConfig<FAST_TANH> {
  598. static constexpr bool allow_inp_grad(size_t idx) {
  599. MGB_MARK_USED_VAR(idx);
  600. return false;
  601. }
  602. };
  603. template <>
  604. struct CheckerConfig<SILU_GRAD> : public NoGradCheckerConfig {};
  605. template <>
  606. struct CheckerConfig<GELU_GRAD> : public NoGradCheckerConfig {};
  607. template <>
  608. struct CheckerConfig<PRELU> : public NoZeroCheckerConfig<0> {};
  609. template <>
  610. struct CheckerConfig<ASINH_GRAD> : public NoGradCheckerConfig {};
  611. template <>
  612. struct CheckerConfig<ACOSH_GRAD> : public NoGradCheckerConfig {
  613. template <typename ctype>
  614. static InputGenerator get_inp_gen(size_t) {
  615. return get_inp_gen_f32_range<ctype>(1.05, 5);
  616. }
  617. };
  618. template <>
  619. struct CheckerConfig<ATANH_GRAD> : public NoGradCheckerConfig {
  620. template <typename ctype>
  621. static InputGenerator get_inp_gen(size_t) {
  622. return get_inp_gen_f32_range<ctype>(-0.95, 0.95);
  623. }
  624. };
  625. template <>
  626. struct CheckerConfig<RELU6_GRAD> : public NoGradCheckerConfig {};
  627. template <>
  628. struct CheckerConfig<SOFTPLUS_GRAD> : public NoGradCheckerConfig {};
  629. template <>
  630. struct CheckerConfig<HSIGMOID_GRAD> : public NoGradCheckerConfig {
  631. template <typename ctype>
  632. static InputGenerator get_inp_gen(size_t) {
  633. return get_inp_gen_f32_range<ctype>(-2.95, 2.95);
  634. }
  635. };
  636. /* ======================= ternary config ======================= */
  637. template <>
  638. struct CheckerConfig<COND_LEQ_MOV> : public BinaryInputMinGap<false> {};
  639. template <>
  640. struct CheckerConfig<COND_LT_MOV> : public BinaryInputMinGap<false> {};
  641. struct CheckerConfig<PRELU_GRAD> : public NoGradCheckerConfig {};
  642. template <>
  643. struct CheckerConfig<CLIP> : public CheckerConfig<void> {
  644. template <typename ctype, class Checker>
  645. static void do_update_checker(Checker& checker) {
  646. auto icoord = [](const typename Checker::NumInpArray& inp) {
  647. auto p0 = inp[0]->template ptr<ctype>(), p1 = inp[1]->template ptr<ctype>(),
  648. p2 = inp[2]->template ptr<ctype>();
  649. for (size_t i = 0, it = inp[0]->shape().total_nr_elems(); i < it; ++i) {
  650. if (p1[i] > p2[i]) {
  651. std::swap(p1[i], p2[i]);
  652. }
  653. if (p1[i] + 1 > p2[i]) {
  654. p2[i] = p1[i] + 1;
  655. }
  656. if (std::abs(p1[i] - p0[i]) < 1) {
  657. if (p1[i] < p0[i])
  658. p0[i] += 1;
  659. else
  660. p0[i] -= 1;
  661. }
  662. if (std::abs(p2[i] - p0[i]) < 1) {
  663. if (p2[i] < p0[i])
  664. p0[i] += 1;
  665. else
  666. p0[i] -= 1;
  667. }
  668. }
  669. };
  670. checker.set_input_coordinator(icoord);
  671. }
  672. template <class Checker>
  673. static void update_checker(Checker& checker) {
  674. using ctype = typename Checker::ctype;
  675. return do_update_checker<ctype>(checker);
  676. }
  677. template <class Opt>
  678. static void update_opt(Opt& opt) {
  679. opt.numdiff_eps = 1e-3;
  680. opt.numdiff_max_err = 0.1;
  681. }
  682. };
  683. /* ======================= test runner ======================= */
  684. namespace detail {
  685. template <typename dtype, class Trait>
  686. struct enable_for_dtype_impl;
  687. template <class Trait>
  688. struct enable_for_dtype_impl<dtype::Float32, Trait> {
  689. static constexpr bool value = Trait::ALLOW_FLOAT;
  690. };
  691. template <>
  692. struct enable_for_dtype_impl<dtype::Float32, void> {
  693. static constexpr bool value = false;
  694. };
  695. template <class Trait>
  696. struct enable_for_dtype_impl<dtype::Int32, Trait> {
  697. static constexpr bool value = Trait::ALLOW_INT;
  698. };
  699. template <>
  700. struct enable_for_dtype_impl<dtype::Int32, void> {
  701. static constexpr bool value = false;
  702. };
  703. template <class Trait>
  704. struct enable_for_dtype_impl<dtype::Bool, Trait> {
  705. static constexpr bool value = Trait::ALLOW_BOOL;
  706. };
  707. } // namespace detail
  708. //! whether to enable test for specific dtype and Trait
  709. template <typename dtype, class Trait>
  710. constexpr bool enable_for_dtype = detail::enable_for_dtype_impl<dtype, Trait>::value;
  711. template <typename Trait, typename dtype, bool enable = enable_for_dtype<dtype, Trait>>
  712. struct TestRunner;
  713. template <typename Trait, typename dtype>
  714. struct TestRunner<Trait, dtype, true> {
  715. static void run();
  716. };
  717. template <typename Trait, typename dtype>
  718. struct TestRunner<Trait, dtype, false> {
  719. static void run() {}
  720. };
  721. template <typename dtype>
  722. struct TestRunner<void, dtype, false> {
  723. static void run() {}
  724. };
  725. template <typename Trait>
  726. class TestOprBasicArithUnaryElemwise : public ::testing::Test {};
  727. template <typename Trait>
  728. class TestOprBasicArithBinaryElemwise : public ::testing::Test {};
  729. template <typename Trait>
  730. class TestOprBasicArithTernaryElemwise : public ::testing::Test {};
  731. typedef ::testing::Types<
  732. #define DEF_TRAIT(_mode, _expr) _mode,
  733. #include "./elemwise_unary_trait_def.inl"
  734. #undef DEF_TRAIT
  735. void // extra void to consume last comma
  736. >
  737. UnaryTraitTypes;
  738. TYPED_TEST_CASE(TestOprBasicArithUnaryElemwise, UnaryTraitTypes);
  739. typedef ::testing::Types<
  740. #define DEF_TRAIT(_mode, _expr) _mode,
  741. #include "./elemwise_binary_trait_def.inl"
  742. #undef DEF_TRAIT
  743. void // extra void to consume last comma
  744. >
  745. BinaryTraitTypes;
  746. TYPED_TEST_CASE(TestOprBasicArithBinaryElemwise, BinaryTraitTypes);
  747. typedef ::testing::Types<
  748. #define DEF_TRAIT(_mode, _expr) _mode,
  749. #include "./elemwise_ternary_trait_def.inl"
  750. #undef DEF_TRAIT
  751. void // extra void to consume last comma
  752. >
  753. TernaryTraitTypes;
  754. TYPED_TEST_CASE(TestOprBasicArithTernaryElemwise, TernaryTraitTypes);
  755. } // anonymous namespace
  756. template <typename Trait, typename dtype>
  757. void TestRunner<Trait, dtype, true>::run() {
  758. {
  759. Mode mode = Trait::MODE;
  760. // copy to temporary var to avoid undefined reference when linking
  761. tested_mode.insert(mode);
  762. }
  763. using ctype = typename DTypeTrait<dtype>::ctype;
  764. HostTensorGenerator<> gen;
  765. using Config = CheckerConfig<Trait>;
  766. static constexpr bool TEST_REV_INP =
  767. Trait::ARITY == 2 &&
  768. Config::allow_inp_grad(0) == Config::allow_inp_grad(1) &&
  769. Config::enable_binary_inp_swap();
  770. using Checker = AutoOprChecker<Trait::ARITY, TEST_REV_INP + 1, dtype>;
  771. auto make_graph = [&](const typename Checker::SymInpArray& inputs) {
  772. typename Checker::SymOutArray out;
  773. SymbolVarArray vinp(inputs.begin(), inputs.end());
  774. out[0] = opr::Elemwise::make(vinp, Trait::MODE);
  775. if (TEST_REV_INP) {
  776. std::swap(vinp[0], vinp[1]);
  777. out[1] = opr::Elemwise::make(vinp, Trait::MODE);
  778. }
  779. return out;
  780. };
  781. auto fwd = [&](typename Checker::NumOutArray& dest,
  782. typename Checker::NumInpArray inp) {
  783. dest[0].resize(inp[0]->shape());
  784. if (TEST_REV_INP)
  785. dest[1].resize(inp[0]->shape());
  786. std::array<const ctype*, Trait::ARITY> iptr;
  787. for (size_t i = 0; i < Trait::ARITY; ++i)
  788. iptr[i] = inp[i]->template ptr<ctype>();
  789. size_t sz = dest[0].shape().total_nr_elems();
  790. ctype* optr = dest[0].template ptr<ctype>();
  791. for (size_t i = 0; i < sz; ++i)
  792. optr[i] = Trait::apply(iptr, i);
  793. if (TEST_REV_INP) {
  794. std::swap(iptr[0], iptr[1]);
  795. ctype* optr = dest[1].template ptr<ctype>();
  796. for (size_t i = 0; i < sz; ++i)
  797. optr[i] = Trait::apply(iptr, i);
  798. }
  799. };
  800. Checker checker{make_graph, fwd};
  801. checker.set_extra_err_msg(ssprintf("mode=%s", Trait::NAME));
  802. for (size_t i = 0; i < Trait::ARITY; ++i) {
  803. auto func = Config::template get_inp_gen<ctype>(i);
  804. if (func.valid())
  805. checker.set_input_generator(i, func.val());
  806. checker.set_input_allow_grad(i, Config::allow_inp_grad(i));
  807. }
  808. TensorShape shapes[] = {{1}, {23, 3}, {666}};
  809. typename Checker::RunOptions opt;
  810. Config::update_opt(opt);
  811. Config::update_checker(checker);
  812. for (auto&& ishp : shapes) {
  813. typename Checker::ShapeInpArray inp;
  814. std::fill(inp.begin(), inp.end(), ishp);
  815. checker.run(inp, opt);
  816. }
  817. }
  818. TYPED_TEST(TestOprBasicArithUnaryElemwise, Int32) {
  819. TestRunner<TypeParam, dtype::Int32>::run();
  820. }
  821. TYPED_TEST(TestOprBasicArithBinaryElemwise, Int32) {
  822. TestRunner<TypeParam, dtype::Int32>::run();
  823. }
  824. TYPED_TEST(TestOprBasicArithTernaryElemwise, Int32) {
  825. TestRunner<TypeParam, dtype::Int32>::run();
  826. }
  827. TYPED_TEST(TestOprBasicArithUnaryElemwise, Float32) {
  828. set_rand_seed(19931102);
  829. TestRunner<TypeParam, dtype::Float32>::run();
  830. }
  831. TYPED_TEST(TestOprBasicArithBinaryElemwise, Float32) {
  832. set_rand_seed(19931150);
  833. TestRunner<TypeParam, dtype::Float32>::run();
  834. }
  835. TYPED_TEST(TestOprBasicArithTernaryElemwise, Float32) {
  836. set_rand_seed(19931102);
  837. TestRunner<TypeParam, dtype::Float32>::run();
  838. }
  839. TEST(TestOprBasicArithElemwise, CheckAllModeTested) {
  840. size_t nr_member = opr::Elemwise::Param::MODE_NR_MEMBER;
  841. ASSERT_EQ(nr_member, tested_mode.size() + 7);
  842. // Not using TestRunner: NOT, AND, OR, XOR, NEQ, ISNAN, ISINF
  843. }
  844. #define TEST_OPR_BASIC_ARITH_UNARY_BOOL(_mode, _op) \
  845. TEST(TestOprBasicArithElemwise, _mode) { \
  846. HostTensorGenerator<dtype::Bool> gen; \
  847. auto host_x = gen({2, 1}); \
  848. auto ptr = host_x->ptr<dt_bool>(); \
  849. for (size_t i = 0; i < 2; ++i) { \
  850. ptr[i] = (i & 1); \
  851. } \
  852. auto graph = ComputingGraph::make(); \
  853. using Mode = opr::Elemwise::Mode; \
  854. auto x = opr::Host2DeviceCopy::make(*graph, host_x), \
  855. y = opr::Elemwise::make({x}, Mode::_mode); \
  856. HostTensorND host_y; \
  857. auto func = graph->compile({make_callback_copy(y, host_y)}); \
  858. func->execute(); \
  859. ASSERT_EQ(TensorShape({2, 1}), host_y.shape()); \
  860. auto ptry = host_y.ptr<dt_bool>(); \
  861. for (int i = 0; i < 2; i++) { \
  862. ASSERT_EQ(_op ptr[i], ptry[i]); \
  863. } \
  864. }
  865. TEST_OPR_BASIC_ARITH_UNARY_BOOL(NOT, !)
  866. #define TEST_OPR_BASIC_ARITH_BINARY_BOOL(_mode, _op) \
  867. TEST(TestOprBasicArithElemwise, _mode) { \
  868. HostTensorGenerator<dtype::Bool> gen; \
  869. auto host_x1 = gen({2, 2}), host_x2 = gen({2, 2}); \
  870. auto ptr1 = host_x1->ptr<dt_bool>(), ptr2 = host_x2->ptr<dt_bool>(); \
  871. for (size_t i = 0; i < 4; ++i) { \
  872. ptr1[i] = (i < 2); \
  873. ptr2[i] = (i & 1); \
  874. } \
  875. auto graph = ComputingGraph::make(); \
  876. using Mode = opr::Elemwise::Mode; \
  877. auto x1 = opr::Host2DeviceCopy::make(*graph, host_x1), \
  878. x2 = opr::Host2DeviceCopy::make(*graph, host_x2), \
  879. y = opr::Elemwise::make({x1, x2}, Mode::_mode); \
  880. HostTensorND host_y; \
  881. auto func = graph->compile({make_callback_copy(y, host_y)}); \
  882. func->execute(); \
  883. ASSERT_EQ(TensorShape({2, 2}), host_y.shape()); \
  884. auto ptry = host_y.ptr<dt_bool>(); \
  885. for (int i = 0; i < 4; i++) { \
  886. ASSERT_EQ(ptr1[i] _op ptr2[i], ptry[i]); \
  887. } \
  888. }
  889. TEST_OPR_BASIC_ARITH_BINARY_BOOL(AND, &&)
  890. TEST_OPR_BASIC_ARITH_BINARY_BOOL(OR, ||)
  891. TEST_OPR_BASIC_ARITH_BINARY_BOOL(XOR, ^)
  892. TEST_OPR_BASIC_ARITH_BINARY_BOOL(LT, <)
  893. TEST_OPR_BASIC_ARITH_BINARY_BOOL(LEQ, <=)
  894. TEST_OPR_BASIC_ARITH_BINARY_BOOL(EQ, ==)
  895. TEST(TestOprBasicArithElemwise, FuseMulAdd3Shapes) {
  896. using Checker = AutoOprChecker<3, 1>;
  897. opr::Elemwise* opr;
  898. auto make_graph =
  899. [&](const typename Checker::SymInpArray& i) -> Checker::SymOutArray {
  900. i[0].node()->owner_graph()->options().graph_opt_level = 0;
  901. auto ret = opr::Elemwise::make(i, Mode::FUSE_MUL_ADD3);
  902. opr = &ret.node()->owner_opr()->cast_final_safe<opr::Elemwise>();
  903. return {ret};
  904. };
  905. auto fwd = [&](typename Checker::NumOutArray& dest,
  906. typename Checker::NumInpArray inp) {
  907. auto graph = ComputingGraph::make();
  908. graph->options().graph_opt_level = false;
  909. auto i = [&](size_t idx) {
  910. return opr::Host2DeviceCopy::make(*graph, inp[idx]);
  911. };
  912. auto ans = i(0) * i(1) + i(2);
  913. graph->compile({make_callback_copy(ans, dest[0])})->execute();
  914. };
  915. Checker checker{make_graph, fwd};
  916. checker.run({TensorShape{1, 2}, {2, 1}, {1, 2}})
  917. .run({TensorShape{1, 2}, {2, 1}, {1}});
  918. ASSERT_FALSE(opr->fuse_badlayout_warn_printed());
  919. checker.run({TensorShape{1, 1, 4}, {1, 3, 1}, {2, 1, 1}});
  920. ASSERT_TRUE(opr->fuse_badlayout_warn_printed());
  921. }
  922. TEST(TestOprBasicArithElemwise, FuseMulAdd4Shapes) {
  923. using Checker = AutoOprChecker<4, 1>;
  924. opr::Elemwise* opr;
  925. auto make_graph =
  926. [&](const typename Checker::SymInpArray& i) -> Checker::SymOutArray {
  927. i[0].node()->owner_graph()->options().graph_opt_level = 0;
  928. auto ret = opr::Elemwise::make(i, Mode::FUSE_MUL_ADD4);
  929. opr = &ret.node()->owner_opr()->cast_final_safe<opr::Elemwise>();
  930. return {ret};
  931. };
  932. auto fwd = [&](typename Checker::NumOutArray& dest,
  933. typename Checker::NumInpArray inp) {
  934. auto graph = ComputingGraph::make();
  935. graph->options().graph_opt_level = false;
  936. auto i = [&](size_t idx) {
  937. return opr::Host2DeviceCopy::make(*graph, inp[idx]);
  938. };
  939. auto ans = i(0) * i(1) + i(2) * i(3);
  940. graph->compile({make_callback_copy(ans, dest[0])})->execute();
  941. };
  942. Checker checker{make_graph, fwd};
  943. checker.run({TensorShape{1, 2}, {2, 1}, {1, 2}, {2, 1}})
  944. .run({TensorShape{1, 2, 1, 2, 1, 2},
  945. {2, 1, 2, 1, 2, 1},
  946. {2, 1, 2, 1, 2, 1},
  947. {1, 2, 1, 2, 1, 2}});
  948. ASSERT_FALSE(opr->fuse_badlayout_warn_printed());
  949. checker.run({TensorShape{1, 2}, {2, 1}, {2, 2}, {2, 2}});
  950. ASSERT_TRUE(opr->fuse_badlayout_warn_printed());
  951. }
  952. TEST(TestOprBasicArithElemwise, WritableFwdForSameStorage) {
  953. HostTensorGenerator<> gen;
  954. auto run = [&](int idx_val, bool should_overwrite) {
  955. auto host_x = gen({100});
  956. auto make_y = [&](ComputingGraph& graph) {
  957. using S = opr::Subtensor;
  958. auto x = opr::Host2DeviceCopy::make_no_fwd(graph, host_x),
  959. idx = x.make_scalar(idx_val),
  960. sub0 = S::make(x, {S::AxisIndexer::make_interval(0, None, idx, None)}),
  961. sub1 = S::make(
  962. x, {S::AxisIndexer::make_interval(0, -idx, None, None)}),
  963. y = sub0 + sub1;
  964. auto chk_overwrite = [sub0, sub1, y]() {
  965. auto py = y.node()->prev_dev_ptr();
  966. return sub0.node()->prev_dev_ptr() == py ||
  967. sub1.node()->prev_dev_ptr() == py;
  968. };
  969. return std::make_pair(y, chk_overwrite);
  970. };
  971. auto g0 = ComputingGraph::make(), g1 = ComputingGraph::make();
  972. g1->options().seq_opt.enable_mem_plan_opt = false;
  973. auto y0 = make_y(*g0), y1 = make_y(*g1);
  974. HostTensorND host_y0, host_y1;
  975. auto f0 = g0->compile({make_callback_copy(y0.first, host_y0)}),
  976. f1 = g1->compile({make_callback_copy(y1.first, host_y1)});
  977. f0->execute();
  978. f1->execute();
  979. ASSERT_EQ(host_y1.shape(), TensorShape{static_cast<size_t>(idx_val)});
  980. MGB_ASSERT_TENSOR_EQ(host_y1, host_y0);
  981. ASSERT_EQ(should_overwrite, y0.second());
  982. ASSERT_FALSE(y1.second());
  983. };
  984. run(10, true);
  985. run(90, false);
  986. }
  987. TEST(TestOprBasicArithElemwise, NonContigInput) {
  988. HostTensorGenerator<> gen;
  989. auto graph = ComputingGraph::make();
  990. constexpr size_t SIZE = 100;
  991. auto host_x = gen({SIZE});
  992. using S = opr::Subtensor;
  993. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  994. xsub = S::make(
  995. x, {S::AxisIndexer::make_interval(0, None, None, x.make_scalar(2))}),
  996. y = xsub + x.make_scalar(1.f);
  997. HostTensorND host_y;
  998. auto func = graph->compile({make_callback_copy(y, host_y)});
  999. func->execute();
  1000. ASSERT_FALSE(xsub.node()->dev_tensor().layout().is_contiguous());
  1001. ASSERT_EQ(SIZE / 2, host_y.layout().total_nr_elems());
  1002. auto px = host_x->ptr<float>(), py = host_y.ptr<float>();
  1003. for (size_t i = 0; i < SIZE / 2; ++i) {
  1004. MGB_ASSERT_FLOAT_EQ(px[i * 2] + 1, py[i]);
  1005. }
  1006. }
  1007. TEST(TestOprBasicArithElemwise, CommutableDedup) {
  1008. auto cn = CompNode::load("xpux");
  1009. auto graph = ComputingGraph::make();
  1010. auto host_x = std::make_shared<HostTensorND>(cn, TensorShape{100}),
  1011. host_y = std::make_shared<HostTensorND>(cn, TensorShape{100});
  1012. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  1013. y = opr::Host2DeviceCopy::make(*graph, host_y);
  1014. auto mk = [](Mode mode, SymbolVar x, SymbolVar y) {
  1015. return opr::Elemwise::make({x, y}, mode);
  1016. };
  1017. #define CHK(_a, _b) ASSERT_EQ((_a).node(), (_b).node())
  1018. CHK(x + y, y + x);
  1019. CHK(x * y, y * x);
  1020. CHK(mk(Mode::EQ, x, y), mk(Mode::EQ, y, x));
  1021. CHK(mk(Mode::MIN, x, y), mk(Mode::MIN, y, x));
  1022. CHK(mk(Mode::MAX, x, y), mk(Mode::MAX, y, x));
  1023. CHK(mk(Mode::LOG_SUM_EXP, x, y), mk(Mode::LOG_SUM_EXP, y, x));
  1024. CHK(x<y, y> x);
  1025. #undef CHK
  1026. ASSERT_NE((x - y).node(), (y - x).node());
  1027. }
  1028. TEST(TestLayoutUtil, CollectiveCollapse) {
  1029. using namespace opr;
  1030. auto shp2layout = [](const TensorShapeArray& tshps) {
  1031. TensorLayoutArray tlayouts(tshps.size());
  1032. for (size_t i = 0; i < tshps.size(); i++) {
  1033. tlayouts[i] = TensorLayout(tshps[i], dtype::Float32());
  1034. }
  1035. return tlayouts;
  1036. };
  1037. auto check = [](const TensorLayoutArray& res, const TensorLayoutArray& std) {
  1038. for (size_t i = 0; i < res.size(); i++) {
  1039. ASSERT_EQ(std[i], res[i]);
  1040. }
  1041. };
  1042. TensorShapeArray tshps1 = {{3, 3}, {3, 3}, {3, 3}};
  1043. auto cc_res1 = Elemwise::collective_collapse(shp2layout(tshps1));
  1044. TensorShapeArray std_res1 = {{9}, {9}, {9}};
  1045. check(cc_res1, shp2layout(std_res1));
  1046. TensorShapeArray tshps2 = {{3, 3, 3}, {1, 3, 3}};
  1047. auto cc_res2 = Elemwise::collective_collapse(shp2layout(tshps2));
  1048. TensorShapeArray std_res2{{3, 9}, {1, 9}};
  1049. check(cc_res2, shp2layout(std_res2));
  1050. TensorShapeArray tshp3 = {{3, 3, 3}, {3, 3, 1}};
  1051. auto cc_res3 = Elemwise::collective_collapse(shp2layout(tshp3));
  1052. TensorShapeArray std_res3{{9, 3}, {9, 1}};
  1053. check(cc_res3, shp2layout(std_res3));
  1054. TensorShapeArray tshp4 = {{3, 3, 3, 3}, {1, 3, 3, 1}};
  1055. auto cc_res4 = Elemwise::collective_collapse(shp2layout(tshp4));
  1056. TensorShapeArray std_res4{{3, 9, 3}, {1, 9, 1}};
  1057. check(cc_res4, shp2layout(std_res4));
  1058. TensorLayoutArray inp5 = {
  1059. TensorLayout(TensorShape{3, 3}, {1, 3}, dtype::Float32()),
  1060. TensorLayout(TensorShape{3, 3}, {1, 3}, dtype::Float32())};
  1061. auto cc_res5 = Elemwise::collective_collapse(inp5);
  1062. auto std_res5 = inp5;
  1063. check(cc_res5, std_res5);
  1064. }
  1065. TEST(TestOprBasicArithElemwise, EmptyInputOutputUnary) {
  1066. HostTensorGenerator<> gen;
  1067. auto graph = ComputingGraph::make();
  1068. auto host_x = gen({3, 0, 1, 3});
  1069. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  1070. y = opr::Elemwise::make(
  1071. {x}, opr::Elemwise::Param(opr::Elemwise::Param::Mode::RELU));
  1072. HostTensorND host_y;
  1073. auto func = graph->compile({make_callback_copy(y, host_y)});
  1074. ASSERT_NO_THROW(func->execute().wait());
  1075. ASSERT_TRUE(host_y.empty());
  1076. ASSERT_TRUE(host_y.shape().is_empty());
  1077. MGB_ASSERT_SHAPE_EQ(host_y.shape(), TensorShape({3, 0, 1, 3}));
  1078. }
  1079. TEST(TestOprBasicArithElemwise, EmptyInputOutputBinary) {
  1080. HostTensorGenerator<> gen;
  1081. auto graph = ComputingGraph::make();
  1082. auto host_x = gen({0, 8, 1, 7}), host_y = gen({0, 8, 1, 7});
  1083. auto x = opr::Host2DeviceCopy::make(*graph, host_x),
  1084. y = opr::Host2DeviceCopy::make(*graph, host_y), z = x + y;
  1085. HostTensorND host_z;
  1086. auto func = graph->compile({make_callback_copy(z, host_z)});
  1087. // Invalid broadcast
  1088. host_y->resize({0, 9, 1, 7});
  1089. ASSERT_ANY_THROW(func->execute().wait());
  1090. // Broadcast to 0
  1091. host_y->resize({1, 8, 0, 7});
  1092. ASSERT_NO_THROW(func->execute().wait());
  1093. ASSERT_TRUE(host_z.empty());
  1094. ASSERT_TRUE(host_z.shape().is_empty());
  1095. MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 0, 7}));
  1096. // Broadcast to 0 (2)
  1097. host_y->resize({2, 8, 1, 7});
  1098. ASSERT_NO_THROW(func->execute().wait());
  1099. ASSERT_TRUE(host_z.empty());
  1100. ASSERT_TRUE(host_z.shape().is_empty());
  1101. MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
  1102. // Scalar broadcast
  1103. z = x + x.make_scalar(1.f);
  1104. func = graph->compile({make_callback_copy(z, host_z)});
  1105. ASSERT_NO_THROW(func->execute().wait());
  1106. ASSERT_TRUE(host_z.empty());
  1107. ASSERT_TRUE(host_z.shape().is_empty());
  1108. MGB_ASSERT_SHAPE_EQ(host_z.shape(), TensorShape({0, 8, 1, 7}));
  1109. }
  1110. TEST(TestOprBasicArithElemwise, PerformEmptyIO) {
  1111. auto cn = CompNode::load("xpu0");
  1112. HostTensorGenerator<> gen;
  1113. auto host_x1 = gen({2, 0, 3, 4}), host_x2 = gen({1});
  1114. auto dev_x1 = std::make_shared<DeviceTensorND>(cn),
  1115. dev_x2 = std::make_shared<DeviceTensorND>(cn);
  1116. dev_x1->copy_from(*host_x1);
  1117. dev_x2->copy_from(*host_x2);
  1118. auto dev_y = std::make_shared<DeviceTensorND>(cn, dev_x1->dtype());
  1119. dev_y->resize(dev_x1->shape());
  1120. auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(cn);
  1121. // test unary mode
  1122. for (auto mode : {Mode::NEGATE, Mode::EXP, Mode::LOG}) {
  1123. SmallVector<DeviceTensorND> inputs = {*dev_x1};
  1124. ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr));
  1125. ASSERT_TRUE(dev_y->empty());
  1126. ASSERT_TRUE(dev_y->shape().is_empty());
  1127. MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape());
  1128. }
  1129. // test binary mode
  1130. for (auto mode : {Mode::ADD, Mode::MUL, Mode::LT}) {
  1131. SmallVector<DeviceTensorND> inputs = {*dev_x1, *dev_x2};
  1132. ASSERT_NO_THROW(opr::Elemwise::perform(mode, *dev_y, inputs, dnn_opr));
  1133. ASSERT_TRUE(dev_y->empty());
  1134. ASSERT_TRUE(dev_y->shape().is_empty());
  1135. MGB_ASSERT_SHAPE_EQ(dev_y->shape(), dev_x1->shape());
  1136. }
  1137. }
  1138. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}