You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_arith.cpp 64 kB

feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
feat(bazel/windows/xp/sp2/inference): implement inference on windows xp (os vesion >= sp2) build with bazel * bazel build support(define __DEPLOY_ON_XP_SP2__ when deploy on xp sp2): (dbg)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c dbg --copt "-D__DEPLOY_ON_XP_SP2__=1" (opt)./bazel build //brain/megbrain:load_and_run --cpu='x86_windows_xp' --compiler='clang_cl' -c opt --copt "-D__DEPLOY_ON_XP_SP2__=1" * internal behavior: will define MGB_HAVE_THREAD=0 when enable __DEPLOY_ON_XP_SP2__ * refer to https://docs.microsoft.com/en-us/cpp/build/configuring-programs-for-windows-xp?view=msvc-160 xp sp2(x86) do not support vc runtime fully, casused by KERNEL32.dll do not implement some base apis for c++ std function, for example, std::mutex/std::thread/std::condition_variable as a workround, we will disable some MegEngine features on xp sp2 env, for exampe, multi-thread etc! * about DNN_MUTEX/MGB_MUTEX, if your code will build in inference code (even CPU backends), please replace std::mutex to DNN_MUTEX/MGB_MUTEX, * about multi-thread, if you code need multi-thread support, please enable it when MGB_HAVE_THREAD=1 * about test build env status 1: Visual Studio 2019(MSVC version <= 14.26.28801)---- pass 2: Visual Studio 2019(MSVC version > 14.26.28801) ---- failed caused by this 'new' version will put VCR depends on win7 KERNEL32.DLL, this may be fixed at Visual Studio 2019 later version but we do not test at this MR merge point 3: Visual Studio 2017 ---------- pass 4: Visual Studio 2014 ---------- pass GitOrigin-RevId: 65ac48b95e99f2c510fe5db449cc8182d682e113
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828
  1. /**
  2. * \file src/opr/impl/basic_arith.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/opr/basic_arith.h"
  12. #include "megbrain/opr/basic_arith_wrapper.h"
  13. #include "megbrain/opr/utility.h"
  14. #include "megbrain/opr/io.h"
  15. #include "megbrain/opr/cond.h"
  16. #include "megbrain/opr/tensor_manip.h"
  17. #include "megbrain/gopt/basic_arith.h"
  18. #include "megbrain/gopt/gtrans.h"
  19. #include "megbrain/utils/arith_helper.h"
  20. #include "megbrain/graph/grad_impl.h"
  21. #include "./internal/megdnn_opr_wrapper.inl"
  22. #include <cmath>
  23. using namespace mgb;
  24. using namespace opr;
  25. namespace {
  26. //! global operator instance for static inference
  27. template<class Opr>
  28. class StaticInferOpr {
  29. intl::UniqPtrWithCN<Opr> m_opr;
  30. MGB_MUTEX m_mtx;
  31. public:
  32. class Lock {
  33. friend class StaticInferOpr;
  34. StaticInferOpr *m_owner;
  35. explicit Lock(StaticInferOpr *owner):
  36. m_owner{owner}
  37. {
  38. #if !__DEPLOY_ON_XP_SP2__
  39. m_owner->m_mtx.lock();
  40. #endif
  41. }
  42. public:
  43. Lock(Lock &&rhs):
  44. m_owner{rhs.m_owner}
  45. {
  46. rhs.m_owner = nullptr;
  47. }
  48. ~Lock() {
  49. #if !__DEPLOY_ON_XP_SP2__
  50. if (m_owner)
  51. m_owner->m_mtx.unlock();
  52. #endif
  53. }
  54. Lock& operator = (const Lock &) = delete;
  55. Lock& operator = (Lock&&) = delete;
  56. intl::UniqPtrWithCN<Opr>& operator() () {
  57. return m_owner->m_opr;
  58. }
  59. };
  60. //! lock and acquire the operator
  61. Lock lock() {
  62. Lock ret{this};
  63. if (!m_opr) {
  64. m_opr = intl::create_megdnn_opr<Opr>(
  65. CompNode::default_cpu());
  66. }
  67. return ret;
  68. }
  69. };
  70. } // anonymous namespace
  71. /* ========================= BatchedDTypePromotion ========================= */
  72. intl::BatchedDTypePromotion::BatchedDTypePromotion(const VarNodeArrayView& vars)
  73. : m_orig_vars{vars} {
  74. mgb_assert(!vars.empty());
  75. DType final_dtype;
  76. bool changed = false;
  77. for (size_t i = 0; i < vars.size(); ++i) {
  78. auto cur = vars[i]->dtype();
  79. if (!i) {
  80. final_dtype = cur;
  81. } else {
  82. auto promoted = dtype_promotion(final_dtype, cur);
  83. changed |= promoted != final_dtype || promoted != cur;
  84. final_dtype = promoted;
  85. }
  86. }
  87. m_changed = changed;
  88. m_final_dtype = final_dtype;
  89. }
  90. void intl::BatchedDTypePromotion::set_dtype(DType dtype) {
  91. mgb_assert(!m_finalized);
  92. if (m_final_dtype != dtype) {
  93. m_final_dtype = dtype;
  94. m_changed = true;
  95. }
  96. }
  97. const VarNodeArrayView& intl::BatchedDTypePromotion::get_vars() {
  98. m_finalized = true;
  99. if (!m_changed) {
  100. return m_orig_vars;
  101. }
  102. if (!m_cvt_vars_view.valid()) {
  103. m_cvt_vars.resize(m_orig_vars.size());
  104. auto dtype = m_final_dtype;
  105. for (size_t i = 0; i < m_cvt_vars.size(); ++i) {
  106. m_cvt_vars[i] = TypeCvt::make(m_orig_vars[i], dtype).node();
  107. }
  108. m_cvt_vars_view.emplace(m_cvt_vars);
  109. }
  110. return m_cvt_vars_view.val();
  111. }
  112. /* =========================== Elemwise =========================== */
  113. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Elemwise);
  114. Elemwise::Elemwise(
  115. const ModeTrait &mode_trait,
  116. const VarNodeArrayView &inputs, Param param,
  117. const OperatorNodeConfig &config):
  118. Super{inputs.at(0)->owner_graph(), config, mode_trait.name, inputs}
  119. {
  120. init_megdnn_opr(*this, param);
  121. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  122. if (mode_trait.commutable) {
  123. mgb_assert(inputs.size() == 2);
  124. add_input({inputs[0], inputs[1]}, AddInputSortType::CUR_ADDED);
  125. } else {
  126. if (param.mode == Mode::FUSE_MUL_ADD3) {
  127. add_input({inputs[0], inputs[1]}, AddInputSortType::CUR_ADDED);
  128. add_input({inputs[2]});
  129. } else if (param.mode == Mode::FUSE_MUL_ADD4) {
  130. auto i0 = inputs[0], i1 = inputs[1], i2 = inputs[2], i3 = inputs[3];
  131. if (i0->id() > i1->id())
  132. std::swap(i0, i1);
  133. if (i2->id() > i3->id())
  134. std::swap(i2, i3);
  135. if (i0->id() > i2->id()) {
  136. std::swap(i0, i2);
  137. std::swap(i1, i3);
  138. }
  139. add_input({i0, i1, i2, i3});
  140. } else {
  141. for (auto i: inputs)
  142. add_input({i});
  143. }
  144. }
  145. mgb_assert(m_input_broadcastable.size() >= inputs.size());
  146. for (size_t i = 0; i < inputs.size(); ++ i) {
  147. if (input()[i]->owner_opr()->same_type<
  148. opr::MarkNoBroadcastElemwise>()) {
  149. m_input_broadcastable[i] = false;
  150. } else {
  151. m_input_broadcastable[i] = true;
  152. }
  153. }
  154. if (inputs.size() == 1) {
  155. m_input_broadcastable[0] = false;
  156. } else {
  157. Maybe<size_t> non_scalar;
  158. using namespace cg::static_infer;
  159. auto &&mgr = owner_graph()->static_infer_manager();
  160. for (size_t i = 0; i < input().size(); ++ i) {
  161. auto it = mgr.get_infer_type(input(i));
  162. if (!((it.shape & InferType::CONST) &&
  163. mgr.infer_shape(input(i)).is_scalar())) {
  164. if (non_scalar.valid()) {
  165. non_scalar.invalidate();
  166. break;
  167. }
  168. non_scalar = i;
  169. }
  170. }
  171. if (non_scalar.valid()) {
  172. // exactly one input is non-scalar
  173. m_input_broadcastable[non_scalar.val()] = false;
  174. }
  175. }
  176. if (inputs.size() &&
  177. inputs[0]->dtype().category() == DTypeCategory::QUANTIZED) {
  178. mgb_assert(param.mode == Param::Mode::ADD ||
  179. param.mode == Param::Mode::SUB ||
  180. param.mode == Param::Mode::NEGATE ||
  181. param.mode == Param::Mode::RELU ||
  182. param.mode == Param::Mode::MAX ||
  183. param.mode == Param::Mode::MIN,
  184. "Only ADD, SUB, NEGATE, RELU, MAX and MIN is guaranteed "
  185. "to be supported on Elemwise for quantized DType, no support %d", (int)param.mode);
  186. }
  187. }
  188. SymbolVar Elemwise::make(const VarNodeArrayView& inputs, Param param,
  189. const OperatorNodeConfig& config) {
  190. auto trait = ModeTrait::from_mode(param.mode);
  191. mgb_assert(inputs.size() == trait.arity,
  192. "%s expects %u inputs; got %zu actually", trait.name,
  193. trait.arity, inputs.size());
  194. intl::BatchedDTypePromotion dtp{inputs};
  195. if (dtp.get_dtype().category() == DTypeCategory::INT && !trait.allow_int) {
  196. dtp.set_dtype(dtype::Float32());
  197. }
  198. mgb_throw_if(dtp.get_dtype().category() == DTypeCategory::FLOAT &&
  199. !trait.allow_float,
  200. ConversionError,
  201. "elemwise mode %s does not allow float input; "
  202. "got inputs: %s",
  203. trait.name, cg::dump_var_info(inputs).c_str());
  204. #if !MGB_BUILD_SLIM_SERVING
  205. auto&& options = inputs[0]->owner_graph()->options();
  206. if (options.graph_opt_level && !(options.disable_inplace_arith_opt)) {
  207. auto repl = gopt::optimize_elemwise_expr_inplace(dtp.get_vars(), param,
  208. config);
  209. if (repl)
  210. return repl;
  211. }
  212. #endif
  213. return SymbolVar{inputs[0]}.insert_single_output_opr<Elemwise>(
  214. trait, dtp.get_vars(), param, config);
  215. }
  216. TensorShape Elemwise::get_output_var_shape(
  217. Mode mode, const TensorShapeArray &input_shapes) {
  218. mgb_assert(input_shapes.size() == ModeTrait::from_mode(mode).arity);
  219. TensorShape ret;
  220. megdnn::Elemwise::deduce_shape(input_shapes, ret);
  221. return ret;
  222. }
  223. void Elemwise::perform(
  224. Mode mode, DeviceTensorND &dest,
  225. const SmallVector<DeviceTensorND> &inputs,
  226. intl::UniqPtrWithCN<megdnn::Elemwise> &opr) {
  227. megdnn::TensorNDArray dnn_inputs(inputs.size());
  228. TensorShapeArray inp_shapes(inputs.size());
  229. DType out_dt;
  230. CompNode out_cn;
  231. for (size_t i = 0; i < inputs.size(); ++ i) {
  232. auto &&t = inputs[i];
  233. if (!i) {
  234. out_cn = t.comp_node();
  235. out_dt = t.dtype();
  236. } else {
  237. mgb_assert(t.comp_node() == out_cn);
  238. mgb_assert(t.dtype() == out_dt);
  239. }
  240. if (t.shape().is_empty()) {
  241. mgb_assert(dest.empty());
  242. return;
  243. }
  244. inp_shapes[i] = t.shape();
  245. }
  246. if (!opr) {
  247. opr = intl::create_megdnn_opr<megdnn::Elemwise>(out_cn);
  248. } else {
  249. mgb_assert(out_cn == opr.comp_node());
  250. }
  251. out_cn.activate();
  252. for (size_t i = 0; i < inputs.size(); ++ i)
  253. dnn_inputs[i] = inputs[i].as_megdnn();
  254. dest.comp_node(out_cn).dtype(out_dt).resize(
  255. get_output_var_shape(mode, inp_shapes));
  256. opr->param() = {mode};
  257. call_megdnn_opr_exec(out_cn, dnn_inputs, dest.as_megdnn(), opr.get(),
  258. nullptr);
  259. }
  260. TensorLayoutArray Elemwise::collective_collapse(
  261. const TensorLayoutArray& layouts) {
  262. TensorLayoutPtrArray inp(layouts.size());
  263. TensorLayoutArray result(inp.size());
  264. for (size_t i = 0; i < layouts.size(); ++ i) {
  265. result[i] = layouts[i];
  266. inp[i] = &result[i];
  267. }
  268. collective_collapse_inplace(inp);
  269. return result;
  270. }
  271. void Elemwise::collective_collapse_inplace(
  272. const TensorLayoutPtrArray& layouts) {
  273. mgb_assert(layouts.size());
  274. size_t ndim = layouts[0]->ndim;
  275. for (auto i: layouts) {
  276. if (i->ndim != ndim)
  277. mgb_throw(MegBrainError, "ndims must be same");
  278. }
  279. auto update_all = [&layouts](size_t axis) {
  280. for (auto i: layouts) {
  281. i->shape[axis] *= i->shape[axis + 1];
  282. i->stride[axis] = i->stride[axis + 1];
  283. i->remove_axis_inplace(axis + 1);
  284. }
  285. };
  286. auto check = [&layouts](size_t axis) -> bool {
  287. auto std_p = std::make_pair(
  288. layouts[0]->shape[axis], layouts[0]->shape[axis + 1]);
  289. for (auto i: layouts) {
  290. auto cur_p = std::make_pair(i->shape[axis], i->shape[axis + 1]);
  291. if (std_p != cur_p) return false;
  292. if (i->stride[axis] != i->stride[axis + 1] *
  293. static_cast<ptrdiff_t>(i->shape[axis+1]) )
  294. return false;
  295. }
  296. return true;
  297. };
  298. for (int i = static_cast<int>(ndim) - 2; i >= 0; i--) {
  299. if (check(i)) {
  300. update_all(i);
  301. }
  302. }
  303. }
  304. void Elemwise::broadcast_collective_collapse(
  305. const TensorLayoutPtrArray &inp_layouts, TensorLayout *target_layout) {
  306. for (auto &&p: inp_layouts) {
  307. *p = p->broadcast(*target_layout);
  308. }
  309. TensorLayoutPtrArray buf(inp_layouts.size() + 1);
  310. buf[0] = target_layout;
  311. for (size_t i = 0; i < inp_layouts.size(); i++) {
  312. buf[i+1] = inp_layouts[i];
  313. }
  314. collective_collapse_inplace(buf);
  315. }
  316. void Elemwise::mem_plan_fwd_in2out_writable() {
  317. mixin_mem_plan_fwd_in2out_writable(*this);
  318. }
  319. void Elemwise::scn_do_execute() {
  320. auto&& inp = input();
  321. megdnn::TensorNDArray dnn_inp;
  322. mgb_assert(dnn_inp.capacity() >= inp.size(),
  323. "heap allocation in elemwise exec");
  324. dnn_inp.resize(inp.size());
  325. for (size_t i = 0; i < inp.size(); ++i) {
  326. if (inp[i]->dev_tensor().empty()) {
  327. mgb_assert(output(0)->dev_tensor().empty());
  328. return;
  329. }
  330. dnn_inp[i] = (inp[i]->dev_tensor().as_megdnn());
  331. }
  332. mgb_assert(!output(0)->dev_tensor().empty());
  333. megdnn_opr()->param() = param();
  334. call_megdnn_opr_exec(comp_node(), dnn_inp,
  335. output(0)->dev_tensor().as_megdnn(), megdnn_opr(),
  336. this);
  337. }
  338. void Elemwise::init_output_static_infer_desc() {
  339. Super::init_output_static_infer_desc();
  340. static StaticInferOpr<megdnn::Elemwise> static_infer_opr;
  341. using namespace cg::static_infer;
  342. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  343. SmallVector<DeviceTensorND> inp_vals(inp.val.size());
  344. for (size_t i = 0; i < inp_vals.size(); ++ i)
  345. inp_vals[i] = inp.val[i].value();
  346. auto sopr = static_infer_opr.lock();
  347. perform(param().mode, dest, inp_vals, sopr());
  348. return true;
  349. };
  350. DepVal deps(input().size());
  351. for (size_t i = 0; i < input().size(); ++ i)
  352. deps[i] = {input(i), DepType::VALUE};
  353. owner_graph()->static_infer_manager().register_value_infer(
  354. output(0), {SourceType::DEP, deps, infer_value});
  355. }
  356. void Elemwise::get_output_var_shape(
  357. const TensorShapeArray &inp_shape, TensorShapeArray &out_shape) const {
  358. out_shape.at(0) = get_output_var_shape(param().mode, inp_shape);
  359. for (size_t i = 0; i < input().size(); ++ i) {
  360. mgb_throw_if(!m_input_broadcastable[i] &&
  361. !out_shape[0].eq_shape(inp_shape[i]), GraphError,
  362. "input %zu declared to be non-broadcastable but broacast "
  363. "actually happened", i);
  364. }
  365. }
  366. void Elemwise::add_input_layout_constraint() {
  367. for (auto i: input()) {
  368. i->add_layout_constraint_monotone();
  369. }
  370. }
  371. void Elemwise::call_megdnn_opr_exec(
  372. CompNode comp_node,
  373. megdnn::TensorNDArray &inp, const megdnn::TensorND &out,
  374. megdnn::Elemwise *opr, Elemwise *caller) {
  375. if (opr->param().mode == Mode::FUSE_MUL_ADD3 &&
  376. !(inp[2].layout.eq_layout(inp[0].layout) ||
  377. inp[2].layout.eq_layout(inp[1].layout) ||
  378. inp[2].layout.is_scalar())) {
  379. if (caller && !caller->fuse_badlayout_warn_printed()) {
  380. mgb_log_debug("%s: FUSE_MUL_ADD3 input layouts mismatch: %s %s %s; "
  381. "fallback to normal computing",
  382. caller->cname(),
  383. inp[0].layout.to_string().c_str(),
  384. inp[1].layout.to_string().c_str(),
  385. inp[2].layout.to_string().c_str()
  386. );
  387. caller->m_fuse_badlayout_warn_printed = true;
  388. }
  389. for (auto &&i: inp) {
  390. i.layout = i.layout.broadcast(out.layout);
  391. }
  392. megdnn::TensorNDArray run_inp(2);
  393. auto run = [&](Mode mode,
  394. const megdnn::TensorND &i0, const megdnn::TensorND &i1,
  395. const megdnn::TensorND &out) {
  396. run_inp[0] = i0;
  397. run_inp[1] = i1;
  398. opr->param() = {mode};
  399. opr->exec(run_inp, out);
  400. };
  401. auto tmp =
  402. intl::get_temp_tensor(caller ? caller->owner_graph() : nullptr,
  403. comp_node, out.layout);
  404. auto tmpv = tmp.as_megdnn();
  405. MGB_TRY {
  406. run(Mode::MUL, inp[0], inp[1], tmpv);
  407. run(Mode::ADD, inp[2], tmpv, out);
  408. } MGB_FINALLY(opr->param() = {Mode::FUSE_MUL_ADD3});
  409. return;
  410. }
  411. if (opr->param().mode == Mode::FUSE_MUL_ADD4 &&
  412. !(inp[0].layout.eq_layout(inp[2].layout) &&
  413. inp[1].layout.eq_layout(inp[3].layout)) &&
  414. !(inp[0].layout.eq_layout(inp[3].layout) &&
  415. inp[1].layout.eq_layout(inp[2].layout))) {
  416. if (caller && !caller->fuse_badlayout_warn_printed()) {
  417. mgb_log_debug(
  418. "%s: FUSE_MUL_ADD4 input layouts mismatch: %s %s %s %s; "
  419. "fallback to normal computing",
  420. caller->cname(),
  421. inp[0].layout.to_string().c_str(),
  422. inp[1].layout.to_string().c_str(),
  423. inp[2].layout.to_string().c_str(),
  424. inp[3].layout.to_string().c_str()
  425. );
  426. caller->m_fuse_badlayout_warn_printed = true;
  427. }
  428. for (auto &&i: inp) {
  429. i.layout = i.layout.broadcast(out.layout);
  430. }
  431. megdnn::TensorNDArray run_inp(2);
  432. auto run = [&](Mode mode,
  433. const megdnn::TensorND &i0, const megdnn::TensorND &i1,
  434. const megdnn::TensorND &out) {
  435. run_inp[0] = i0;
  436. run_inp[1] = i1;
  437. opr->param() = {mode};
  438. opr->exec(run_inp, out);
  439. };
  440. auto tmp =
  441. intl::get_temp_tensor(caller ? caller->owner_graph() : nullptr,
  442. comp_node, out.layout);
  443. auto tmpv = tmp.as_megdnn();
  444. MGB_TRY {
  445. run(Mode::MUL, inp[0], inp[1], tmpv);
  446. run(Mode::MUL, inp[2], inp[3], out);
  447. run(Mode::ADD, out, tmpv, out);
  448. } MGB_FINALLY(opr->param() = {Mode::FUSE_MUL_ADD4});
  449. return;
  450. }
  451. // All Elemwise operations on QuantizedS32/QuantizedS8 are not related to
  452. // scale. MegDNN does not support computing Elemwise for
  453. // QuantizedS32/QuantizedS8, we translate the data type to Int32/Int8 before
  454. // passing to MegDNN.
  455. if (inp.size() &&
  456. inp[0].layout.dtype.category() == DTypeCategory::QUANTIZED) {
  457. auto inp_dtype = inp[0].layout.dtype;
  458. DType compute_dtype;
  459. if (inp_dtype.enumv() == DTypeEnum::QuantizedS32) {
  460. compute_dtype = dtype::Int32();
  461. } else if (inp_dtype.enumv() == DTypeEnum::QuantizedS8) {
  462. compute_dtype = dtype::Int8();
  463. } else {
  464. mgb_throw(MegBrainError,
  465. "Unsupported Quantized Elemwise Mode %s: %d on %s",
  466. inp[0].layout.dtype.name(), int(opr->param().mode),
  467. comp_node.to_string().c_str());
  468. }
  469. megdnn::TensorNDArray run_inp(inp);
  470. for (size_t i = 0; i < inp.size(); i++) {
  471. run_inp[i].layout.dtype = compute_dtype;
  472. }
  473. megdnn::TensorND run_out = out;
  474. run_out.layout.dtype = compute_dtype;
  475. opr->exec(run_inp, run_out);
  476. return;
  477. }
  478. opr->exec(inp, out);
  479. }
  480. #if MGB_ENABLE_GRAD
  481. MGB_IMPL_OPR_GRAD(Elemwise) {
  482. SymbolVar i[5];
  483. SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)),
  484. og{out_grad.at(0)}, result;
  485. for (size_t t = 0; t < opr.input().size(); ++ t)
  486. i[t] = opr.input()[t];
  487. if (opr.input().size() >= 2)
  488. i1 = opr.input(1);
  489. if (opr.input().size() >= 3)
  490. i2 = opr.input(2);
  491. // negate after reduce, for better performance
  492. bool negate_result = false;
  493. #define RET(_v) result = (_v); break
  494. #define EL1(_mode, _a) Elemwise::make({_a}, Mode::_mode)
  495. #define EL2(_mode, _a, _b) Elemwise::make({_a, _b}, Mode::_mode)
  496. #define EL3(_mode, _a, _b, _c) Elemwise::make({_a, _b, _c}, Mode::_mode)
  497. #define RET_INVALID() return InvalidGrad::make(opr, wrt_idx)
  498. using Mode = Elemwise::Mode;
  499. switch (opr.param().mode) {
  500. // unary
  501. case Mode::RELU:
  502. case Mode::FUSE_ADD_RELU:
  503. RET(EL2(SWITCH_GT0, out, og));
  504. case Mode::ABS:
  505. RET(EL2(ABS_GRAD, i0, og));
  506. case Mode::ACOS:
  507. negate_result = true;
  508. RET(og / EL1(SIN, out));
  509. case Mode::ASIN:
  510. RET(og / EL1(COS, out));
  511. case Mode::ATAN2:
  512. if (wrt_idx) {
  513. negate_result = true;
  514. }
  515. RET(og * i[!wrt_idx] / (i0 * i0 + i1 * i1));
  516. case Mode::CEIL:
  517. return nullptr;
  518. case Mode::COS:
  519. negate_result = true;
  520. RET(EL1(SIN, i0) * og);
  521. case Mode::EXP:
  522. RET(og * out);
  523. case Mode::EXPM1:
  524. RET(og * EL1(EXP, i0));
  525. case Mode::FLOOR:
  526. return nullptr;
  527. case Mode::LOG:
  528. RET(og / i0);
  529. case Mode::LOG1P:
  530. RET(og / (i0 + 1));
  531. case Mode::NEGATE:
  532. negate_result = true;
  533. RET(og);
  534. case Mode::SIGMOID:
  535. case Mode::FUSE_ADD_SIGMOID:
  536. RET(EL2(SIGMOID_GRAD, out, og));
  537. case Mode::SIN:
  538. RET(EL1(COS, i0) * og);
  539. case Mode::TANH:
  540. case Mode::FUSE_ADD_TANH:
  541. RET(EL2(TANH_GRAD, out, og));
  542. case Mode::FAST_TANH:
  543. RET(EL2(FAST_TANH_GRAD, i0, og));
  544. case Mode::ROUND:
  545. return nullptr;
  546. case Mode::ERF:
  547. RET(EL1(EXP, - i0 * i0) * 2 / static_cast<float>(sqrt(M_PI)) * og);
  548. case Mode::ERFINV:
  549. RET(EL1(EXP, out * out) * static_cast<float>(sqrt(M_PI)) / 2 * og);
  550. case Mode::ERFC:
  551. RET(-EL1(EXP, -i0 * i0) * 2 / static_cast<float>(sqrt(M_PI)) * og);
  552. case Mode::H_SWISH:
  553. RET(EL2(H_SWISH_GRAD, i0, og));
  554. case Mode::FUSE_ADD_H_SWISH:
  555. RET(EL2(H_SWISH_GRAD, (i0 + i1), og));
  556. case Mode::NOT:
  557. return nullptr;
  558. case Mode::SILU:
  559. RET(EL2(SILU_GRAD, i0, og));
  560. case Mode::GELU:
  561. RET(EL2(GELU_GRAD, i0, og));
  562. // binary
  563. case Mode::ABS_GRAD:
  564. if (wrt_idx == 0) {
  565. return nullptr;
  566. }
  567. RET(EL2(ABS_GRAD, i0, og));
  568. case Mode::ADD:
  569. RET(og);
  570. case Mode::FLOOR_DIV:
  571. return nullptr;
  572. case Mode::MAX:
  573. RET(EL3(COND_LEQ_MOV, i[!wrt_idx], i[wrt_idx], og));
  574. case Mode::MIN:
  575. RET(EL3(COND_LEQ_MOV, i[wrt_idx], i[!wrt_idx], og));
  576. case Mode::MOD:
  577. if (wrt_idx == 0) {
  578. RET(og);
  579. }
  580. RET_INVALID();
  581. case Mode::MUL:
  582. RET(og * i[!wrt_idx]);
  583. case Mode::POW:
  584. if (wrt_idx) {
  585. RET(out * EL1(LOG, i0) * og);
  586. }
  587. RET(og * i1 * EL2(POW, i0, i1 - 1));
  588. case Mode::SIGMOID_GRAD:
  589. if (wrt_idx == 0) {
  590. auto one = i0.make_scalar_dt(1), two = i0.make_scalar_dt(2);
  591. RET((one - i0 * two) * i1 * og);
  592. }
  593. RET(EL2(SIGMOID_GRAD, i0, og));
  594. case Mode::SUB:
  595. negate_result = wrt_idx;
  596. RET(og);
  597. case Mode::SWITCH_GT0:
  598. if (!wrt_idx)
  599. return nullptr;
  600. RET(EL2(SWITCH_GT0, i0, og));
  601. case Mode::TANH_GRAD:
  602. if (wrt_idx == 0) {
  603. auto mtwo = i0.make_scalar_dt(-2);
  604. RET(mtwo * i0 * i1 * og);
  605. }
  606. RET(EL2(TANH_GRAD, i0, og));
  607. case Mode::TRUE_DIV:
  608. if (wrt_idx == 0) {
  609. RET(og / i1);
  610. }
  611. negate_result = true;
  612. RET((og * i0) * EL2(POW, i1, i1.make_scalar(-2)));
  613. case Mode::LOG_SUM_EXP:
  614. if (wrt_idx == 0) {
  615. RET(og * EL1(SIGMOID, i0 - i1));
  616. }
  617. RET(og * EL1(SIGMOID, i1 - i0));
  618. case Mode::LT:
  619. case Mode::LEQ:
  620. return nullptr;
  621. case Mode::EQ:
  622. RET_INVALID();
  623. case Mode::OR:
  624. case Mode::XOR:
  625. case Mode::AND:
  626. return nullptr;
  627. // ternary
  628. case Mode::COND_LEQ_MOV:
  629. if (wrt_idx <= 1)
  630. return nullptr;
  631. RET(EL3(COND_LEQ_MOV, i0, i1, og));
  632. // fuse oprs
  633. case Mode::FUSE_MUL_ADD3:
  634. if (wrt_idx < 2) {
  635. RET(og * i[wrt_idx ^ 1]);
  636. } else {
  637. RET(og);
  638. }
  639. case Mode::FUSE_MUL_ADD4:
  640. RET(og * i[wrt_idx ^ 1]);
  641. default:
  642. mgb_throw(GraphError, "grad for elemwise mode %s unimplemented",
  643. megdnn::Elemwise::ModeTrait::from_mode(
  644. opr.param().mode).name);
  645. }
  646. #undef EL3
  647. #undef EL2
  648. #undef EL1
  649. #undef RET
  650. if (opr.input_broadcastable()[wrt_idx]) {
  651. result = reduce_sum(result,
  652. opr::GetVarShape::make(opr.input(wrt_idx)));
  653. } else if (result.node()->owner_opr()->same_type<Broadcast>()) {
  654. // forward broadcast for optimizer to work
  655. result = opr::Broadcast::make(result.node()->owner_opr()->input(0),
  656. opr::GetVarShape::make(i[wrt_idx]));
  657. }
  658. if (negate_result)
  659. result = -result;
  660. return result.node();
  661. }
  662. #endif
  663. VarNode* Elemwise::sum_grad_list(VarNode *wrt, VarNodeArray &grads) {
  664. mgb_assert(!grads.empty());
  665. if (grads.size() == 1)
  666. return grads[0];
  667. #if MGB_ENABLE_COND_EXEC
  668. CondExecMerge::modify_grad_sum_list(wrt, grads);
  669. #endif
  670. VarNodeArray mid_results;
  671. VarNode *ret;
  672. if (wrt->owner_graph()->options().graph_opt_level) {
  673. ret = gopt::GradSumListOptimizer{wrt, grads, mid_results}.get_sum();
  674. } else {
  675. ret = gopt::elemwise_reduce_var_list(
  676. grads, Elemwise::Mode::ADD, &mid_results);
  677. }
  678. mid_results.swap(grads);
  679. return ret;
  680. }
  681. void Elemwise::record_execute_deps(ExecDependencyArray& deps) {
  682. record_megdnn_opr(deps);
  683. }
  684. Elemwise::NodeProp* Elemwise::do_make_node_prop() const {
  685. auto ret = Super::do_make_node_prop();
  686. for (auto& inp : input()) {
  687. ret->add_dep_type_existing_var(inp,
  688. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  689. }
  690. return ret;
  691. }
  692. /* =========================== TypeCvt =========================== */
  693. MGB_DYN_TYPE_OBJ_FINAL_IMPL(TypeCvt);
  694. TypeCvt::TypeCvt(
  695. VarNode *inp, DType dest_type, const OperatorNodeConfig &config):
  696. Super{inp->owner_graph(), config, std::string("as") + dest_type.name(),
  697. {inp}}
  698. {
  699. init_megdnn_opr(*this, {});
  700. mgb_assert(dest_type.valid());
  701. add_input({inp});
  702. add_equivalence_component<ScalarHash<const void*>>(dest_type.handle());
  703. output(0)->dtype(dest_type).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  704. }
  705. SymbolVar TypeCvt::make(
  706. SymbolVar input, DType dest_type, const OperatorNodeConfig &config) {
  707. if (input.dtype() == dest_type)
  708. return input;
  709. return input.insert_single_output_opr<TypeCvt>(
  710. input.node(), dest_type, config);
  711. }
  712. void TypeCvt::perform(DeviceTensorND &dest,
  713. DType dest_type, const DeviceTensorND &src,
  714. intl::UniqPtrWithCN<megdnn::TypeCvt> &opr) {
  715. mgb_assert(src.comp_node() == opr.comp_node());
  716. mgb_assert(dest_type.valid());
  717. if (src.empty()) {
  718. mgb_assert(dest.empty());
  719. return;
  720. }
  721. if (src.dtype() == dest_type) {
  722. dest.copy_from(src);
  723. return;
  724. }
  725. src.comp_node().activate();
  726. dest.comp_node(src.comp_node()).dtype(dest_type).resize(src.shape());
  727. opr->exec(src.as_megdnn(), dest.as_megdnn());
  728. }
  729. void TypeCvt::add_input_layout_constraint() {
  730. for (auto i: input()) {
  731. i->add_layout_constraint_contiguous();
  732. }
  733. }
  734. TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const {
  735. auto ret = Super::do_make_node_prop();
  736. ret->add_dep_type_existing_var(input(0),
  737. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  738. return ret;
  739. }
  740. #if MGB_ENABLE_GRAD
  741. MGB_IMPL_OPR_GRAD(TypeCvt) {
  742. MGB_MARK_USED_VAR(wrt_idx);
  743. auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype();
  744. if (itype.category() == DTypeCategory::FLOAT &&
  745. otype.category() == DTypeCategory::INT) {
  746. return nullptr;
  747. }
  748. if (itype.category() != DTypeCategory::FLOAT) {
  749. return InvalidGrad::make(opr, 0);
  750. }
  751. return TypeCvt::make(out_grad[0], opr.input(0)->dtype()).node();
  752. }
  753. #endif
  754. void TypeCvt::mem_plan_fwd_in2out_writable() {
  755. bool cond_low_bit =
  756. input(0)->dtype().is_low_bit() && output(0)->dtype().is_low_bit() &&
  757. input(0)->dtype().low_bit() == output(0)->dtype().low_bit();
  758. bool cond_normal = !input(0)->dtype().is_low_bit() &&
  759. !output(0)->dtype().is_low_bit() &&
  760. input(0)->dtype().size() == output(0)->dtype().size();
  761. if ((cond_low_bit || cond_normal) && input(0)->layout().is_contiguous()) {
  762. output(0)->set_fwd_in2out_writable(input(0));
  763. }
  764. }
  765. void TypeCvt::scn_do_execute() {
  766. auto ovar = output(0)->dev_tensor().as_megdnn();
  767. for (size_t i = 0; i < ovar.layout.ndim; ++i) {
  768. if (!ovar.layout[i]) {
  769. // skip execution for empty var
  770. return;
  771. }
  772. }
  773. megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), ovar);
  774. }
  775. void TypeCvt::init_output_static_infer_desc() {
  776. static StaticInferOpr<megdnn::TypeCvt> static_infer_opr;
  777. Super::init_output_static_infer_desc();
  778. using namespace cg::static_infer;
  779. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  780. auto sopr = static_infer_opr.lock();
  781. perform(dest, output(0)->dtype(), inp.val.at(0).value(), sopr());
  782. return true;
  783. };
  784. owner_graph()->static_infer_manager().register_value_infer(
  785. output(0), {SourceType::DEP, {{input(0), DepType::VALUE}},
  786. infer_value});
  787. }
  788. void TypeCvt::record_execute_deps(ExecDependencyArray& deps) {
  789. record_megdnn_opr(deps);
  790. }
  791. /* =========================== AddUpdate =========================== */
  792. MGB_DYN_TYPE_OBJ_FINAL_IMPL(AddUpdate);
  793. AddUpdate::AddUpdate(VarNode *dest, VarNode *delta,
  794. const Param &param,
  795. const OperatorNodeConfig &config):
  796. Super{dest->owner_graph(), config, "inplace_add", {dest, delta}},
  797. m_param{param}
  798. {
  799. auto dest_opr = dest->owner_opr();
  800. mgb_throw_if(dest_opr->same_type<ImmutableTensor>(),
  801. GraphError,
  802. "AddUpdate cannot be applied on ImmutableTensor; ");
  803. add_input({dest, delta});
  804. /*
  805. * here we tell the system that output(0) would force-update input(0); the
  806. * topo-sorting system would ensure that all the readers finish before
  807. * executing this AddUpdate operation
  808. */
  809. add_output(None)->
  810. set_fwd_in2out_writable_force(input(0)).
  811. add_flag(VarNode::Flag::NO_MEM_RECLAIM);
  812. mgb_assert(m_param.disable->dtype() == dtype::Int32{},
  813. "dtype of disable flag on AddUpdate must be Int32, got %s actually.",
  814. m_param.disable->dtype().name());
  815. add_equivalence_component<ScalarHash<void*>>(m_param.alpha.get());
  816. add_equivalence_component<ScalarHash<void*>>(m_param.beta.get());
  817. add_equivalence_component<ScalarHash<void*>>(m_param.bias.get());
  818. add_equivalence_component<ScalarHash<void*>>(m_param.disable.get());
  819. }
  820. SymbolVar AddUpdate::make(SymbolVar dest, SymbolVar delta,
  821. const Param &param, const OperatorNodeConfig &config) {
  822. delta = opr::TypeCvt::make(delta, dest.dtype());
  823. return dest.insert_single_output_opr<AddUpdate>(
  824. dest.node(), delta.node(), param, config);
  825. }
  826. cg::OperatorNodeBase::NodeProp* AddUpdate::do_make_node_prop() const {
  827. auto ret = Super::do_make_node_prop();
  828. ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR);
  829. return ret;
  830. }
  831. void AddUpdate::create_megdnn_opr() {
  832. set_megdnn_opr(intl::get_megdnn_handle(comp_node())->
  833. create_operator<megdnn::AddUpdate>());
  834. }
  835. void AddUpdate::scn_do_execute() {
  836. mgb_assert(m_param.disable->dtype() == dtype::Int32{},
  837. "dtype of disable flag on AddUpdate must be Int32, got %s actually.",
  838. m_param.disable->dtype().name());
  839. auto disable = m_param.disable->get_cast<int>();
  840. if(disable == 1) return;
  841. mgb_assert(disable == 0, "disable flag on AddUpdate can only be 0 or 1,"
  842. " got %d actually.", disable);
  843. auto &&dest = output(0)->dev_tensor();
  844. auto &&delta_nobrd = input(1)->dev_tensor();
  845. auto delta = delta_nobrd.sub(SubTensorSpec::make_from_offset_elem(
  846. delta_nobrd.layout().broadcast(dest.shape()), 0));
  847. mgb_assert(input(0)->dev_tensor().raw_ptr() == dest.raw_ptr());
  848. auto beta = m_param.beta->get_cast<float>();
  849. if (!m_param.alpha->get_cast<bool>() && beta == 1 &&
  850. !m_param.bias->get_cast<bool>()) {
  851. dest.copy_from_fixlayout(delta);
  852. } else {
  853. auto opr = static_cast<megdnn::AddUpdate*>(megdnn_opr());
  854. opr->param() = {
  855. m_param.alpha->get_cast<float>(),
  856. beta,
  857. m_param.bias->get_cast<float>()};
  858. opr->exec(dest.as_megdnn(), delta.as_megdnn());
  859. }
  860. }
  861. void AddUpdate::init_output_static_infer_desc() {
  862. using namespace cg::static_infer;
  863. owner_graph()->static_infer_manager().register_shape_infer(
  864. output(0), ShapeInferDesc::make_identity(input(0)));
  865. }
  866. void AddUpdate::record_execute_deps(ExecDependencyArray& deps) {
  867. record_megdnn_opr(deps);
  868. }
  869. #if MGB_ENABLE_GRAD
  870. MGB_IMPL_OPR_GRAD(AddUpdate) {
  871. // actually valid, just not implemented
  872. return InvalidGrad::make(opr, wrt_idx);
  873. }
  874. #endif
  875. /* =========================== Reduce =========================== */
  876. class Reduce::KernScheduler {
  877. class ValueDep final : public ExecDependency {
  878. DeviceTensorStorage m_val;
  879. public:
  880. explicit ValueDep(DeviceTensorStorage val) : m_val(std::move(val)) {}
  881. };
  882. public:
  883. bool has_actual_computing() const {
  884. mgb_assert(m_shape_computed);
  885. return !m_kern_param.empty() || m_apply_side_effect;
  886. }
  887. size_t workspace_size() const {
  888. return m_workspace_spec[2].end();
  889. }
  890. bool shape_computed() const {
  891. return m_shape_computed;
  892. }
  893. //! init shapes in kern param
  894. void init_shapes(
  895. megdnn::Reduce *opr, CompNode comp_node, DType dtype, Mode mode,
  896. TensorShape ishp, TensorShape oshp, const Param::DataType data_type);
  897. void setup_kern_params_layout_and_mode(Mode mode, DType inp_dtype,
  898. TensorShape& inp_shp,
  899. const Param::DataType);
  900. void check_shapes(const TensorShape &ishp, const TensorShape &oshp) {
  901. mgb_assert(m_prev_ishp.eq_shape(ishp) &&
  902. m_prev_oshp.eq_shape(oshp));
  903. }
  904. //! update pointers in kern param; the tensors must have been allocated
  905. void update_ptr(
  906. const DeviceTensorND &input, const DeviceTensorND &dest,
  907. const DeviceTensorND &workspace);
  908. void execute(megdnn::Reduce *opr,
  909. const DeviceTensorND &input, const DeviceTensorND &dest);
  910. void record_execute_deps(ExecDependencyArray& deps) {
  911. if (m_elemwise_trans_opr) {
  912. deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(
  913. std::move(m_elemwise_trans_opr)));
  914. }
  915. if (m_typecvt_opr) {
  916. deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(
  917. std::move(m_typecvt_opr)));
  918. }
  919. deps.emplace_back(
  920. std::make_unique<ValueDep>(m_side_affect_wkspc.storage()));
  921. }
  922. private:
  923. struct KernParam {
  924. megdnn::TensorND input, output;
  925. //! param passed to megdnn
  926. megdnn::param::Reduce kparam;
  927. megdnn::Workspace workspace;
  928. KernParam(Mode mode, int32_t ra):
  929. kparam{mode, ra}
  930. {
  931. }
  932. };
  933. struct SubWorkspace {
  934. size_t size, offset;
  935. size_t end() const {
  936. return size + offset;
  937. }
  938. };
  939. void update_kparam_for_elemwise_side_effect(
  940. CompNode comp_node, Mode mode, const Param::DataType data_type);
  941. bool m_shape_computed = false;
  942. std::vector<KernParam> m_kern_param;
  943. TensorShape m_prev_ishp, m_prev_oshp;
  944. SubWorkspace m_workspace_spec[3]; //! tmp output[2], kern workspce
  945. /*!
  946. * some reduce mode (like SUM_SQR) has side effect of element-wise
  947. * trans. If this is the case and there is no kernel param,
  948. * m_apply_side_effect would be non-null
  949. */
  950. thin_function<void(const DeviceTensorND &in,
  951. const DeviceTensorND &out)>
  952. m_apply_side_effect;
  953. std::unique_ptr<megdnn::Elemwise> m_elemwise_trans_opr;
  954. std::unique_ptr<megdnn::TypeCvt> m_typecvt_opr;
  955. DeviceTensorND m_side_affect_wkspc;
  956. };
  957. void Reduce::KernScheduler::setup_kern_params_layout_and_mode(Mode mode,
  958. DType inp_dtype,
  959. TensorShape& ishp,
  960. const Param::DataType data_type) {
  961. auto prev_dtype = inp_dtype;
  962. for (size_t idx = 0; idx < m_kern_param.size(); ++idx) {
  963. auto&& i = m_kern_param[idx];
  964. #if !MEGDNN_DISABLE_FLOAT16
  965. if (idx == 0 && data_type == Param::DataType::FLOAT_O32xC32) {
  966. i.input.layout.dtype = inp_dtype;
  967. i.output.layout.dtype = dtype::Float32();
  968. i.kparam.data_type = data_type;
  969. } else if (data_type == Param::DataType::FLOAT_O16xC32) {
  970. i.input.layout.dtype = prev_dtype;
  971. if (idx + 1 == m_kern_param.size()) {
  972. i.output.layout.dtype = dtype::Float16();
  973. i.kparam.data_type = data_type;
  974. }
  975. else {
  976. i.output.layout.dtype = dtype::Float32();
  977. i.kparam.data_type = Param::DataType::FLOAT_O32xC32;
  978. }
  979. } else
  980. #endif
  981. {
  982. mgb_assert(data_type == Param::DataType::DEFAULT || (
  983. data_type == Param::DataType::FLOAT_O32xC32 &&
  984. idx));
  985. i.input.layout.dtype = prev_dtype;
  986. i.output.layout.dtype = prev_dtype;
  987. i.kparam.data_type = Param::DataType::DEFAULT;
  988. }
  989. prev_dtype = i.output.layout.dtype;
  990. i.input.layout.init_contiguous_stride(ishp);
  991. ishp.shape[i.kparam.axis] = 1;
  992. i.output.layout.init_contiguous_stride(ishp);
  993. }
  994. if (mode == Mode::SUM_SQR) {
  995. for (size_t i = 1; i < m_kern_param.size(); ++ i)
  996. m_kern_param[i].kparam.mode = Mode::SUM;
  997. }
  998. }
  999. void Reduce::KernScheduler::init_shapes(
  1000. megdnn::Reduce *opr, CompNode comp_node, DType inp_dtype, Mode mode,
  1001. TensorShape ishp, TensorShape oshp, const Param::DataType data_type) {
  1002. mgb_assert(ishp.ndim && oshp.ndim);
  1003. if (ishp.eq_shape(m_prev_ishp) && oshp.eq_shape(m_prev_oshp))
  1004. return;
  1005. m_prev_ishp = ishp;
  1006. m_prev_oshp = oshp;
  1007. m_kern_param.clear();
  1008. if (oshp.is_scalar()) {
  1009. // if ishp is non-contiguous, add_layout_constraint_contiguous would be
  1010. // added; so we do not have to worry about this
  1011. ishp.shape[0] = ishp.total_nr_elems();
  1012. ishp.ndim = 1;
  1013. }
  1014. mgb_assert(oshp.ndim == ishp.ndim,
  1015. "input and output ndim mismatch for reduction: ishp=%s oshp=%s",
  1016. ishp.to_string().c_str(), oshp.to_string().c_str());
  1017. for (size_t i = 0; i < ishp.ndim; ++ i) {
  1018. if (ishp.shape[i] != oshp.shape[i]) {
  1019. mgb_assert(oshp.shape[i] == 1,
  1020. "input and output shape mismatch for reduction: "
  1021. "ishp=%s oshp=%s",
  1022. ishp.to_string().c_str(), oshp.to_string().c_str());
  1023. }
  1024. }
  1025. auto remove_axis = [](TensorShape &shp, size_t ax) {
  1026. mgb_assert(shp.ndim > 1);
  1027. for (auto i = ax + 1; i < shp.ndim; ++ i)
  1028. shp.shape[i - 1] = shp.shape[i];
  1029. -- shp.ndim;
  1030. };
  1031. // collapse consecutive shape-1 axes in oshp
  1032. for (size_t i = 0; i < oshp.ndim; ++ i) {
  1033. auto start = i;
  1034. while (i < oshp.ndim && oshp.shape[i] == 1)
  1035. ++ i;
  1036. if (start + 1 < i) {
  1037. for (auto j = start + 1; j < i; ++ j)
  1038. ishp.shape[start] *= ishp.shape[j];
  1039. for (auto j = start + 1; j < i; ++ j) {
  1040. remove_axis(ishp, start + 1);
  1041. remove_axis(oshp, start + 1);
  1042. }
  1043. i = start;
  1044. }
  1045. }
  1046. for (uint32_t i = 0; i < ishp.ndim; ++ i) {
  1047. if (ishp.shape[i] != oshp.shape[i]) {
  1048. mgb_assert(oshp.shape[i] == 1);
  1049. m_kern_param.push_back({mode, static_cast<int32_t>(i)});
  1050. }
  1051. }
  1052. // sort according to reduction size, so workspace can be smaller
  1053. small_sort(m_kern_param.begin(), m_kern_param.end(),
  1054. [&](const KernParam &a, const KernParam &b) {
  1055. return ishp.shape[a.kparam.axis] > ishp.shape[b.kparam.axis];
  1056. });
  1057. // init kparam input/output layout
  1058. setup_kern_params_layout_and_mode(mode, inp_dtype, ishp, data_type);
  1059. // init workspace size
  1060. memset(m_workspace_spec, 0, sizeof(m_workspace_spec));
  1061. for (auto&& i : m_kern_param) {
  1062. opr->param() = i.kparam;
  1063. i.workspace.size = opr->get_workspace_in_bytes(
  1064. i.input.layout, i.output.layout);
  1065. update_max(m_workspace_spec[2].size, i.workspace.size);
  1066. }
  1067. mgb_assert(ishp.eq_shape(oshp));
  1068. if (m_kern_param.size() >= 2) {
  1069. m_workspace_spec[0].size =
  1070. m_kern_param[1].input.layout.span().high_byte;
  1071. }
  1072. if (m_kern_param.size() >= 3) {
  1073. m_workspace_spec[1].size =
  1074. m_kern_param[2].input.layout.span().high_byte;
  1075. }
  1076. auto align = comp_node.get_mem_addr_alignment();
  1077. for (int i = 0; i < 2; ++ i) {
  1078. m_workspace_spec[i + 1].offset = get_aligned_power2(
  1079. m_workspace_spec[i].end(), align);
  1080. }
  1081. update_kparam_for_elemwise_side_effect(comp_node, mode, data_type);
  1082. m_shape_computed = true;
  1083. }
  1084. void Reduce::KernScheduler::update_kparam_for_elemwise_side_effect(
  1085. CompNode comp_node, Mode mode, const Param::DataType data_type) {
  1086. m_apply_side_effect = nullptr;
  1087. m_elemwise_trans_opr.reset();
  1088. m_typecvt_opr.reset();
  1089. if (!m_kern_param.empty()) {
  1090. // no need to set m_apply_side_effect
  1091. return;
  1092. } /* else */
  1093. // case A: input.layout == output.layout
  1094. // case B: input.total_nr_elems == 1 and output is a scalar
  1095. if (mode == Mode::SUM_SQR) {
  1096. m_elemwise_trans_opr = intl::get_megdnn_handle(comp_node)->
  1097. create_operator<megdnn::Elemwise>();
  1098. m_elemwise_trans_opr->param() = {Elemwise::Mode::MUL};
  1099. }
  1100. if (data_type != Param::DataType::DEFAULT) {
  1101. m_side_affect_wkspc = DeviceTensorND{comp_node, dtype::Float32()};
  1102. m_typecvt_opr = intl::get_megdnn_handle(comp_node)->
  1103. create_operator<megdnn::TypeCvt>();
  1104. }
  1105. if (!m_typecvt_opr && !m_elemwise_trans_opr)
  1106. return;
  1107. m_apply_side_effect = [this](const DeviceTensorND &in,
  1108. const DeviceTensorND &out) {
  1109. if (m_typecvt_opr) {
  1110. m_side_affect_wkspc.resize(in.shape());
  1111. }
  1112. if (!m_elemwise_trans_opr) {
  1113. mgb_assert(m_typecvt_opr);
  1114. m_typecvt_opr->exec(in.as_megdnn(), out.as_megdnn());
  1115. return;
  1116. }
  1117. auto im = in.as_megdnn();
  1118. megdnn::TensorND wm;
  1119. if (m_typecvt_opr && in.dtype() != m_side_affect_wkspc.dtype()) {
  1120. m_side_affect_wkspc.resize(in.shape());
  1121. wm = m_side_affect_wkspc.as_megdnn();
  1122. m_typecvt_opr->exec(im, wm);
  1123. } else {
  1124. wm = im;
  1125. }
  1126. if (m_typecvt_opr && wm.layout.dtype != out.dtype()) {
  1127. m_elemwise_trans_opr->exec({wm, wm}, wm);
  1128. m_typecvt_opr->exec(wm, out.as_megdnn());
  1129. } else {
  1130. auto &&wshp = wm.layout;
  1131. if (wshp.ndim != out.layout().ndim) {
  1132. // to ensure that wkspc.ndim equals out.ndim in the case:
  1133. // wkspc.shape=(1, 1, ..., 1) and out.shape=(1), otherwise it
  1134. // may lead the 'TensorShape Dimension' assertion failed in
  1135. // the following broadcast operator
  1136. mgb_assert(wshp.total_nr_elems() == 1 && out.layout().ndim == 1);
  1137. wshp.ndim = 1;
  1138. }
  1139. m_elemwise_trans_opr->exec({wm, wm}, out.as_megdnn());
  1140. }
  1141. };
  1142. }
  1143. void Reduce::KernScheduler::update_ptr(
  1144. const DeviceTensorND &input, const DeviceTensorND &dest,
  1145. const DeviceTensorND &workspace) {
  1146. auto dtype = dest.layout().dtype;
  1147. mgb_assert(dtype.valid());
  1148. mgb_assert(m_shape_computed);
  1149. if (workspace_size()) {
  1150. mgb_assert(workspace.layout().dtype == dtype::Byte() &&
  1151. workspace.layout().ndim == 1 &&
  1152. workspace.shape()[0] >= workspace_size());
  1153. }
  1154. if (m_kern_param.empty())
  1155. return;
  1156. mgb_assert(input.layout().total_nr_elems() ==
  1157. m_kern_param[0].input.layout.total_nr_elems());
  1158. mgb_assert(dest.shape().total_nr_elems() ==
  1159. m_kern_param.back().output.layout.total_nr_elems());
  1160. m_kern_param[0].input.raw_ptr = const_cast<dt_byte*>(input.raw_ptr());
  1161. dt_byte
  1162. *workspace_begin = workspace_size() ?
  1163. const_cast<dt_byte*>(workspace.raw_ptr()) : nullptr,
  1164. *tmp_reduce_ptr[2] = {
  1165. workspace_begin + m_workspace_spec[0].offset,
  1166. workspace_begin + m_workspace_spec[1].offset},
  1167. *kern_workspace = workspace_begin + m_workspace_spec[2].offset;
  1168. for (size_t i = 0; i < m_kern_param.size() - 1; ++ i) {
  1169. auto optr = tmp_reduce_ptr[i % 2];
  1170. m_kern_param[i].output.raw_ptr = optr;
  1171. m_kern_param[i + 1].input.raw_ptr = optr;
  1172. }
  1173. for (auto &&i: m_kern_param)
  1174. i.workspace.raw_ptr = kern_workspace;
  1175. m_kern_param.back().output.raw_ptr = const_cast<dt_byte*>(dest.raw_ptr());
  1176. }
  1177. void Reduce::KernScheduler::execute(
  1178. megdnn::Reduce *opr,
  1179. const DeviceTensorND &input, const DeviceTensorND &dest) {
  1180. if (m_apply_side_effect) {
  1181. mgb_assert(m_kern_param.empty());
  1182. m_apply_side_effect(input, dest);
  1183. return;
  1184. }
  1185. mgb_assert(!m_kern_param.empty());
  1186. mgb_assert(input.layout().is_contiguous() &&
  1187. input.raw_ptr() == m_kern_param[0].input.raw_ptr &&
  1188. dest.raw_ptr() == m_kern_param.back().output.raw_ptr);
  1189. for (auto &&i: m_kern_param) {
  1190. opr->param() = i.KernParam::kparam;
  1191. opr->exec(i.input, i.output, i.workspace);
  1192. }
  1193. }
  1194. class Reduce::OutTensorShapeExtender {
  1195. public:
  1196. OutTensorShapeExtender(const TensorShape& ishp, const TensorShape& oshp)
  1197. : m_oshp(oshp) {
  1198. mgb_assert(oshp.ndim <= ishp.ndim,
  1199. "output ndim should be less and equal than input ndim for "
  1200. "reduction: "
  1201. "ishp=%s oshp=%s",
  1202. ishp.to_string().c_str(), oshp.to_string().c_str());
  1203. // Ex. ishp = (a, b, c, d), oshp = (c, d)
  1204. if (!oshp.is_scalar() && ishp.ndim != oshp.ndim) {
  1205. size_t ndim_diff = ishp.ndim - oshp.ndim;
  1206. auto&& canonized_oshp = m_canonized_oshp_storage.emplace(oshp);
  1207. for (size_t i = 0; i < ishp.ndim; ++i)
  1208. if (i < ndim_diff)
  1209. canonized_oshp[i] = 1;
  1210. else
  1211. canonized_oshp[i] = oshp[i - ndim_diff];
  1212. canonized_oshp.ndim = ishp.ndim;
  1213. }
  1214. }
  1215. const TensorShape& get() const {
  1216. return m_canonized_oshp_storage.valid() ? m_canonized_oshp_storage.val()
  1217. : m_oshp;
  1218. }
  1219. private:
  1220. Maybe<TensorShape> m_canonized_oshp_storage;
  1221. const TensorShape& m_oshp;
  1222. };
  1223. MGB_DYN_TYPE_OBJ_FINAL_IMPL(Reduce);
  1224. Reduce::Reduce(VarNode *inp, VarNode *target_shape, const Param &param,
  1225. const OperatorNodeConfig &config):
  1226. Super{inp->owner_graph(), config,
  1227. ssprintf("reduce%d", static_cast<int>(param.mode)), {inp}},
  1228. m_param{param}, m_kern_scheduler{std::make_unique<KernScheduler>()}
  1229. {
  1230. add_input({inp});
  1231. if (inp->dtype().enumv() == DTypeEnum::Quantized8Asymm &&
  1232. inp->dtype().category() == DTypeCategory::QUANTIZED) {
  1233. mgb_assert(param.mode != Param::Mode::PRODUCT,
  1234. "Reduce does not support PRODUCT mode on quantized input");
  1235. mgb_assert(param.mode != Param::Mode::SUM_SQR,
  1236. "Reduce does not support SUM_SQR mode on quantized input");
  1237. mgb_assert(param.mode != Param::Mode::SUM,
  1238. "Reduce does not support SUM mode on quantized input");
  1239. }
  1240. DType out_dtype;
  1241. switch (param.data_type) {
  1242. case Param::DataType::DEFAULT:
  1243. out_dtype = inp->dtype();
  1244. break;
  1245. #if !MEGDNN_DISABLE_FLOAT16
  1246. case Param::DataType::FLOAT_O16xC32:
  1247. out_dtype = dtype::Float16();
  1248. break;
  1249. case Param::DataType::FLOAT_IO16xC32:
  1250. mgb_assert(false);
  1251. #endif
  1252. case Param::DataType::FLOAT_O32xC32:
  1253. out_dtype = dtype::Float32();
  1254. break;
  1255. case Param::DataType::QUINT_I8xO32:
  1256. out_dtype = dtype::QuantizedS32(
  1257. inp->dtype().param<dtype::Quantized8Asymm>().scale);
  1258. break;
  1259. case Param::DataType::QINT_I8xO32:
  1260. out_dtype = dtype::QuantizedS32(
  1261. inp->dtype().param<dtype::QuantizedS8>().scale);
  1262. break;
  1263. default:
  1264. mgb_throw(GraphError, "invalid param data_type: %d",
  1265. int(param.data_type));
  1266. }
  1267. add_output(None)->dtype(out_dtype);
  1268. cg::add_workspace_output(this);
  1269. add_equivalence_component<PODHash<Param>>(&m_param);
  1270. if (param.axis >= -MEGDNN_MAX_NDIM && param.axis < MEGDNN_MAX_NDIM) {
  1271. mgb_throw_if(target_shape, GraphError,
  1272. "could not specify both axis and target shape");
  1273. m_is_symtshp = false;
  1274. } else {
  1275. mgb_throw_if(!target_shape, GraphError,
  1276. "neither axis or target_shape specified");
  1277. add_input({target_shape});
  1278. m_is_symtshp = true;
  1279. outshape_by_symvar_enable(0, 1);
  1280. }
  1281. }
  1282. Reduce::~Reduce() = default;
  1283. SymbolVar Reduce::make(
  1284. SymbolVar src, Param param, SymbolVar target_shape,
  1285. const OperatorNodeConfig &config) {
  1286. if (param.data_type == Param::DataType::FLOAT_IO16xC32) {
  1287. mgb_log_warn("DataType FLOAT_IO16xC32 has been deprecated "
  1288. "use FLOAT_O16xC32 instead");
  1289. param.data_type = Param::DataType::FLOAT_O16xC32;
  1290. }
  1291. if (param.mode == Mode::SUM &&
  1292. src.node()->owner_opr()->same_type<Elemwise>()) {
  1293. // replace sum(x^2) by sum_sqr(x)
  1294. auto &&opr = src.node()->owner_opr()->cast_final<Elemwise>();
  1295. if (opr.param().mode == Elemwise::Mode::POW) {
  1296. mgb_assert(opr.input().size() == 2);
  1297. auto pow = SymbolVar{opr.input(1)}.as_immutable_scalar();
  1298. if (pow.valid() && pow->get_cast<float>() == 2) {
  1299. src = opr.input(0);
  1300. param.mode = Mode::SUM_SQR;
  1301. }
  1302. }
  1303. }
  1304. return src.insert_single_output_opr<Reduce>(
  1305. src.node(), target_shape.node(), param, config);
  1306. }
  1307. void Reduce::outshape_by_symvar_do_get_output_shape(
  1308. TensorShape &dest, const ShapeInferInfo &shpinfo) {
  1309. cg::copy_tensor_value_to_shape(dest, *shpinfo.shpval_inp_val.at(0));
  1310. }
  1311. void Reduce::init_output_static_infer_desc() {
  1312. using namespace cg::static_infer;
  1313. auto &&mgr = owner_graph()->static_infer_manager();
  1314. // infer output shape
  1315. if (m_is_symtshp) {
  1316. // reduce to target shape
  1317. Super::init_output_static_infer_desc();
  1318. } else {
  1319. // reduce along axis
  1320. auto infer_shape = [this](TensorShape &dest, const InpVal &inp) {
  1321. dest = inp.val.at(0).shape();
  1322. mgb_assert(m_param.axis < static_cast<int>(dest.ndim) &&
  1323. m_param.axis >= -static_cast<int>(dest.ndim),
  1324. "invalid axis for reduction: shape=%s axis=%d",
  1325. dest.to_string().c_str(), m_param.axis);
  1326. int real_axis = m_param.axis;
  1327. if (real_axis < 0)
  1328. real_axis += dest.ndim;
  1329. dest.shape[real_axis] = 1;
  1330. return true;
  1331. };
  1332. mgr.register_shape_infer(
  1333. output(0), {
  1334. SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_shape});
  1335. }
  1336. // infer workspace
  1337. auto infer_workspace = [this](TensorShape &dest, const InpVal &inp) {
  1338. init_kern_sched_shape(inp.val[0].shape(), inp.val[1].shape());
  1339. dest.ndim = 1;
  1340. dest.shape[0] = m_kern_scheduler->workspace_size();
  1341. return true;
  1342. };
  1343. mgr.register_shape_infer(output(1),
  1344. {SourceType::DEP,
  1345. {{input(0), DepType::SHAPE}, {output(0), DepType::SHAPE}},
  1346. infer_workspace});
  1347. // infer value
  1348. static StaticInferOpr<megdnn::Reduce> static_infer_opr;
  1349. auto infer_value = [this](DeviceTensorND &dest, const InpVal &inp) {
  1350. DeviceTensorND workspace;
  1351. auto sopr = static_infer_opr.lock();
  1352. perform(m_param.mode, dest, workspace, inp.val[0].value(),
  1353. output(0)->dtype(), inp.val.at(1).shape(), sopr(),
  1354. m_param.data_type);
  1355. return true;
  1356. };
  1357. mgr.register_value_infer(output(0),
  1358. {SourceType::DEP,
  1359. {{input(0), DepType::VALUE}, {output(0), DepType::SHAPE}},
  1360. infer_value});
  1361. }
  1362. void Reduce::init_kern_sched_shape(const TensorShape& ishp,
  1363. const TensorShape& oshp) {
  1364. OutTensorShapeExtender extender(ishp, oshp);
  1365. auto&& canonized_oshp = extender.get();
  1366. m_kern_scheduler->init_shapes(static_cast<megdnn::Reduce*>(megdnn_opr()),
  1367. comp_node(), input(0)->dtype(), m_param.mode,
  1368. ishp, canonized_oshp, m_param.data_type);
  1369. }
  1370. cg::OperatorNodeBase::OprEventCallback Reduce::get_opr_event_callback() {
  1371. auto on_mem_status_changed = [this]() {
  1372. auto&& ishp = input(0)->shape();
  1373. auto&& oshp = output(0)->shape();
  1374. OutTensorShapeExtender extender(ishp, oshp);
  1375. auto&& canonized_oshp = extender.get();
  1376. m_kern_scheduler->check_shapes(input(0)->shape(), canonized_oshp);
  1377. m_kern_scheduler->update_ptr(
  1378. input(0)->dev_tensor(), output(0)->dev_tensor(),
  1379. output(1)->shape()[0] ? output(1)->dev_tensor()
  1380. : DeviceTensorND{});
  1381. };
  1382. return {on_mem_status_changed};
  1383. }
  1384. void Reduce::mem_plan_fwd_in2out_readonly() {
  1385. init_kern_sched_shape(input(0)->shape(), output(0)->shape());
  1386. if (!m_kern_scheduler->has_actual_computing()) {
  1387. // forward memory if no actual computing needed
  1388. if (!output(0)->mem_plan().valid()) {
  1389. // output(0) is dynamic but current is staic alloc phase (for
  1390. // workspace)
  1391. return;
  1392. }
  1393. auto&& ily = input(0)->layout();
  1394. auto&& oly = output(0)->layout();
  1395. const TensorLayout* fwd_spec = nullptr;
  1396. Maybe<TensorLayout> ily_modified_storage;
  1397. if (!ily.eq_shape(oly)) {
  1398. auto&& ily_modified = ily_modified_storage.emplace(ily);
  1399. mgb_assert(ily.ndim > oly.ndim);
  1400. for (size_t i = 0; i < ily.ndim - oly.ndim; ++i)
  1401. mgb_assert(ily.shape[i] == 1);
  1402. ily_modified = ily_modified.reshape(oly);
  1403. fwd_spec = &ily_modified;
  1404. } else {
  1405. fwd_spec = &ily;
  1406. }
  1407. m_mem_fwd_success = output(0)->set_fwd_in2out_readonly(
  1408. input(0), SubTensorSpec::make_from_layout(*fwd_spec));
  1409. }
  1410. }
  1411. void Reduce::add_input_layout_constraint() {
  1412. if (!cg::is_static_var_shape(output(0))) {
  1413. // output shape can not be inferred; require contiguous to be safe
  1414. input(0)->add_layout_constraint_contiguous();
  1415. } else {
  1416. auto check = [this](const TensorLayout &ily) {
  1417. auto &&mgr = owner_graph()->static_infer_manager();
  1418. auto oshp = mgr.infer_shape(output(0));
  1419. init_kern_sched_shape(ily, oshp);
  1420. if (m_kern_scheduler->has_actual_computing())
  1421. return ily.is_contiguous();
  1422. return true;
  1423. };
  1424. input(0)->add_layout_constraint(check);
  1425. }
  1426. }
  1427. void Reduce::scn_do_execute() {
  1428. auto&& inp = input(0)->dev_tensor();
  1429. auto&& out = output(0)->dev_tensor();
  1430. auto&& ishp = input(0)->shape();
  1431. auto&& oshp = output(0)->shape();
  1432. const DeviceTensorND* out_ptr;
  1433. Maybe<DeviceTensorND> canonized_storage;
  1434. OutTensorShapeExtender extender(ishp, oshp);
  1435. auto&& canonized_oshp = extender.get();
  1436. if (canonized_oshp.ndim != out.shape().ndim) {
  1437. auto&& canonized_out = canonized_storage.emplace(out);
  1438. canonized_out.reset(
  1439. canonized_out.storage(),
  1440. canonized_out.layout().reshape(canonized_oshp));
  1441. out_ptr = &canonized_out;
  1442. } else {
  1443. out_ptr = &out;
  1444. }
  1445. // shape initialized either in deducing workspace,
  1446. // mem_plan_fwd_in2out_readonly, or check input layout
  1447. m_kern_scheduler->check_shapes(inp.shape(), out_ptr->shape());
  1448. if (m_kern_scheduler->has_actual_computing()) {
  1449. m_kern_scheduler->execute(static_cast<megdnn::Reduce*>(megdnn_opr()),
  1450. inp, *out_ptr);
  1451. } else {
  1452. // no reduction needed, just forward
  1453. if (m_mem_fwd_success) {
  1454. mgb_assert(inp.raw_ptr() == out_ptr->raw_ptr() &&
  1455. out_ptr->layout().total_nr_elems() ==
  1456. inp.layout().total_nr_elems());
  1457. } else {
  1458. if (!out_ptr->shape().eq_shape(inp.shape())) {
  1459. mgb_assert(out_ptr->shape().is_scalar() &&
  1460. inp.shape().total_nr_elems() == 1);
  1461. out_ptr->sub(SubTensorSpec::make_from_layout(inp.layout()))
  1462. .copy_from_fixlayout(inp);
  1463. } else {
  1464. out_ptr->copy_from_fixlayout(inp);
  1465. }
  1466. }
  1467. }
  1468. }
  1469. void Reduce::perform(
  1470. Mode mode,
  1471. DeviceTensorND &dest, DeviceTensorND &workspace,
  1472. const DeviceTensorND &input,
  1473. const DType &target_dtype,
  1474. const TensorShape &target_shape,
  1475. intl::UniqPtrWithCN<megdnn::Reduce> &opr, const Param::DataType data_type) {
  1476. mgb_assert(!dest.storage().comp_node_valid() ||
  1477. opr.comp_node() == dest.comp_node());
  1478. KernScheduler ksched;
  1479. OutTensorShapeExtender extender(input.shape(), target_shape);
  1480. auto&& canonized_oshp = extender.get();
  1481. ksched.init_shapes(opr.get(), opr.comp_node(), input.layout().dtype,
  1482. mode, input.shape(), canonized_oshp, data_type);
  1483. if (!ksched.has_actual_computing()) {
  1484. mgb_assert(target_shape.total_nr_elems() ==
  1485. input.layout().total_nr_elems());
  1486. dest.copy_from(input);
  1487. dest.reset(dest.storage(), {target_shape, dest.dtype()});
  1488. return;
  1489. }
  1490. workspace.
  1491. comp_node(opr.comp_node()).
  1492. dtype(dtype::Byte());
  1493. size_t workspace_size = ksched.workspace_size();
  1494. DeviceTensorND input_contig_storage;
  1495. const DeviceTensorND *input_contig = &input;
  1496. if (!input.layout().is_contiguous()) {
  1497. auto offset = get_aligned_power2(
  1498. workspace_size, opr.comp_node().get_mem_addr_alignment());
  1499. workspace_size = offset +
  1500. input.dtype().size(input.shape().total_nr_elems());
  1501. workspace.resize({workspace_size});
  1502. input_contig_storage.
  1503. reset(workspace.storage().sub(offset), {
  1504. input.shape(), input.dtype()}).
  1505. copy_from(input);
  1506. input_contig = &input_contig_storage;
  1507. } else {
  1508. workspace.resize({workspace_size});
  1509. }
  1510. opr.comp_node().activate();
  1511. dest.comp_node(opr.comp_node()).dtype(target_dtype).resize(target_shape);
  1512. ksched.update_ptr(*input_contig, dest, workspace);
  1513. ksched.execute(opr.get(), *input_contig, dest);
  1514. }
  1515. void Reduce::create_megdnn_opr() {
  1516. set_megdnn_opr(intl::get_megdnn_handle(comp_node())->
  1517. create_operator<megdnn::Reduce>());
  1518. }
  1519. #if MGB_ENABLE_GRAD
  1520. MGB_IMPL_OPR_GRAD(Reduce) {
  1521. for (size_t i = 1; i < opr.output().size(); ++ i)
  1522. mgb_assert(!out_grad[i]);
  1523. if (wrt_idx || opr.input(0)->dtype().category() != DTypeCategory::FLOAT)
  1524. return InvalidGrad::make(opr, wrt_idx);
  1525. SymbolVar og{out_grad[0]}, iv{opr.input(0)}, ov{opr.output(0)};
  1526. constexpr auto cmv = Elemwise::Mode::COND_LEQ_MOV;
  1527. using Mode = Reduce::Mode;
  1528. SymbolVar grad = [&]() {
  1529. switch (opr.param().mode) {
  1530. case Mode::SUM:
  1531. return Broadcast::make(og, GetVarShape::make(iv));
  1532. case Mode::SUM_SQR:
  1533. return (og * og.make_scalar_dt(2) * iv);
  1534. case Mode::PRODUCT:
  1535. return ((og * ov) / iv);
  1536. case Mode::MIN:
  1537. return Elemwise::make({iv, ov, og}, cmv);
  1538. case Mode::MAX:
  1539. return Elemwise::make({ov, iv, og}, cmv);
  1540. case Mode::MEAN: {
  1541. auto og_shape = opr::GetVarShape::make(og),
  1542. iv_shape = opr::GetVarShape::make(iv),
  1543. scale = div(
  1544. opr::reduce_prod(og_shape, og_shape.make_scalar(1)),
  1545. opr::reduce_prod(iv_shape, iv_shape.make_scalar(1)));
  1546. return scale * Broadcast::make(og, GetVarShape::make(iv));
  1547. }
  1548. default:
  1549. mgb_throw(MegBrainError, "bad reduce mode");
  1550. }
  1551. }();
  1552. grad = TypeCvt::make(grad, iv.dtype());
  1553. return grad.node();
  1554. }
  1555. #endif
  1556. void Reduce::record_execute_deps(ExecDependencyArray& deps) {
  1557. record_megdnn_opr(deps);
  1558. m_kern_scheduler->record_execute_deps(deps);
  1559. }
  1560. /* =========================== PowC =========================== */
  1561. MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC);
  1562. PowC::PowC(VarNode *i0, const Param &param, const OperatorNodeConfig &config)
  1563. : Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, ssprintf("powc_%g", param.exp), {i0}} ) {
  1564. init_megdnn_opr(*this, param);
  1565. add_input({i0});
  1566. output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
  1567. intl::MegDNNOprInitPostCtor<PowC>::apply(*this);
  1568. }
  1569. SymbolVar PowC::make(SymbolVar x, const Param& param,
  1570. const OperatorNodeConfig& config) {
  1571. if (almost_equal(param.exp, 1.f)) {
  1572. return x;
  1573. }
  1574. if (almost_equal(param.exp, 0.f)) {
  1575. return x.make_scalar_dt(1).broadcast(x.symshape());
  1576. }
  1577. return x.insert_single_output_opr<PowC>(x.node(), param, config);
  1578. }
  1579. void PowC::add_input_layout_constraint() {
  1580. input(0)->add_layout_constraint_monotone();
  1581. }
  1582. void PowC::mem_plan_fwd_in2out_writable() {
  1583. output(0)->set_fwd_in2out_writable(input(0));
  1584. }
  1585. void PowC::init_output_static_infer_desc() {
  1586. Super::init_output_static_infer_desc();
  1587. static StaticInferOpr<megdnn::PowC> static_infer_opr;
  1588. using namespace cg::static_infer;
  1589. auto infer_value = [this](DeviceTensorND& dest, const InpVal& inp) {
  1590. auto infer_opr_lock = static_infer_opr.lock();
  1591. auto&& infer_opr = infer_opr_lock();
  1592. infer_opr->param() = this->param();
  1593. auto&& ival = inp.val[0].value().as_megdnn();
  1594. infer_opr->exec(ival, dest.resize(ival.layout).as_megdnn());
  1595. return true;
  1596. };
  1597. owner_graph()->static_infer_manager().register_value_infer(
  1598. output(0),
  1599. {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
  1600. }
  1601. void PowC::scn_do_execute() {
  1602. if (input(0)->dev_tensor().empty()) {
  1603. mgb_assert(output(0)->dev_tensor().empty());
  1604. return;
  1605. }
  1606. mgb_assert(!output(0)->dev_tensor().empty());
  1607. Super::scn_do_execute();
  1608. }
  1609. PowC::NodeProp* PowC::do_make_node_prop() const {
  1610. auto ret = Super::do_make_node_prop();
  1611. ret->add_dep_type_existing_var(input(0),
  1612. NodeProp::DepType::VALUE_ALLOW_EMPTY);
  1613. return ret;
  1614. }
  1615. #if MGB_ENABLE_GRAD
  1616. MGB_IMPL_OPR_GRAD(PowC) {
  1617. auto exp = opr.param().exp;
  1618. return (exp * SymbolVar{out_grad[0]} *
  1619. PowC::make(opr.input(0), exp - 1, opr.config()))
  1620. .node();
  1621. }
  1622. #endif
  1623. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台