GitOrigin-RevId: ddc8af79af
tags/v1.0.0-rc1
| @@ -43,6 +43,7 @@ add_library(megbrain OBJECT EXCLUDE_FROM_ALL ${SOURCES}) | |||||
| target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) | target_link_libraries(megbrain PUBLIC mgb_opr_param_defs) | ||||
| target_include_directories(megbrain | target_include_directories(megbrain | ||||
| PUBLIC $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}> | PUBLIC $<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}> | ||||
| PRIVATE ${PROJECT_SOURCE_DIR}/third_party/midout/src | |||||
| ) | ) | ||||
| foreach (INCPATH IN LISTS MGB_INC) | foreach (INCPATH IN LISTS MGB_INC) | ||||
| target_include_directories(megbrain | target_include_directories(megbrain | ||||
| @@ -15,6 +15,20 @@ | |||||
| #include <deque> | #include <deque> | ||||
| //! TODO: here has to be know some megdnn::opr when there is produced midout.h | |||||
| //! fix it if there is another graceful way. | |||||
| #include "megdnn/oprs.h" | |||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_chain) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_chain, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| using namespace opr; | using namespace opr; | ||||
| @@ -132,6 +146,7 @@ const char* ExpandFusedArithPass::name() const { | |||||
| } | } | ||||
| void ExpandFusedArithPass::apply(OptState &opt) const { | void ExpandFusedArithPass::apply(OptState &opt) const { | ||||
| MIDOUT_B("ExpandFusedArithPass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto on_opr = [&](OperatorNodeBase *opr) { | auto on_opr = [&](OperatorNodeBase *opr) { | ||||
| using Mode = Elemwise::Mode; | using Mode = Elemwise::Mode; | ||||
| @@ -172,6 +187,7 @@ void ExpandFusedArithPass::apply(OptState &opt) const { | |||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ NormalizeArithChainPass ================ */ | /* ================ NormalizeArithChainPass ================ */ | ||||
| @@ -529,7 +545,9 @@ const char* NormalizeArithChainPass::name() const { | |||||
| } | } | ||||
| void NormalizeArithChainPass::apply(OptState &opt) const { | void NormalizeArithChainPass::apply(OptState &opt) const { | ||||
| MIDOUT_B("NormalizeArithChainPass::apply") | |||||
| Impl{opt}; | Impl{opt}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ ReorderArithChainPass ================ */ | /* ================ ReorderArithChainPass ================ */ | ||||
| @@ -737,7 +755,9 @@ const char* ReorderArithChainPass::name() const { | |||||
| } | } | ||||
| void ReorderArithChainPass::apply(OptState &opt) const { | void ReorderArithChainPass::apply(OptState &opt) const { | ||||
| MIDOUT_B("ReorderArithChainPass::apply") | |||||
| Impl{*this, opt}; | Impl{*this, opt}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ ArithFusePass ================ */ | /* ================ ArithFusePass ================ */ | ||||
| @@ -944,8 +964,9 @@ const char* ArithFusePass::name() const { | |||||
| } | } | ||||
| void ArithFusePass::apply(OptState &opt) const { | void ArithFusePass::apply(OptState &opt) const { | ||||
| MIDOUT_B("ArithFusePass::apply") | |||||
| Impl{opt}; | Impl{opt}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -19,6 +19,16 @@ | |||||
| #include <cmath> | #include <cmath> | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_inplace) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_inplace, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace opr; | using namespace opr; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -150,8 +160,10 @@ bool gopt::has_inplace_basic_arith_opt(const cg::OperatorNodeBase& opr) { | |||||
| const inplace_optimize::OptimizerRegistry& | const inplace_optimize::OptimizerRegistry& | ||||
| inplace_optimize::optimizer_registry() { | inplace_optimize::optimizer_registry() { | ||||
| MIDOUT_B("inplace_optimize::optimizer_registry") | |||||
| static OptimizerRegistry ret = make_optimizer_registry(); | static OptimizerRegistry ret = make_optimizer_registry(); | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| inplace_optimize::OptimizerRegistry | inplace_optimize::OptimizerRegistry | ||||
| @@ -13,6 +13,20 @@ | |||||
| #include "megbrain/gopt/basic_arith.h" | #include "megbrain/gopt/basic_arith.h" | ||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| //! TODO: here has to be know some megdnn::opr when there is produced midout.h | |||||
| //! fix it if there is another graceful way. | |||||
| #include "megdnn/oprs.h" | |||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_trans) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_trans, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -284,7 +298,9 @@ const char* ArithMulDistributePass::name() const { | |||||
| } | } | ||||
| void ArithMulDistributePass::apply(OptState &opt) const { | void ArithMulDistributePass::apply(OptState &opt) const { | ||||
| MIDOUT_B("ArithMulDistributePass::apply") | |||||
| Impl{*this, opt}; | Impl{*this, opt}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ FinalArithTransformPass ================ */ | /* ================ FinalArithTransformPass ================ */ | ||||
| @@ -488,7 +504,9 @@ const char* FinalArithTransformPass::name() const { | |||||
| } | } | ||||
| void FinalArithTransformPass::apply(OptState &opt) const { | void FinalArithTransformPass::apply(OptState &opt) const { | ||||
| MIDOUT_B("FinalArithTransformPass::apply") | |||||
| Impl{*this, opt}; | Impl{*this, opt}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
| #include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
| #include "megbrain/opr/tensor_gen.h" | #include "megbrain/opr/tensor_gen.h" | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "megdnn/tensor_format.h" | #include "megdnn/tensor_format.h" | ||||
| @@ -36,6 +37,16 @@ | |||||
| #include "megbrain/gopt/misc.h" | #include "megbrain/gopt/misc.h" | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_inference) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_inference, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -430,7 +441,9 @@ ParamRedistributePass::Impl::Impl(OptState &state): | |||||
| } | } | ||||
| void ParamRedistributePass::apply(OptState &state) const { | void ParamRedistributePass::apply(OptState &state) const { | ||||
| MIDOUT_B("ParamRedistributePass::apply") | |||||
| Impl{state}; | Impl{state}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ ParamFusePass ================ */ | /* ================ ParamFusePass ================ */ | ||||
| @@ -512,6 +525,7 @@ const char* ParamFusePass::name() const { | |||||
| } | } | ||||
| void ParamFusePass::apply(OptState &state) const { | void ParamFusePass::apply(OptState &state) const { | ||||
| MIDOUT_B("ParamFusePass::apply") | |||||
| auto rewriter = state.graph().make_rewriter(); | auto rewriter = state.graph().make_rewriter(); | ||||
| auto cg = state.graph().comp_graph(); | auto cg = state.graph().comp_graph(); | ||||
| @@ -613,6 +627,7 @@ void ParamFusePass::apply(OptState &state) const { | |||||
| state.graph().iter(replace_opr); | state.graph().iter(replace_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ One2OneOprReplacePass ================ */ | /* ================ One2OneOprReplacePass ================ */ | ||||
| @@ -621,6 +636,7 @@ const char* ConvertF32ToF16Pass::name() const { | |||||
| } | } | ||||
| void ConvertF32ToF16Pass::apply(OptState& state) const { | void ConvertF32ToF16Pass::apply(OptState& state) const { | ||||
| MIDOUT_B("ConvertF32ToF16Pass::apply") | |||||
| state.set_var_replace_check_flag(m_var_replace_check_flag); | state.set_var_replace_check_flag(m_var_replace_check_flag); | ||||
| auto rewriter = state.graph().make_rewriter(); | auto rewriter = state.graph().make_rewriter(); | ||||
| VarNodeArray new_inp_cache; | VarNodeArray new_inp_cache; | ||||
| @@ -674,6 +690,7 @@ void ConvertF32ToF16Pass::apply(OptState& state) const { | |||||
| auto opr = endpoints[0].node()->owner_opr(); | auto opr = endpoints[0].node()->owner_opr(); | ||||
| state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE); | state.call_with_opr(opr, replace_output, OprPropertyFlag::NONE); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | ||||
| @@ -940,6 +957,7 @@ std::unique_ptr<ConvertF32ToF16Pass> ConvertF32ToF16Pass::make( | |||||
| /* ================ ConvertFormatPass ================ */ | /* ================ ConvertFormatPass ================ */ | ||||
| void ConvertFormatPass::apply(OptState& state) const { | void ConvertFormatPass::apply(OptState& state) const { | ||||
| MIDOUT_B("ConvertFormatPass::apply") | |||||
| state.set_var_replace_check_flag(m_var_replace_check_flag); | state.set_var_replace_check_flag(m_var_replace_check_flag); | ||||
| auto rewriter = state.graph().make_rewriter(); | auto rewriter = state.graph().make_rewriter(); | ||||
| VarNodeArray new_inp_cache; | VarNodeArray new_inp_cache; | ||||
| @@ -994,9 +1012,11 @@ void ConvertFormatPass::apply(OptState& state) const { | |||||
| }; | }; | ||||
| state.graph().iter(on_opr); | state.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | ||||
| MIDOUT_B("ConvertFormatPass::make") | |||||
| auto filter_mode = | auto filter_mode = | ||||
| [](const megdnn::param::Convolution::Sparse conv_mode, | [](const megdnn::param::Convolution::Sparse conv_mode, | ||||
| const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode { | const VarNode* filter) -> megdnn::param::RelayoutFormat::Mode { | ||||
| @@ -1551,6 +1571,7 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| replace_func[opr::GroupLocalForward::typeinfo()] = | replace_func[opr::GroupLocalForward::typeinfo()] = | ||||
| relayout_first_inp_to_chw; | relayout_first_inp_to_chw; | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ ConvertBatchNormPass ================ */ | /* ================ ConvertBatchNormPass ================ */ | ||||
| @@ -1559,6 +1580,7 @@ const char* ConvertBatchNormToElemwisePass::name() const { | |||||
| } | } | ||||
| void ConvertBatchNormToElemwisePass::apply(OptState& state) const { | void ConvertBatchNormToElemwisePass::apply(OptState& state) const { | ||||
| MIDOUT_B("ConvertBatchNormToElemwisePass::apply") | |||||
| auto rewriter = state.graph().make_rewriter(); | auto rewriter = state.graph().make_rewriter(); | ||||
| auto on_opr = [&](OperatorNodeBase* opr) { | auto on_opr = [&](OperatorNodeBase* opr) { | ||||
| if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) { | if (auto bn = try_cast_as_op<opr::BatchNorm>(opr)) { | ||||
| @@ -1586,6 +1608,7 @@ void ConvertBatchNormToElemwisePass::apply(OptState& state) const { | |||||
| state.graph().iter(on_opr); | state.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ FuseConvBiasNonlinPass ================ */ | /* ================ FuseConvBiasNonlinPass ================ */ | ||||
| @@ -1594,6 +1617,7 @@ const char* FuseConvBiasNonlinPass::name() const { | |||||
| } | } | ||||
| void FuseConvBiasNonlinPass::apply(OptState& state) const { | void FuseConvBiasNonlinPass::apply(OptState& state) const { | ||||
| MIDOUT_B("FuseConvBiasNonlinPass::apply") | |||||
| std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps; | std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps; | ||||
| state.graph().iter([&m_deps](OperatorNodeBase* opr) { | state.graph().iter([&m_deps](OperatorNodeBase* opr) { | ||||
| for (auto& inp : opr->input()) { | for (auto& inp : opr->input()) { | ||||
| @@ -1843,6 +1867,7 @@ void FuseConvBiasNonlinPass::apply(OptState& state) const { | |||||
| state.graph().iter(on_opr); | state.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ FuseConvBiasZPass ================ */ | /* ================ FuseConvBiasZPass ================ */ | ||||
| @@ -1851,6 +1876,7 @@ const char* FuseConvBiasZPass::name() const { | |||||
| } | } | ||||
| void FuseConvBiasZPass::apply(OptState& state) const { | void FuseConvBiasZPass::apply(OptState& state) const { | ||||
| MIDOUT_B("FuseConvBiasZPass::apply") | |||||
| UniqReaderCheck uniq_reader_check{state.graph()}; | UniqReaderCheck uniq_reader_check{state.graph()}; | ||||
| auto rewriter = state.graph().make_rewriter(); | auto rewriter = state.graph().make_rewriter(); | ||||
| @@ -1977,6 +2003,7 @@ void FuseConvBiasZPass::apply(OptState& state) const { | |||||
| state.graph().iter(on_opr); | state.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ FuseDeconvCvtPass ================ */ | /* ================ FuseDeconvCvtPass ================ */ | ||||
| @@ -1986,6 +2013,7 @@ const char* FuseDeconvCvtPass::name() const { | |||||
| void FuseDeconvCvtPass::apply(OptState& state) const { | void FuseDeconvCvtPass::apply(OptState& state) const { | ||||
| MIDOUT_B("FuseDeconvCvtPass::apply") | |||||
| std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps; | std::unordered_map<VarNode*, std::vector<OperatorNodeBase*>> m_deps; | ||||
| state.graph().iter([&m_deps](OperatorNodeBase* opr) { | state.graph().iter([&m_deps](OperatorNodeBase* opr) { | ||||
| for (auto& inp : opr->input()) { | for (auto& inp : opr->input()) { | ||||
| @@ -2036,6 +2064,7 @@ void FuseDeconvCvtPass::apply(OptState& state) const { | |||||
| state.graph().iter(on_opr); | state.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ ParamMergePass ================ */ | /* ================ ParamMergePass ================ */ | ||||
| @@ -2044,10 +2073,12 @@ const char* ParamMergePass::name() const { | |||||
| } | } | ||||
| void ParamMergePass::apply(OptState& opt_state) const { | void ParamMergePass::apply(OptState& opt_state) const { | ||||
| MIDOUT_B("ParamMergePass::apply") | |||||
| param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>( | param_merge<opr::SharedDeviceTensor, opr::MultipleDeviceTensorHolder>( | ||||
| opt_state); | opt_state); | ||||
| param_merge<opr::SharedDeviceTensorWithFormat, | param_merge<opr::SharedDeviceTensorWithFormat, | ||||
| opr::MultipleDeviceTensorWithFormatHolder>(opt_state); | opr::MultipleDeviceTensorWithFormatHolder>(opt_state); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -19,6 +19,16 @@ | |||||
| #include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
| #include "../../core/impl/graph/cg_impl.h" | #include "../../core/impl/graph/cg_impl.h" | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_misc) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_misc, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -29,6 +39,7 @@ const char* RemoveNonComputingOprPass::name() const { | |||||
| } | } | ||||
| void RemoveNonComputingOprPass::apply(OptState& opt) const { | void RemoveNonComputingOprPass::apply(OptState& opt) const { | ||||
| MIDOUT_B("RemoveNonComputingOprPass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto on_opr = [&](OperatorNodeBase* opr) { | auto on_opr = [&](OperatorNodeBase* opr) { | ||||
| auto type = opr->dyn_typeinfo(); | auto type = opr->dyn_typeinfo(); | ||||
| @@ -75,6 +86,7 @@ void RemoveNonComputingOprPass::apply(OptState& opt) const { | |||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ ExpandVirtualGradPass ================ */ | /* ================ ExpandVirtualGradPass ================ */ | ||||
| @@ -84,6 +96,7 @@ const char* ExpandVirtualGradPass::name() const { | |||||
| } | } | ||||
| void ExpandVirtualGradPass::apply(OptState& opt) const { | void ExpandVirtualGradPass::apply(OptState& opt) const { | ||||
| MIDOUT_B("ExpandVirtualGradPass::apply") | |||||
| #if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
| opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| @@ -111,6 +124,7 @@ void ExpandVirtualGradPass::apply(OptState& opt) const { | |||||
| #else | #else | ||||
| MGB_MARK_USED_VAR(opt); | MGB_MARK_USED_VAR(opt); | ||||
| #endif | #endif | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================= DelayBroadcastPass ================ */ | /* ================= DelayBroadcastPass ================ */ | ||||
| @@ -144,6 +158,7 @@ void DelayBroadcastPass::apply(OptState& opt) const { | |||||
| // remove them from the chain, and add them back right after the endpoint. | // remove them from the chain, and add them back right after the endpoint. | ||||
| // TypeCvt's order may change, so disable the check. | // TypeCvt's order may change, so disable the check. | ||||
| MIDOUT_B("DelayBroadcastPass::apply") | |||||
| opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | opt.set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
| auto unique_reader_chk = UniqReaderCheck{opt.graph()}; | auto unique_reader_chk = UniqReaderCheck{opt.graph()}; | ||||
| @@ -325,6 +340,7 @@ void DelayBroadcastPass::apply(OptState& opt) const { | |||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ======================= RecompTypeCvtPass ====================== */ | /* ======================= RecompTypeCvtPass ====================== */ | ||||
| @@ -334,6 +350,7 @@ const char* RecompTypeCvtPass::name() const { | |||||
| } | } | ||||
| void RecompTypeCvtPass::apply(OptState& opt) const { | void RecompTypeCvtPass::apply(OptState& opt) const { | ||||
| MIDOUT_B("RecompTypeCvtPass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto allowed_typecvt = [](OperatorNodeBase* opr) -> OperatorNodeBase* { | auto allowed_typecvt = [](OperatorNodeBase* opr) -> OperatorNodeBase* { | ||||
| @@ -399,6 +416,7 @@ void RecompTypeCvtPass::apply(OptState& opt) const { | |||||
| }; | }; | ||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ======================= CombineAstypeAndReducePass ====================== */ | /* ======================= CombineAstypeAndReducePass ====================== */ | ||||
| @@ -408,6 +426,7 @@ const char* CombineAstypeAndReducePass::name() const { | |||||
| } | } | ||||
| void CombineAstypeAndReducePass::apply(OptState& opt) const { | void CombineAstypeAndReducePass::apply(OptState& opt) const { | ||||
| MIDOUT_B("CombineAstypeAndReducePass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| using DataType = opr::Reduce::Param::DataType; | using DataType = opr::Reduce::Param::DataType; | ||||
| @@ -453,6 +472,7 @@ void CombineAstypeAndReducePass::apply(OptState& opt) const { | |||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ CondExecConstPredicateFolding ================ */ | /* ================ CondExecConstPredicateFolding ================ */ | ||||
| @@ -462,6 +482,7 @@ const char* CondExecConstPredicateFolding::name() const { | |||||
| void CondExecConstPredicateFolding::apply(OptState& opt) const { | void CondExecConstPredicateFolding::apply(OptState& opt) const { | ||||
| #if MGB_ENABLE_COND_EXEC | #if MGB_ENABLE_COND_EXEC | ||||
| MIDOUT_B("CondExecConstPredicateFolding::apply") | |||||
| if (!cg::ExecutionMask::have_alive_instance()) { | if (!cg::ExecutionMask::have_alive_instance()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -605,6 +626,7 @@ void CondExecConstPredicateFolding::apply(OptState& opt) const { | |||||
| } | } | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| #endif // MGB_ENABLE_COND_EXEC | #endif // MGB_ENABLE_COND_EXEC | ||||
| } | } | ||||
| @@ -632,6 +654,7 @@ bool RemoveRedundantTypeCvtPass::should_remove(DType A, DType B) { | |||||
| } | } | ||||
| void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | ||||
| MIDOUT_B("RemoveRedundantTypeCvtPass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| auto on_opr = [&](OperatorNodeBase* opr) { | auto on_opr = [&](OperatorNodeBase* opr) { | ||||
| @@ -656,6 +679,7 @@ void RemoveRedundantTypeCvtPass::apply(OptState& opt) const { | |||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| @@ -668,6 +692,7 @@ const char* PackAllReduceScanPass::name() const { | |||||
| } | } | ||||
| void PackAllReduceScanPass::apply(OptState& opt) const { | void PackAllReduceScanPass::apply(OptState& opt) const { | ||||
| MIDOUT_B("PackAllReduceScanPass::apply") | |||||
| auto comp_graph = opt.graph().comp_graph(); | auto comp_graph = opt.graph().comp_graph(); | ||||
| if (comp_graph->options().allreduce_pack_max_size == 0) return; | if (comp_graph->options().allreduce_pack_max_size == 0) return; | ||||
| auto cb_scan = [this] (OperatorNodeBase* opr) { | auto cb_scan = [this] (OperatorNodeBase* opr) { | ||||
| @@ -682,6 +707,7 @@ void PackAllReduceScanPass::apply(OptState& opt) const { | |||||
| } | } | ||||
| }; | }; | ||||
| opt.graph().iter(cb_scan); | opt.graph().iter(cb_scan); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | bool PackAllReduceScanPass::check_pattern(OperatorNodeBase* opr) { | ||||
| @@ -856,6 +882,7 @@ void PackAllReduceReplacePass::insert_packed_oprs( | |||||
| } | } | ||||
| void PackAllReduceReplacePass::apply(OptState& opt) const { | void PackAllReduceReplacePass::apply(OptState& opt) const { | ||||
| MIDOUT_B("PackAllReduceReplacePass::apply") | |||||
| // get graph options | // get graph options | ||||
| auto comp_graph = opt.graph().comp_graph(); | auto comp_graph = opt.graph().comp_graph(); | ||||
| size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; | size_t max_size = comp_graph->options().allreduce_pack_max_size * 1024 * 1024; | ||||
| @@ -917,6 +944,7 @@ void PackAllReduceReplacePass::apply(OptState& opt) const { | |||||
| }; | }; | ||||
| opt.graph().iter(cb_replace); | opt.graph().iter(cb_replace); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| #else | #else | ||||
| @@ -36,6 +36,16 @@ | |||||
| #endif | #endif | ||||
| #include "megbrain/gopt/misc.h" | #include "megbrain/gopt/misc.h" | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_tensor_reformat) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| @@ -755,8 +765,10 @@ void TensorReformatPass::translate_pass(OptState& opt) const { | |||||
| } | } | ||||
| void TensorReformatPass::apply(OptState& opt) const { | void TensorReformatPass::apply(OptState& opt) const { | ||||
| MIDOUT_B("TensorReformatPass::apply") | |||||
| insert_pass(opt); | insert_pass(opt); | ||||
| translate_pass(opt); | translate_pass(opt); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ EnableTensorCorePass =============== */ | /* ================ EnableTensorCorePass =============== */ | ||||
| @@ -773,6 +785,7 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var, | |||||
| std::unique_ptr<EnableTensorCorePass> | std::unique_ptr<EnableTensorCorePass> | ||||
| EnableTensorCorePass::make_tensorcore_converter() { | EnableTensorCorePass::make_tensorcore_converter() { | ||||
| MIDOUT_B("EnableTensorCorePass::make") | |||||
| // replace rule for conv bias opr | // replace rule for conv bias opr | ||||
| auto replace_conv_bias_opr = [](OperatorNodeBase* opr, | auto replace_conv_bias_opr = [](OperatorNodeBase* opr, | ||||
| const VarNodeArray& new_inp) { | const VarNodeArray& new_inp) { | ||||
| @@ -1111,6 +1124,7 @@ EnableTensorCorePass::make_tensorcore_converter() { | |||||
| replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; | replace_func[opr::GetVarShape::typeinfo()] = replace_inps_to_nchw4; | ||||
| replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; | replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ EnableCHWN4Pass =============== */ | /* ================ EnableCHWN4Pass =============== */ | ||||
| @@ -1125,6 +1139,7 @@ VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var, | |||||
| } | } | ||||
| std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { | std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { | ||||
| MIDOUT_B("EnableCHWN4Pass::make") | |||||
| auto ret = std::make_unique<EnableCHWN4Pass>(); | auto ret = std::make_unique<EnableCHWN4Pass>(); | ||||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
| auto&& replace_func = ret->m_opr_replace_func; | auto&& replace_func = ret->m_opr_replace_func; | ||||
| @@ -1381,6 +1396,7 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { | |||||
| replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; | replace_func[opr::Dimshuffle::typeinfo()] = replace_inps_to_nchw4; | ||||
| replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; | replace_func[opr::BatchConvBias::typeinfo()] = replace_inps_to_nchw4; | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ EnableNCHW4Pass ================ */ | /* ================ EnableNCHW4Pass ================ */ | ||||
| @@ -1395,6 +1411,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var, | |||||
| } | } | ||||
| std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | ||||
| MIDOUT_B("EnableNCHW4Pass::make") | |||||
| auto ret = std::make_unique<EnableNCHW4Pass>(); | auto ret = std::make_unique<EnableNCHW4Pass>(); | ||||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
| using RelayoutMode = RelayoutPlaceholder::LayoutType; | using RelayoutMode = RelayoutPlaceholder::LayoutType; | ||||
| @@ -1772,6 +1789,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ | |||||
| replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; | ||||
| replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ EnableNchwxxPass =============== */ | /* ================ EnableNchwxxPass =============== */ | ||||
| @@ -2140,6 +2158,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){ | |||||
| std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | ||||
| size_t pack_c_size) { | size_t pack_c_size) { | ||||
| MIDOUT_B("EnableNchwxxPass::make") | |||||
| auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size); | auto ret = std::make_unique<EnableNchwxxPass>(pack_c_size); | ||||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
| std::string convter_pass_name = "conv_format_nchw88"; | std::string convter_pass_name = "conv_format_nchw88"; | ||||
| @@ -2149,6 +2168,7 @@ std::unique_ptr<EnableNchwxxPass> EnableNchwxxPass::make_nchwxx_converter( | |||||
| ret->fill_opr_convert_fun(pack_c_size); | ret->fill_opr_convert_fun(pack_c_size); | ||||
| ret->set_name(convter_pass_name); | ret->set_name(convter_pass_name); | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ================ EnableNchw44DotPass =============== */ | /* ================ EnableNchw44DotPass =============== */ | ||||
| @@ -2164,6 +2184,7 @@ VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var, | |||||
| std::unique_ptr<EnableNchw44DotPass> | std::unique_ptr<EnableNchw44DotPass> | ||||
| EnableNchw44DotPass::make_nchw44_dot_converter() { | EnableNchw44DotPass::make_nchw44_dot_converter() { | ||||
| MIDOUT_B("EnableNchw44DotPass::make") | |||||
| auto ret = std::make_unique<EnableNchw44DotPass>(); | auto ret = std::make_unique<EnableNchw44DotPass>(); | ||||
| ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
| //! First is whether the conv can trans to nchwxx, second is the filter | //! First is whether the conv can trans to nchwxx, second is the filter | ||||
| @@ -2384,6 +2405,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
| replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | ||||
| replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | ||||
| return ret; | return ret; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* ==================== ShuffleShuffleRemovePass ================= */ | /* ==================== ShuffleShuffleRemovePass ================= */ | ||||
| @@ -2961,9 +2983,11 @@ const char* ShuffleShuffleRemovePass::name() const { | |||||
| } | } | ||||
| void ShuffleShuffleRemovePass::apply(OptState& opt) const { | void ShuffleShuffleRemovePass::apply(OptState& opt) const { | ||||
| MIDOUT_B("ShuffleShuffleRemovePass::apply") | |||||
| opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE | | opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_SHAPE | | ||||
| VarReplaceCheckFlag::CHECK_DTYPE); | VarReplaceCheckFlag::CHECK_DTYPE); | ||||
| Impl{opt}; | Impl{opt}; | ||||
| MIDOUT_E | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -14,6 +14,16 @@ | |||||
| #include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_weight_preprocess) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_weight_preprocess, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace gopt; | using namespace gopt; | ||||
| using namespace cg; | using namespace cg; | ||||
| @@ -23,6 +33,7 @@ const char* WinogradTransformReplacePass::name() const { | |||||
| } | } | ||||
| void WinogradTransformReplacePass::apply(OptState& opt) const { | void WinogradTransformReplacePass::apply(OptState& opt) const { | ||||
| MIDOUT_B("WinogradTransformReplacePass::apply") | |||||
| auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
| ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; | ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; | ||||
| opt.graph().iter([&cvprop](OperatorNodeBase *opr) { | opt.graph().iter([&cvprop](OperatorNodeBase *opr) { | ||||
| @@ -174,6 +185,7 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { | |||||
| opt.graph().iter(on_opr); | opt.graph().iter(on_opr); | ||||
| rewriter.apply_inplace(); | rewriter.apply_inplace(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /** | /** | ||||
| @@ -855,10 +855,12 @@ VarNode* CollectiveComm::grad(VarNode* out_grad) const { | |||||
| return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); | return ModeTrait::from_mode(m_param.mode).grad(out_grad, this); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(CollectiveComm) { | MGB_IMPL_OPR_GRAD(CollectiveComm) { | ||||
| mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); | mgb_assert(out_grad.size() == 1, "CollectiveComm should only have one grad"); | ||||
| return opr.grad(out_grad[0]); | return opr.grad(out_grad[0]); | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== shallow copy ===================== */ | /* ===================== shallow copy ===================== */ | ||||
| @@ -109,6 +109,7 @@ cg::OperatorNodeBase::NodeProp* RemoteSend::do_make_node_prop() const { | |||||
| return prop; | return prop; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(RemoteSend) { | MGB_IMPL_OPR_GRAD(RemoteSend) { | ||||
| mgb_assert(opr.is_grad()); | mgb_assert(opr.is_grad()); | ||||
| return RemoteRecv::make(opr.key() + ":grad", | return RemoteRecv::make(opr.key() + ":grad", | ||||
| @@ -118,6 +119,7 @@ MGB_IMPL_OPR_GRAD(RemoteSend) { | |||||
| opr.input(0)->shape(), opr.input(0)->dtype()) | opr.input(0)->shape(), opr.input(0)->dtype()) | ||||
| .node(); | .node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== RemoteRecv ===================== */ | /* ===================== RemoteRecv ===================== */ | ||||
| @@ -552,6 +552,7 @@ void Elemwise::call_megdnn_opr_exec( | |||||
| opr->exec(inp, out); | opr->exec(inp, out); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Elemwise) { | MGB_IMPL_OPR_GRAD(Elemwise) { | ||||
| SymbolVar i[5]; | SymbolVar i[5]; | ||||
| SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), | SymbolVar i0(opr.input(0)), i1, i2, out(opr.output(0)), | ||||
| @@ -730,6 +731,7 @@ MGB_IMPL_OPR_GRAD(Elemwise) { | |||||
| result = -result; | result = -result; | ||||
| return result.node(); | return result.node(); | ||||
| } | } | ||||
| #endif | |||||
| VarNode* Elemwise::sum_grad_list(VarNode *wrt, VarNodeArray &grads) { | VarNode* Elemwise::sum_grad_list(VarNode *wrt, VarNodeArray &grads) { | ||||
| mgb_assert(!grads.empty()); | mgb_assert(!grads.empty()); | ||||
| @@ -814,6 +816,7 @@ TypeCvt::NodeProp* TypeCvt::do_make_node_prop() const { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(TypeCvt) { | MGB_IMPL_OPR_GRAD(TypeCvt) { | ||||
| MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
| auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); | auto itype = opr.input(0)->dtype(), otype = opr.output(0)->dtype(); | ||||
| @@ -826,6 +829,7 @@ MGB_IMPL_OPR_GRAD(TypeCvt) { | |||||
| } | } | ||||
| return TypeCvt::make(out_grad[0], opr.input(0)->dtype()).node(); | return TypeCvt::make(out_grad[0], opr.input(0)->dtype()).node(); | ||||
| } | } | ||||
| #endif | |||||
| void TypeCvt::mem_plan_fwd_in2out_writable() { | void TypeCvt::mem_plan_fwd_in2out_writable() { | ||||
| if (input(0)->dtype().size() == output(0)->dtype().size() && | if (input(0)->dtype().size() == output(0)->dtype().size() && | ||||
| @@ -963,10 +967,12 @@ void AddUpdate::record_execute_deps(ExecDependencyArray& deps) { | |||||
| record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(AddUpdate) { | MGB_IMPL_OPR_GRAD(AddUpdate) { | ||||
| // actually valid, just not implemented | // actually valid, just not implemented | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| } | } | ||||
| #endif | |||||
| /* =========================== Reduce =========================== */ | /* =========================== Reduce =========================== */ | ||||
| @@ -1698,6 +1704,7 @@ void Reduce::create_megdnn_opr() { | |||||
| create_operator<megdnn::Reduce>()); | create_operator<megdnn::Reduce>()); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Reduce) { | MGB_IMPL_OPR_GRAD(Reduce) { | ||||
| for (size_t i = 1; i < opr.output().size(); ++ i) | for (size_t i = 1; i < opr.output().size(); ++ i) | ||||
| mgb_assert(!out_grad[i]); | mgb_assert(!out_grad[i]); | ||||
| @@ -1733,7 +1740,7 @@ MGB_IMPL_OPR_GRAD(Reduce) { | |||||
| grad = TypeCvt::make(grad, iv.dtype()); | grad = TypeCvt::make(grad, iv.dtype()); | ||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | |||||
| void Reduce::record_execute_deps(ExecDependencyArray& deps) { | void Reduce::record_execute_deps(ExecDependencyArray& deps) { | ||||
| record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
| @@ -1783,11 +1790,13 @@ void PowC::init_output_static_infer_desc() { | |||||
| {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(PowC) { | MGB_IMPL_OPR_GRAD(PowC) { | ||||
| auto exp = opr.param().exp; | auto exp = opr.param().exp; | ||||
| return (exp * SymbolVar{out_grad[0]} * | return (exp * SymbolVar{out_grad[0]} * | ||||
| PowC::make(opr.input(0), exp - 1, opr.config())) | PowC::make(opr.input(0), exp - 1, opr.config())) | ||||
| .node(); | .node(); | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -106,6 +106,7 @@ void MatrixMul::scn_do_execute() { | |||||
| MGB_FINALLY({ tparam = this->param(); }); | MGB_FINALLY({ tparam = this->param(); }); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(MatrixMul) { | MGB_IMPL_OPR_GRAD(MatrixMul) { | ||||
| mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
| "only float data type supported for grad"); | "only float data type supported for grad"); | ||||
| @@ -128,6 +129,7 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { | |||||
| } | } | ||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ================= BatchedMatrixMul ================= */ | /* ================= BatchedMatrixMul ================= */ | ||||
| @@ -224,6 +226,7 @@ void BatchedMatrixMul::scn_do_execute() { | |||||
| MGB_FINALLY({ tparam = this->param(); }); | MGB_FINALLY({ tparam = this->param(); }); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | ||||
| mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
| "only float data type supported for grad"); | "only float data type supported for grad"); | ||||
| @@ -251,6 +254,7 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | |||||
| } | } | ||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ================= Dot ================= */ | /* ================= Dot ================= */ | ||||
| @@ -327,6 +331,7 @@ void Dot::add_input_layout_constraint() { | |||||
| input(1)->add_layout_constraint(check); | input(1)->add_layout_constraint(check); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Dot) { | MGB_IMPL_OPR_GRAD(Dot) { | ||||
| auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); | auto other_input = opr.input(wrt_idx == 0 ? 1 : 0); | ||||
| auto ishp0 = opr::GetVarShape::make(opr.input(0)), | auto ishp0 = opr::GetVarShape::make(opr.input(0)), | ||||
| @@ -336,6 +341,7 @@ MGB_IMPL_OPR_GRAD(Dot) { | |||||
| Broadcast::make(mul(out_grad[0], other_input), max_ishp), | Broadcast::make(mul(out_grad[0], other_input), max_ishp), | ||||
| wrt_idx ? ishp1 : ishp0).node(); | wrt_idx ? ishp1 : ishp0).node(); | ||||
| } | } | ||||
| #endif | |||||
| SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1, | SymbolVar Dot::make(SymbolVar opr0, SymbolVar opr1, | ||||
| const OperatorNodeConfig &config) { | const OperatorNodeConfig &config) { | ||||
| @@ -350,6 +356,8 @@ void Dot::record_execute_deps(ExecDependencyArray &deps) { | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(MatrixInverse); | ||||
| MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") | MEGDNN_OPR_INIT1(MatrixInverse, "matrix_inv") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(MatrixInverse) { | MGB_IMPL_OPR_GRAD(MatrixInverse) { | ||||
| SymbolVar a = opr.output(0); | SymbolVar a = opr.output(0); | ||||
| // TODO: use unified MatrixMul interface when we have it | // TODO: use unified MatrixMul interface when we have it | ||||
| @@ -364,6 +372,7 @@ MGB_IMPL_OPR_GRAD(MatrixInverse) { | |||||
| a_bnn); | a_bnn); | ||||
| return da.reshape(a.symshape()).node(); | return da.reshape(a.symshape()).node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ================= SVD ================= */ | /* ================= SVD ================= */ | ||||
| @@ -386,6 +395,7 @@ SVD::SVD(VarNode* src, const Param& param, const OperatorNodeConfig& config) : | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| namespace { | namespace { | ||||
| /*! | /*! | ||||
| @@ -477,7 +487,9 @@ OP(*, {}, {}) | |||||
| #undef OP | #undef OP | ||||
| } // anonymous namespace | } // anonymous namespace | ||||
| #endif | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(SVD) { | MGB_IMPL_OPR_GRAD(SVD) { | ||||
| /** | /** | ||||
| * The formula is copied from | * The formula is copied from | ||||
| @@ -555,6 +567,7 @@ MGB_IMPL_OPR_GRAD(SVD) { | |||||
| I_n - matmul(v, v, param01))); | I_n - matmul(v, v, param01))); | ||||
| return ret.reshape(a.symshape()).node(); | return ret.reshape(a.symshape()).node(); | ||||
| } | } | ||||
| #endif | |||||
| SymbolVarArray SVD::make(const SymbolVar& src, const Param& param, | SymbolVarArray SVD::make(const SymbolVar& src, const Param& param, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| @@ -818,6 +818,7 @@ SymbolVar CondExecMark::mark_if_need(SymbolVar maybe_ppv, SymbolVar input, | |||||
| return input; | return input; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(CondExecMark) { | MGB_IMPL_OPR_GRAD(CondExecMark) { | ||||
| if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { | if (wrt_idx == opr.input().size() - 1 || !out_grad.at(wrt_idx)) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -841,6 +842,7 @@ MGB_IMPL_OPR_GRAD(CondExecMark) { | |||||
| {1, grad_mode}, OperatorNodeConfig{}) | {1, grad_mode}, OperatorNodeConfig{}) | ||||
| ->output(0); | ->output(0); | ||||
| } | } | ||||
| #endif | |||||
| /* ============================= CondExecMerge ============================= */ | /* ============================= CondExecMerge ============================= */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(CondExecMerge); | ||||
| @@ -1225,6 +1227,7 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(CondExecMerge) { | MGB_IMPL_OPR_GRAD(CondExecMerge) { | ||||
| using Mode = CondExecMerge::Param::Mode; | using Mode = CondExecMerge::Param::Mode; | ||||
| if (opr.param().mode == Mode::SUM_COND_OUT && | if (opr.param().mode == Mode::SUM_COND_OUT && | ||||
| @@ -1259,6 +1262,7 @@ MGB_IMPL_OPR_GRAD(CondExecMerge) { | |||||
| OperatorNodeConfig{og->comp_node()}) | OperatorNodeConfig{og->comp_node()}) | ||||
| ->output(0); | ->output(0); | ||||
| } | } | ||||
| #endif | |||||
| void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) { | void CondExecMerge::modify_grad_sum_list(VarNode* wrt, VarNodeArray& grads) { | ||||
| if (!ExecutionMask::have_alive_instance()) { | if (!ExecutionMask::have_alive_instance()) { | ||||
| @@ -230,6 +230,7 @@ void BatchNormForward::mem_plan_fwd_in2out_writable() { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(BatchNormForward) { | MGB_IMPL_OPR_GRAD(BatchNormForward) { | ||||
| mgb_assert(wrt_idx < 5); | mgb_assert(wrt_idx < 5); | ||||
| if (wrt_idx < 3) { | if (wrt_idx < 3) { | ||||
| @@ -242,6 +243,7 @@ MGB_IMPL_OPR_GRAD(BatchNormForward) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(BatchNormBackward); | ||||
| @@ -18,6 +18,19 @@ | |||||
| #include "megdnn/oprs/utils.h" | #include "megdnn/oprs/utils.h" | ||||
| //! TODO: here has to be know some megdnn::opr when there is produced midout.h | |||||
| //! fix it if there is another graceful way. | |||||
| #include "megdnn/oprs.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_opr_convolution) | |||||
| #define MIDOUT_B(...) \ | |||||
| MIDOUT_BEGIN(megbrain_opr_convolution, __VA_ARGS__) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| #include "../internal/megdnn_opr_wrapper.inl" | #include "../internal/megdnn_opr_wrapper.inl" | ||||
| #include <array> | #include <array> | ||||
| @@ -230,6 +243,7 @@ class TimedProfiler { | |||||
| static constexpr int arity_in = OprArityTrait<Opr>::arity_in; | static constexpr int arity_in = OprArityTrait<Opr>::arity_in; | ||||
| static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | ||||
| static constexpr int arity = OprArityTrait<Opr>::arity; | static constexpr int arity = OprArityTrait<Opr>::arity; | ||||
| using ConvTensorShapes = std::array<TensorShape, arity>; | using ConvTensorShapes = std::array<TensorShape, arity>; | ||||
| public: | public: | ||||
| @@ -295,6 +309,7 @@ double TimedProfiler<Opr>::init_timeout_setting() { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | ||||
| const TParam& raw_param) { | const TParam& raw_param) { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_impl"))) | |||||
| auto&& param = raw_param.as_single_pod<Param>(); | auto&& param = raw_param.as_single_pod<Param>(); | ||||
| CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); | CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); | ||||
| auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn); | auto megdnn_opr = intl::create_megdnn_opr<Opr>(cn); | ||||
| @@ -401,14 +416,17 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||||
| mgb_assert(ev_start->finished()); | mgb_assert(ev_start->finished()); | ||||
| return TResult::from_pod(Result{ev_start->elapsed_time_until(*ev_end)}); | return TResult::from_pod(Result{ev_start->elapsed_time_until(*ev_end)}); | ||||
| MIDOUT_E | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) { | void TimedProfiler<Opr>::prof_init_device(const TParam& raw_param) { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("TimedProfiler::prof_init_device"))) | |||||
| auto&& param = raw_param.as_single_pod<Param>(); | auto&& param = raw_param.as_single_pod<Param>(); | ||||
| CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); | CompNode cn = CompNode::load(param.comp_node_loc, param.comp_node_loc); | ||||
| // wait for cuda init, so its time does not get accounted in timeout | // wait for cuda init, so its time does not get accounted in timeout | ||||
| cn.sync(); | cn.sync(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| /* =================== AlgoChooser =================== */ | /* =================== AlgoChooser =================== */ | ||||
| @@ -426,6 +444,7 @@ class AlgoChooser { | |||||
| static constexpr int arity_in = OprArityTrait<Opr>::arity_in; | static constexpr int arity_in = OprArityTrait<Opr>::arity_in; | ||||
| static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | ||||
| static constexpr int arity = OprArityTrait<Opr>::arity; | static constexpr int arity = OprArityTrait<Opr>::arity; | ||||
| using ImplAlgo = typename Opr::Algorithm*; | using ImplAlgo = typename Opr::Algorithm*; | ||||
| using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr; | using MGBOpr = typename MegDNNOpr2MGBOpr<Opr>::MGBOpr; | ||||
| using ConvTensorLayouts = std::array<TensorLayout, arity>; | using ConvTensorLayouts = std::array<TensorLayout, arity>; | ||||
| @@ -473,8 +492,8 @@ class AlgoChooser { | |||||
| //! put first | //! put first | ||||
| std::vector<ImplAlgo> get_all_candidates() const { | std::vector<ImplAlgo> get_all_candidates() const { | ||||
| auto heu = choose_by_heuristic(); | auto heu = choose_by_heuristic(); | ||||
| auto&& ret = OprArityTrait<Opr>::get_all_algorithms( | |||||
| m_megdnn_opr, m_layouts); | |||||
| auto&& ret = OprArityTrait<Opr>::get_all_algorithms(m_megdnn_opr, | |||||
| m_layouts); | |||||
| bool found = false; | bool found = false; | ||||
| for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
| if (ret[i] == heu) { | if (ret[i] == heu) { | ||||
| @@ -491,7 +510,7 @@ class AlgoChooser { | |||||
| //! get candidate algos with workspace limit. | //! get candidate algos with workspace limit. | ||||
| std::vector<ImplAlgo> get_all_candidates_with_workspace_limit() const { | std::vector<ImplAlgo> get_all_candidates_with_workspace_limit() const { | ||||
| auto && all_algos = get_all_candidates(); | |||||
| auto&& all_algos = get_all_candidates(); | |||||
| auto opr = m_mgb_opr; | auto opr = m_mgb_opr; | ||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| opr->owner_graph(), opr->comp_node(), | opr->owner_graph(), opr->comp_node(), | ||||
| @@ -633,16 +652,16 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result( | |||||
| algo->name(), str_on_inp_shape.c_str()); | algo->name(), str_on_inp_shape.c_str()); | ||||
| timer.reset(); | timer.reset(); | ||||
| MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); } | MGB_TRY { cur_rst = ctx.profile_single_algo(algo, cur_timeout); } | ||||
| MGB_CATCH(std::exception & exc, | |||||
| { | |||||
| mgb_log_warn("caught exception during %s: %s", | |||||
| msg.c_str(), exc.what()); | |||||
| continue; | |||||
| }) | |||||
| MGB_CATCH(std::exception & exc, { | |||||
| mgb_log_warn("caught exception during %s: %s", msg.c_str(), | |||||
| exc.what()); | |||||
| continue; | |||||
| }) | |||||
| MGB_CATCH(..., { | MGB_CATCH(..., { | ||||
| mgb_log_warn("caught exception during %s", msg.c_str()); | mgb_log_warn("caught exception during %s", msg.c_str()); | ||||
| continue; | continue; | ||||
| }) if (!cur_rst.valid()) { | |||||
| }) | |||||
| if (!cur_rst.valid()) { | |||||
| mgb_log_warn("timeout when %s; timeout setting: %.3fsec", | mgb_log_warn("timeout when %s; timeout setting: %.3fsec", | ||||
| msg.c_str(), cur_timeout); | msg.c_str(), cur_timeout); | ||||
| continue; | continue; | ||||
| @@ -680,6 +699,7 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts( | |||||
| template <typename Opr> | template <typename Opr> | ||||
| typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile( | typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile( | ||||
| ExeContext& ctx, bool require_reproducible, bool enable_update) { | ExeContext& ctx, bool require_reproducible, bool enable_update) { | ||||
| MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) | |||||
| auto opr = ctx.mgb_opr(); | auto opr = ctx.mgb_opr(); | ||||
| if (opr->owner_graph()->options().no_profiling_on_shape_change) { | if (opr->owner_graph()->options().no_profiling_on_shape_change) { | ||||
| auto algo = ctx.megdnn_opr()->execution_policy().algorithm; | auto algo = ctx.megdnn_opr()->execution_policy().algorithm; | ||||
| @@ -720,6 +740,7 @@ typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::choose_by_profile( | |||||
| opr->owner_graph(), opr->comp_node(), | opr->owner_graph(), opr->comp_node(), | ||||
| opr->execution_policy().workspace_limit)); | opr->execution_policy().workspace_limit)); | ||||
| mgb_trap(); | mgb_trap(); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| template <> | template <> | ||||
| @@ -748,7 +769,7 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext:: | |||||
| if (m_layouts[1].dtype.enumv() == DTypeEnum::QuantizedS8 && | if (m_layouts[1].dtype.enumv() == DTypeEnum::QuantizedS8 && | ||||
| param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44) { | param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44) { | ||||
| if (winograd_preprocess_opr->param().format == | if (winograd_preprocess_opr->param().format == | ||||
| megdnn::param::MatrixMul::Format::MK4){ | |||||
| megdnn::param::MatrixMul::Format::MK4) { | |||||
| winograd_preprocess_opr->param().compute_mode = | winograd_preprocess_opr->param().compute_mode = | ||||
| ConvBias::Param::ComputeMode::FLOAT32; | ConvBias::Param::ComputeMode::FLOAT32; | ||||
| param.opr_param.compute_mode = | param.opr_param.compute_mode = | ||||
| @@ -941,6 +962,7 @@ void ConvolutionForward::init_output_dtype() { | |||||
| output(0)->dtype(output_dtype); | output(0)->dtype(output_dtype); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ConvolutionForward) { | MGB_IMPL_OPR_GRAD(ConvolutionForward) { | ||||
| mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
| "only float data type supported for grad"); | "only float data type supported for grad"); | ||||
| @@ -960,6 +982,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionForward) { | |||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| size_t ConvolutionForward::get_workspace_size_bytes( | size_t ConvolutionForward::get_workspace_size_bytes( | ||||
| const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
| @@ -1086,6 +1109,7 @@ void ConvolutionBackwardData::scn_do_execute() { | |||||
| intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { | MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { | ||||
| mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -1101,6 +1125,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardData) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== ConvolutionBackwardFilter ==================== */ | /* ==================== ConvolutionBackwardFilter ==================== */ | ||||
| IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter"); | IMPL_CONV(ConvolutionBackwardFilter, "conv_bwd_filter"); | ||||
| @@ -1138,6 +1163,7 @@ size_t ConvolutionBackwardFilter::get_workspace_size_bytes( | |||||
| megdnn_opr(), this); | megdnn_opr(), this); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { | MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { | ||||
| mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -1153,6 +1179,7 @@ MGB_IMPL_OPR_GRAD(ConvolutionBackwardFilter) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== Convolution3DForward ==================== */ | /* ==================== Convolution3DForward ==================== */ | ||||
| IMPL_CONV(Convolution3DForward, "conv3d_fwd"); | IMPL_CONV(Convolution3DForward, "conv3d_fwd"); | ||||
| @@ -1192,6 +1219,7 @@ void Convolution3DForward::init_output_dtype() { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Convolution3DForward) { | MGB_IMPL_OPR_GRAD(Convolution3DForward) { | ||||
| mgb_assert(opr.param().data_type == | mgb_assert(opr.param().data_type == | ||||
| Convolution3DForward::Param::DataType::FLOAT, | Convolution3DForward::Param::DataType::FLOAT, | ||||
| @@ -1212,6 +1240,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DForward) { | |||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| size_t Convolution3DForward::get_workspace_size_bytes( | size_t Convolution3DForward::get_workspace_size_bytes( | ||||
| const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
| @@ -1285,6 +1314,7 @@ void Convolution3DBackwardData::scn_do_execute() { | |||||
| intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { | MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { | ||||
| mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -1300,6 +1330,7 @@ MGB_IMPL_OPR_GRAD(Convolution3DBackwardData) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== Convolution3DBackwardFilter ==================== */ | /* ==================== Convolution3DBackwardFilter ==================== */ | ||||
| IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter"); | IMPL_CONV(Convolution3DBackwardFilter, "conv3d_bwd_filter"); | ||||
| @@ -1658,6 +1689,7 @@ size_t LocalShareForward::get_workspace_size_bytes( | |||||
| megdnn_opr(), this); | megdnn_opr(), this); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(LocalShareForward) { | MGB_IMPL_OPR_GRAD(LocalShareForward) { | ||||
| mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | mgb_assert(opr.input(0)->dtype().category() == DTypeCategory::FLOAT, | ||||
| "only float data type supported for grad"); | "only float data type supported for grad"); | ||||
| @@ -1677,6 +1709,7 @@ MGB_IMPL_OPR_GRAD(LocalShareForward) { | |||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== LocalShareBackwardData ==================== */ | /* ===================== LocalShareBackwardData ==================== */ | ||||
| @@ -1737,6 +1770,7 @@ void LocalShareBackwardData::scn_do_execute() { | |||||
| intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { | MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { | ||||
| mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -1752,6 +1786,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardData) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== LocalShareBackwardFilter ==================== */ | /* ==================== LocalShareBackwardFilter ==================== */ | ||||
| @@ -1792,6 +1827,7 @@ size_t LocalShareBackwardFilter::get_workspace_size_bytes( | |||||
| megdnn_opr(), this); | megdnn_opr(), this); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { | MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { | ||||
| mgb_assert(!out_grad[1]); | mgb_assert(!out_grad[1]); | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -1805,6 +1841,7 @@ MGB_IMPL_OPR_GRAD(LocalShareBackwardFilter) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== DeformableConvForward ==================== */ | /* ===================== DeformableConvForward ==================== */ | ||||
| @@ -1869,6 +1906,7 @@ size_t DeformableConvForward::get_workspace_size_bytes( | |||||
| megdnn_opr(), this); | megdnn_opr(), this); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(DeformableConvForward) { | MGB_IMPL_OPR_GRAD(DeformableConvForward) { | ||||
| mgb_assert(opr.input(0)->dtype() == dtype::Float32(), | mgb_assert(opr.input(0)->dtype() == dtype::Float32(), | ||||
| "only float data type supported for grad"); | "only float data type supported for grad"); | ||||
| @@ -1888,6 +1926,7 @@ MGB_IMPL_OPR_GRAD(DeformableConvForward) { | |||||
| SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]}; | SymbolVarArray grads = {grad_arr[0], filter_grad, grad_arr[1], grad_arr[2]}; | ||||
| return grads[wrt_idx].node(); | return grads[wrt_idx].node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== DeformableConvBackwardData ==================== */ | /* ==================== DeformableConvBackwardData ==================== */ | ||||
| @@ -2265,4 +2304,4 @@ void BatchConvBiasForward::init_output_format() { | |||||
| #undef IMPL_CONV | #undef IMPL_CONV | ||||
| #undef MGB_FOREACH_FASTRUN_OPR | #undef MGB_FOREACH_FASTRUN_OPR | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -20,11 +20,13 @@ using namespace opr; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsForward); | ||||
| MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") | MEGDNN_OPR_INIT1(Images2NeibsForward, "images2neibs") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Images2NeibsForward) { | MGB_IMPL_OPR_GRAD(Images2NeibsForward) { | ||||
| mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); | mgb_assert(wrt_idx == 0 && out_grad.size() == 2 && !out_grad[1]); | ||||
| return Images2NeibsBackward::make( | return Images2NeibsBackward::make( | ||||
| out_grad[0], opr.input(0), opr.param()).node(); | out_grad[0], opr.input(0), opr.param()).node(); | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Images2NeibsBackward); | ||||
| MEGDNN_OPR_INIT2(Images2NeibsBackward, "images2neibs_grad", 1, false); | MEGDNN_OPR_INIT2(Images2NeibsBackward, "images2neibs_grad", 1, false); | ||||
| @@ -20,10 +20,13 @@ using namespace opr; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalForward); | ||||
| MEGDNN_OPR_INIT2(LocalForward, "local") | MEGDNN_OPR_INIT2(LocalForward, "local") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(LocalForward) { | MGB_IMPL_OPR_GRAD(LocalForward) { | ||||
| return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>( | return intl::conv_grad<LocalBackwardData, LocalBackwardFilter>( | ||||
| opr, wrt_idx, out_grad); | opr, wrt_idx, out_grad); | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalBackwardData); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(LocalBackwardData); | ||||
| MEGDNN_OPR_INIT3(LocalBackwardData, "local_bwd_data", 2, false); | MEGDNN_OPR_INIT3(LocalBackwardData, "local_bwd_data", 2, false); | ||||
| @@ -34,10 +37,13 @@ MEGDNN_OPR_INIT3(LocalBackwardFilter, "local_bwd_filter", 2, false); | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalForward); | ||||
| MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") | MEGDNN_OPR_INIT2(GroupLocalForward, "glocal") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(GroupLocalForward) { | MGB_IMPL_OPR_GRAD(GroupLocalForward) { | ||||
| return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>( | return intl::conv_grad<GroupLocalBackwardData, GroupLocalBackwardFilter>( | ||||
| opr, wrt_idx, out_grad); | opr, wrt_idx, out_grad); | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalBackwardData); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupLocalBackwardData); | ||||
| MEGDNN_OPR_INIT3(GroupLocalBackwardData, "glocal_bwd_data", 2, false); | MEGDNN_OPR_INIT3(GroupLocalBackwardData, "glocal_bwd_data", 2, false); | ||||
| @@ -20,12 +20,14 @@ using namespace opr; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNForward); | ||||
| MEGDNN_OPR_INIT1(LRNForward, "lrn") | MEGDNN_OPR_INIT1(LRNForward, "lrn") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(LRNForward) { | MGB_IMPL_OPR_GRAD(LRNForward) { | ||||
| mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
| SymbolVar grad = LRNBackward::make( | SymbolVar grad = LRNBackward::make( | ||||
| opr.input(0), opr.output(0), out_grad[0], opr.param()); | opr.input(0), opr.output(0), out_grad[0], opr.param()); | ||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(LRNBackward); | ||||
| MEGDNN_OPR_INIT3(LRNBackward, "lrn_bwd", 0, true); | MEGDNN_OPR_INIT3(LRNBackward, "lrn_bwd", 0, true); | ||||
| @@ -19,12 +19,14 @@ using namespace opr; | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingForward); | ||||
| MEGDNN_OPR_INIT1(PoolingForward, "pooling") | MEGDNN_OPR_INIT1(PoolingForward, "pooling") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(PoolingForward) { | MGB_IMPL_OPR_GRAD(PoolingForward) { | ||||
| mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
| SymbolVar grad = PoolingBackward::make( | SymbolVar grad = PoolingBackward::make( | ||||
| opr.input(0), opr.output(0), out_grad[0], opr.param()); | opr.input(0), opr.output(0), out_grad[0], opr.param()); | ||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(PoolingBackward); | ||||
| MEGDNN_OPR_INIT3(PoolingBackward, "pooling_bwd", 0, true); | MEGDNN_OPR_INIT3(PoolingBackward, "pooling_bwd", 0, true); | ||||
| @@ -40,6 +40,7 @@ SymbolVar ROIAlignForward::make(SymbolVar src, SymbolVar rois, | |||||
| src.node(), rois.node(), param, config); | src.node(), rois.node(), param, config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ROIAlignForward) { | MGB_IMPL_OPR_GRAD(ROIAlignForward) { | ||||
| if (out_grad[1]) { | if (out_grad[1]) { | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -55,6 +56,7 @@ MGB_IMPL_OPR_GRAD(ROIAlignForward) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== ROIAlignBackward ==================== */ | /* ==================== ROIAlignBackward ==================== */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROIAlignBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ROIAlignBackward); | ||||
| @@ -84,6 +84,7 @@ size_t ROIPoolingForward::get_workspace_size_bytes( | |||||
| input_shapes, output_shapes); | input_shapes, output_shapes); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ROIPoolingForward) { | MGB_IMPL_OPR_GRAD(ROIPoolingForward) { | ||||
| if (out_grad[1] || wrt_idx == 2) { | if (out_grad[1] || wrt_idx == 2) { | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -98,6 +99,7 @@ MGB_IMPL_OPR_GRAD(ROIPoolingForward) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| void ROIPoolingForward::scn_do_execute() { | void ROIPoolingForward::scn_do_execute() { | ||||
| return intl::MegDNNOprMethInvoker<megdnn::ROIPoolingForward>:: | return intl::MegDNNOprMethInvoker<megdnn::ROIPoolingForward>:: | ||||
| @@ -146,6 +148,7 @@ SymbolVar DeformablePSROIPoolingForward::make( | |||||
| return all[0]; | return all[0]; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { | MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { | ||||
| mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 | mgb_assert(wrt_idx <= 2); // wrt_idx = 0 or 1 or 2 | ||||
| @@ -168,6 +171,7 @@ MGB_IMPL_OPR_GRAD(DeformablePSROIPooling) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== DeformablePSROIPoolingBackward ==================== */ | /* ==================== DeformablePSROIPoolingBackward ==================== */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(DeformablePSROIPoolingBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(DeformablePSROIPoolingBackward); | ||||
| @@ -127,6 +127,7 @@ void WarpPerspectiveForward::record_execute_deps(ExecDependencyArray& deps) { | |||||
| record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { | MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { | ||||
| mgb_assert(opr.input().size() == 3, | mgb_assert(opr.input().size() == 3, | ||||
| "backward with mat_idx is currently unsupported"); | "backward with mat_idx is currently unsupported"); | ||||
| @@ -145,6 +146,7 @@ MGB_IMPL_OPR_GRAD(WarpPerspectiveForward) { | |||||
| } else | } else | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| } | } | ||||
| #endif | |||||
| /* ====================== WarpPerspectiveBackwardData ====================== */ | /* ====================== WarpPerspectiveBackwardData ====================== */ | ||||
| @@ -234,6 +236,7 @@ void ResizeForward::record_execute_deps(ExecDependencyArray &deps) { | |||||
| record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ResizeForward) { | MGB_IMPL_OPR_GRAD(ResizeForward) { | ||||
| mgb_assert(opr.input().size() == 2); | mgb_assert(opr.input().size() == 2); | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -243,6 +246,7 @@ MGB_IMPL_OPR_GRAD(ResizeForward) { | |||||
| } else | } else | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| } | } | ||||
| #endif | |||||
| /* ====================== ResizeBackward ====================== */ | /* ====================== ResizeBackward ====================== */ | ||||
| @@ -83,6 +83,7 @@ void IndexingOneHot::init_output_dtype() { | |||||
| output(0)->dtype(input(0)->dtype()); | output(0)->dtype(input(0)->dtype()); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IndexingOneHot) { | MGB_IMPL_OPR_GRAD(IndexingOneHot) { | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| return IndexingSetOneHot::make( | return IndexingSetOneHot::make( | ||||
| @@ -91,6 +92,7 @@ MGB_IMPL_OPR_GRAD(IndexingOneHot) { | |||||
| } | } | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| } | } | ||||
| #endif | |||||
| /* ==================== IndexingSetOneHot ==================== */ | /* ==================== IndexingSetOneHot ==================== */ | ||||
| @@ -133,6 +135,7 @@ void IndexingSetOneHot::scn_do_execute() { | |||||
| intl::get_megdnn_workspace_from_var(output(1))); | intl::get_megdnn_workspace_from_var(output(1))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { | MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { | ||||
| SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; | SymbolVar index{opr.input(1)}, sub{opr.input(2)}, og{out_grad.at(0)}; | ||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| @@ -144,6 +147,7 @@ MGB_IMPL_OPR_GRAD(IndexingSetOneHot) { | |||||
| } | } | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| } | } | ||||
| #endif | |||||
| size_t IndexingSetOneHot::get_workspace_size_bytes( | size_t IndexingSetOneHot::get_workspace_size_bytes( | ||||
| const TensorShapeArray &input_shapes, | const TensorShapeArray &input_shapes, | ||||
| @@ -165,6 +169,7 @@ void IndexingRemap::init_output_dtype() { | |||||
| output(0)->dtype(input(0)->dtype()); | output(0)->dtype(input(0)->dtype()); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IndexingRemap) { | MGB_IMPL_OPR_GRAD(IndexingRemap) { | ||||
| if (wrt_idx == 1) | if (wrt_idx == 1) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -172,6 +177,7 @@ MGB_IMPL_OPR_GRAD(IndexingRemap) { | |||||
| return IndexingRemapBackward::make( | return IndexingRemapBackward::make( | ||||
| out_grad[0], opr.input(1), opr.input(0), opr.param()).node(); | out_grad[0], opr.input(1), opr.input(0), opr.param()).node(); | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingRemapBackward); | ||||
| MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false); | MEGDNN_OPR_INIT3(IndexingRemapBackward, "indexing_remap_bwd", 2, false); | ||||
| @@ -460,6 +466,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | |||||
| MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY( | ||||
| IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); | IndexingIncrMultiAxisVec, "indexing_incr_multi_axis_vec", false); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | ||||
| if (wrt_idx) | if (wrt_idx) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -468,7 +475,9 @@ MGB_IMPL_OPR_GRAD(IndexingMultiAxisVec) { | |||||
| SymbolVar{opr.input(0)}.fill_retain_dtype(0), | SymbolVar{opr.input(0)}.fill_retain_dtype(0), | ||||
| out_grad.at(0), opr.index_desc()).node(); | out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { | ||||
| if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -479,7 +488,9 @@ MGB_IMPL_OPR_GRAD(IndexingSetMultiAxisVec) { | |||||
| } | } | ||||
| return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); | return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { | MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { | ||||
| if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -488,6 +499,7 @@ MGB_IMPL_OPR_GRAD(IndexingIncrMultiAxisVec) { | |||||
| } | } | ||||
| return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); | return IndexingMultiAxisVec::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ============================= Mesh Indexing ============================ */ | /* ============================= Mesh Indexing ============================ */ | ||||
| @@ -498,6 +510,7 @@ MGB_IMPL_FANCY_INDEXING_OPR_GET( | |||||
| BatchedMeshIndexing, "batched_mesh_indexing", false, | BatchedMeshIndexing, "batched_mesh_indexing", false, | ||||
| output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); | output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(MeshIndexing) { | MGB_IMPL_OPR_GRAD(MeshIndexing) { | ||||
| if (wrt_idx != 0) { | if (wrt_idx != 0) { | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -507,6 +520,9 @@ MGB_IMPL_OPR_GRAD(MeshIndexing) { | |||||
| opr.index_desc()) | opr.index_desc()) | ||||
| .node(); | .node(); | ||||
| } | } | ||||
| #endif | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { | MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { | ||||
| if (wrt_idx != 0) { | if (wrt_idx != 0) { | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -516,11 +532,14 @@ MGB_IMPL_OPR_GRAD(BatchedMeshIndexing) { | |||||
| opr.index_desc()) | opr.index_desc()) | ||||
| .node(); | .node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ========================= IncrMeshIndexing ========================= */ | /* ========================= IncrMeshIndexing ========================= */ | ||||
| MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(IncrMeshIndexing, "incr_mesh_indexing", | ||||
| false); | false); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { | MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { | ||||
| if (wrt_idx > 2) { | if (wrt_idx > 2) { | ||||
| return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -530,9 +549,11 @@ MGB_IMPL_OPR_GRAD(IncrMeshIndexing) { | |||||
| } | } | ||||
| return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); | return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedIncrMeshIndexing, | ||||
| "batched_incr_mesh_indexing", false); | "batched_incr_mesh_indexing", false); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { | MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { | ||||
| if (wrt_idx > 2) { | if (wrt_idx > 2) { | ||||
| return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -542,10 +563,12 @@ MGB_IMPL_OPR_GRAD(BatchedIncrMeshIndexing) { | |||||
| } | } | ||||
| return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); | return BatchedMeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ======================== SetMeshIndexing =========================== */ | /* ======================== SetMeshIndexing =========================== */ | ||||
| MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(SetMeshIndexing, "set_mesh_indexing", false); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(SetMeshIndexing) { | MGB_IMPL_OPR_GRAD(SetMeshIndexing) { | ||||
| if (wrt_idx >= 2) { | if (wrt_idx >= 2) { | ||||
| return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -560,9 +583,11 @@ MGB_IMPL_OPR_GRAD(SetMeshIndexing) { | |||||
| return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); | return MeshIndexing::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, | MGB_IMPL_FANCY_INDEXING_OPR_MODIFY(BatchedSetMeshIndexing, | ||||
| "batched_set_mesh_indexing", false); | "batched_set_mesh_indexing", false); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { | MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { | ||||
| if (wrt_idx > 2) { | if (wrt_idx > 2) { | ||||
| return opr::InvalidGrad::make(opr, wrt_idx); | return opr::InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -578,5 +603,6 @@ MGB_IMPL_OPR_GRAD(BatchedSetMeshIndexing) { | |||||
| .node(); | .node(); | ||||
| } | } | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -764,11 +764,13 @@ Copy::NodeProp* Copy::do_make_node_prop() const { | |||||
| return rst; | return rst; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Copy) { | MGB_IMPL_OPR_GRAD(Copy) { | ||||
| mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
| return Copy::make(out_grad[0], | return Copy::make(out_grad[0], | ||||
| OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); | OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); | ||||
| } | } | ||||
| #endif | |||||
| void Copy::add_input_layout_constraint() { | void Copy::add_input_layout_constraint() { | ||||
| if (input(0)->comp_node() != output(0)->comp_node()) { | if (input(0)->comp_node() != output(0)->comp_node()) { | ||||
| @@ -268,9 +268,11 @@ VarNode* Loop::grad(Loop &opr, size_t wrt_idx, const VarNodeArray &out_grad) { | |||||
| return gopr->get_grad_var(wrt_idx); | return gopr->get_grad_var(wrt_idx); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Loop) { | MGB_IMPL_OPR_GRAD(Loop) { | ||||
| return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad); | return Loop::grad(const_cast<Loop&>(opr), wrt_idx, out_grad); | ||||
| } | } | ||||
| #endif | |||||
| cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { | cg::OperatorNodeBase::NodeProp* Loop::do_make_node_prop() const { | ||||
| auto prop = LoopImpl::do_make_node_prop(); | auto prop = LoopImpl::do_make_node_prop(); | ||||
| @@ -48,23 +48,26 @@ namespace intl { | |||||
| /* ================= Argmxx ================= */ | /* ================= Argmxx ================= */ | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Argmax) { | MGB_IMPL_OPR_GRAD(Argmax) { | ||||
| MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
| MGB_MARK_USED_VAR(opr); | MGB_MARK_USED_VAR(opr); | ||||
| mgb_assert(!wrt_idx); | mgb_assert(!wrt_idx); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmax); | ||||
| MEGDNN_OPR_INIT1(Argmax, "argmax") | MEGDNN_OPR_INIT1(Argmax, "argmax") | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Argmin) { | MGB_IMPL_OPR_GRAD(Argmin) { | ||||
| MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
| MGB_MARK_USED_VAR(opr); | MGB_MARK_USED_VAR(opr); | ||||
| mgb_assert(!wrt_idx); | mgb_assert(!wrt_idx); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Argmin); | ||||
| MEGDNN_OPR_INIT1(Argmin, "argmin") | MEGDNN_OPR_INIT1(Argmin, "argmin") | ||||
| @@ -84,12 +87,14 @@ std::array<SymbolVar, 2> ArgsortForward::make( | |||||
| return {node->output(0), node->output(1)}; | return {node->output(0), node->output(1)}; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ArgsortForward) { | MGB_IMPL_OPR_GRAD(ArgsortForward) { | ||||
| mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]); | ||||
| if (!out_grad[0]) | if (!out_grad[0]) | ||||
| return nullptr; | return nullptr; | ||||
| return ArgsortBackward::make(out_grad[0], opr.output(1)).node(); | return ArgsortBackward::make(out_grad[0], opr.output(1)).node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ================= ArgsortBackward ================= */ | /* ================= ArgsortBackward ================= */ | ||||
| @@ -107,12 +112,14 @@ Cumsum::Cumsum(VarNode* opr, const Param& param, | |||||
| add_input({opr}, AddInputSortType::CUR_ADDED); | add_input({opr}, AddInputSortType::CUR_ADDED); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Cumsum) { | MGB_IMPL_OPR_GRAD(Cumsum) { | ||||
| mgb_assert(out_grad[0] && !out_grad[1]); | mgb_assert(out_grad[0] && !out_grad[1]); | ||||
| auto param = opr.param(); | auto param = opr.param(); | ||||
| param.reverse = !param.reverse; | param.reverse = !param.reverse; | ||||
| return Cumsum::make(out_grad[0], param).node(); | return Cumsum::make(out_grad[0], param).node(); | ||||
| } | } | ||||
| #endif | |||||
| SymbolVar Cumsum::make(SymbolVar opr, const Param& param, | SymbolVar Cumsum::make(SymbolVar opr, const Param& param, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| @@ -170,6 +177,7 @@ CondTake::CondTake(VarNode *data, VarNode *mask, | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(CondTake) { | MGB_IMPL_OPR_GRAD(CondTake) { | ||||
| mgb_assert(out_grad.size() == 3 && !out_grad[2]); | mgb_assert(out_grad.size() == 3 && !out_grad[2]); | ||||
| if (wrt_idx == 0 && out_grad[0]) { | if (wrt_idx == 0 && out_grad[0]) { | ||||
| @@ -181,6 +189,7 @@ MGB_IMPL_OPR_GRAD(CondTake) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| std::array<SymbolVar, 2> CondTake::make( | std::array<SymbolVar, 2> CondTake::make( | ||||
| SymbolVar data, SymbolVar mask, | SymbolVar data, SymbolVar mask, | ||||
| @@ -318,6 +327,7 @@ void TopK::record_execute_deps(ExecDependencyArray& deps) { | |||||
| record_megdnn_opr(deps); | record_megdnn_opr(deps); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(TopK) { | MGB_IMPL_OPR_GRAD(TopK) { | ||||
| if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { | if (opr.param().mode == TopK::Param::Mode::KTH_ONLY) { | ||||
| mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); | mgb_assert(out_grad[0] && !out_grad[1] && !out_grad[2]); | ||||
| @@ -334,5 +344,6 @@ MGB_IMPL_OPR_GRAD(TopK) { | |||||
| return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0)) | return ArgsortBackward::make(out_grad[0], opr.output(1), opr.input(0)) | ||||
| .node(); | .node(); | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -316,9 +316,11 @@ VarNodeArray AllGather::grad(const VarNodeArray &out_grad) { | |||||
| OperatorNodeConfig().comp_node_arr(sp_cn))); | OperatorNodeConfig().comp_node_arr(sp_cn))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(AllGather) { | MGB_IMPL_OPR_GRAD(AllGather) { | ||||
| return const_cast<AllGather&>(opr).grad(out_grad); | return const_cast<AllGather&>(opr).grad(out_grad); | ||||
| } | } | ||||
| #endif | |||||
| void AllGather::on_output_comp_node_stream_changed() { | void AllGather::on_output_comp_node_stream_changed() { | ||||
| } | } | ||||
| @@ -112,19 +112,21 @@ UniqPtrWithCN<megdnn::RNGBase> RNGOpr<MegDNNOpr>::create_megdnn_opr() { | |||||
| return opr; | return opr; | ||||
| } | } | ||||
| #define IMPL(_cls) \ | |||||
| template class RNGOpr<::megdnn::_cls>; \ | |||||
| MGB_IMPL_OPR_GRAD(_cls) { \ | |||||
| MGB_MARK_USED_VAR(out_grad); \ | |||||
| return InvalidGrad::make(opr, wrt_idx); \ | |||||
| } \ | |||||
| #define IMPL(_cls) \ | |||||
| MGB_IMPL_OPR_GRAD(_cls) { \ | |||||
| MGB_MARK_USED_VAR(out_grad); \ | |||||
| return InvalidGrad::make(opr, wrt_idx); \ | |||||
| } | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace opr { | namespace opr { | ||||
| namespace intl { | namespace intl { | ||||
| template class RNGOpr<::megdnn::GaussianRNG>; | |||||
| template class RNGOpr<::megdnn::UniformRNG>; | |||||
| #ifdef MGB_ENABLE_GRAD | |||||
| IMPL(GaussianRNG); | IMPL(GaussianRNG); | ||||
| IMPL(UniformRNG); | IMPL(UniformRNG); | ||||
| #endif | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -46,11 +46,13 @@ void Alloc::outshape_by_symvar_do_get_output_shape( | |||||
| void Alloc::scn_do_execute() { | void Alloc::scn_do_execute() { | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Alloc) { | MGB_IMPL_OPR_GRAD(Alloc) { | ||||
| MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
| MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
| return InvalidGrad::make(opr, 0); | return InvalidGrad::make(opr, 0); | ||||
| } | } | ||||
| #endif | |||||
| /* ======================= Linspace ======================= */ | /* ======================= Linspace ======================= */ | ||||
| @@ -123,6 +125,7 @@ void Linspace::record_execute_deps(ExecDependencyArray& deps) { | |||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Linspace) { | MGB_IMPL_OPR_GRAD(Linspace) { | ||||
| if (wrt_idx == 2) | if (wrt_idx == 2) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -134,6 +137,7 @@ MGB_IMPL_OPR_GRAD(Linspace) { | |||||
| return opr::Dot::make(og, | return opr::Dot::make(og, | ||||
| opr::Linspace::make(i0, i1, opr.input(2), opr.param())).node(); | opr::Linspace::make(i0, i1, opr.input(2), opr.param())).node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ======================= Eye ======================= */ | /* ======================= Eye ======================= */ | ||||
| @@ -195,9 +199,10 @@ void Eye::record_execute_deps(ExecDependencyArray& deps) { | |||||
| std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | std::make_unique<intl::MegDNNGraphDep>(std::move(m_megdnn_opr))); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Eye) { | MGB_IMPL_OPR_GRAD(Eye) { | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| } | } | ||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -165,12 +165,13 @@ void GetVarShape::init_output_static_infer_desc() { | |||||
| mgr.register_value_infer(output(0), | mgr.register_value_infer(output(0), | ||||
| {SourceType::DEP, deps, infer_value}); | {SourceType::DEP, deps, infer_value}); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(GetVarShape) { | MGB_IMPL_OPR_GRAD(GetVarShape) { | ||||
| MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
| MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| SymbolVar GetVarShape::make(const VarNodeArrayView& inp, Param param, | SymbolVar GetVarShape::make(const VarNodeArrayView& inp, Param param, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| @@ -362,11 +363,13 @@ SymbolVar Reshape::make(SymbolVar inp, SymbolVar tshp, | |||||
| inp.node(), tshp.node(), unspec_axis, config); | inp.node(), tshp.node(), unspec_axis, config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Reshape) { | MGB_IMPL_OPR_GRAD(Reshape) { | ||||
| if (wrt_idx) | if (wrt_idx) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); | return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); | ||||
| } | } | ||||
| #endif | |||||
| Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout( | Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout( | ||||
| const TensorLayout &src, const TensorShape &tshape) const { | const TensorLayout &src, const TensorShape &tshape) const { | ||||
| @@ -429,12 +432,14 @@ SymbolVar Broadcast::make(SymbolVar inp, SymbolVar tshp, | |||||
| inp.node(), tshp.node(), config); | inp.node(), tshp.node(), config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Broadcast) { | MGB_IMPL_OPR_GRAD(Broadcast) { | ||||
| if (wrt_idx) | if (wrt_idx) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| return Reduce::make(out_grad.at(0), Reduce::Mode::SUM, | return Reduce::make(out_grad.at(0), Reduce::Mode::SUM, | ||||
| GetVarShape::make(opr.input(0))).node(); | GetVarShape::make(opr.input(0))).node(); | ||||
| } | } | ||||
| #endif | |||||
| Maybe<TensorLayout> Broadcast::reshapebrdcast_get_dest_layout( | Maybe<TensorLayout> Broadcast::reshapebrdcast_get_dest_layout( | ||||
| const TensorLayout &src, const TensorShape &tshape) const { | const TensorLayout &src, const TensorShape &tshape) const { | ||||
| @@ -562,9 +567,11 @@ VarNode* Dimshuffle::grad( | |||||
| return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); | return Dimshuffle::make(out_grad.at(0), back, m_pattern.size()).node(); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Dimshuffle) { | MGB_IMPL_OPR_GRAD(Dimshuffle) { | ||||
| return opr.grad(wrt_idx, out_grad); | return opr.grad(wrt_idx, out_grad); | ||||
| } | } | ||||
| #endif | |||||
| // f}}} | // f}}} | ||||
| @@ -631,10 +638,12 @@ AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(AxisAddRemove) { | MGB_IMPL_OPR_GRAD(AxisAddRemove) { | ||||
| MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
| return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); | return Reshape::make(out_grad[0], GetVarShape::make(opr.input(0))).node(); | ||||
| } | } | ||||
| #endif | |||||
| // f}}} | // f}}} | ||||
| @@ -642,6 +651,7 @@ MGB_IMPL_OPR_GRAD(AxisAddRemove) { | |||||
| MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); | MGB_IMPL_FANCY_INDEXING_OPR_GET(Subtensor, "subtensor", true); | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Subtensor) { | MGB_IMPL_OPR_GRAD(Subtensor) { | ||||
| if (wrt_idx) | if (wrt_idx) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -650,6 +660,7 @@ MGB_IMPL_OPR_GRAD(Subtensor) { | |||||
| SymbolVar{opr.input(0)}.fill_retain_dtype(0), | SymbolVar{opr.input(0)}.fill_retain_dtype(0), | ||||
| out_grad.at(0), opr.index_desc()).node(); | out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| void Subtensor::init_output_static_infer_desc() { | void Subtensor::init_output_static_infer_desc() { | ||||
| using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
| @@ -783,6 +794,7 @@ void SetSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { | |||||
| sub.copy_from_fixlayout(val); | sub.copy_from_fixlayout(val); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(SetSubtensor) { | MGB_IMPL_OPR_GRAD(SetSubtensor) { | ||||
| if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -793,6 +805,7 @@ MGB_IMPL_OPR_GRAD(SetSubtensor) { | |||||
| } | } | ||||
| return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); | return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| // f}}} | // f}}} | ||||
| @@ -813,6 +826,7 @@ void IncrSubtensor::modify(DeviceTensorND &sub, const DeviceTensorND &val) { | |||||
| opr->exec(sub.as_megdnn(), val.as_megdnn()); | opr->exec(sub.as_megdnn(), val.as_megdnn()); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(IncrSubtensor) { | MGB_IMPL_OPR_GRAD(IncrSubtensor) { | ||||
| if (wrt_idx >= 2) | if (wrt_idx >= 2) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -821,6 +835,7 @@ MGB_IMPL_OPR_GRAD(IncrSubtensor) { | |||||
| } | } | ||||
| return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); | return Subtensor::make(out_grad.at(0), opr.index_desc()).node(); | ||||
| } | } | ||||
| #endif | |||||
| // f}}} | // f}}} | ||||
| @@ -1085,6 +1100,7 @@ void Split::do_execute(ExecEnv &env) { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Split) { | MGB_IMPL_OPR_GRAD(Split) { | ||||
| if (wrt_idx) | if (wrt_idx) | ||||
| return InvalidGrad::make(opr, wrt_idx); | return InvalidGrad::make(opr, wrt_idx); | ||||
| @@ -1100,6 +1116,7 @@ MGB_IMPL_OPR_GRAD(Split) { | |||||
| return Concat::make(grad, opr.options().axis, | return Concat::make(grad, opr.options().axis, | ||||
| OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); | OperatorNodeConfig{}.follow_comp_node(opr.input(0))).node(); | ||||
| } | } | ||||
| #endif | |||||
| void Split::mem_plan_fwd_in2out_readonly() { | void Split::mem_plan_fwd_in2out_readonly() { | ||||
| m_readonly_fwd_called = true; | m_readonly_fwd_called = true; | ||||
| @@ -1236,6 +1253,7 @@ SymbolVar Concat::make(const VarNodeArrayView& inp, int axis, | |||||
| axis, config); | axis, config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Concat) { | MGB_IMPL_OPR_GRAD(Concat) { | ||||
| auto axis = opr.axis(); | auto axis = opr.axis(); | ||||
| mgb_assert(out_grad.size() == 1); | mgb_assert(out_grad.size() == 1); | ||||
| @@ -1250,6 +1268,7 @@ MGB_IMPL_OPR_GRAD(Concat) { | |||||
| OperatorNodeConfig().comp_node_arr(comp_node)); | OperatorNodeConfig().comp_node_arr(comp_node)); | ||||
| return cg::to_var_node_array(ret); | return cg::to_var_node_array(ret); | ||||
| } | } | ||||
| #endif | |||||
| void Concat::scn_do_execute() { | void Concat::scn_do_execute() { | ||||
| auto&& out = output(0)->dev_tensor(); | auto&& out = output(0)->dev_tensor(); | ||||
| @@ -1507,6 +1526,7 @@ void ParamPackSplit::init_output_static_infer_desc() { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(ParamPackSplit) { | MGB_IMPL_OPR_GRAD(ParamPackSplit) { | ||||
| mgb_assert(out_grad.size() == opr.output().size()); | mgb_assert(out_grad.size() == opr.output().size()); | ||||
| SmallVector<SymbolVar> grad; | SmallVector<SymbolVar> grad; | ||||
| @@ -1531,6 +1551,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { | |||||
| OperatorNodeConfig{}.follow_comp_node(opr.input(0))) | OperatorNodeConfig{}.follow_comp_node(opr.input(0))) | ||||
| .node(); | .node(); | ||||
| } | } | ||||
| #endif | |||||
| // f}}} | // f}}} | ||||
| /* f{{{ ======================= RelayoutFormat ======================= */ | /* f{{{ ======================= RelayoutFormat ======================= */ | ||||
| @@ -255,9 +255,11 @@ void MarkDynamicVar::scn_do_execute() { | |||||
| o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); | o->dev_tensor().copy_from_fixlayout(i->dev_tensor()); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(MarkDynamicVar) { | MGB_IMPL_OPR_GRAD(MarkDynamicVar) { | ||||
| return MarkDynamicVar::make(out_grad.at(0)).node(); | return MarkDynamicVar::make(out_grad.at(0)).node(); | ||||
| } | } | ||||
| #endif | |||||
| MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config): | MarkDynamicVar::MarkDynamicVar(VarNode *node, const OperatorNodeConfig &config): | ||||
| Super{node->owner_graph(), config, "mark_dyn", {node}} | Super{node->owner_graph(), config, "mark_dyn", {node}} | ||||
| @@ -381,10 +383,12 @@ CallbackInjector::mixin_get_static_infer_desc(OperatorNodeBase &opr) { | |||||
| } | } | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(CallbackInjector) { | MGB_IMPL_OPR_GRAD(CallbackInjector) { | ||||
| MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
| return out_grad.at(0); | return out_grad.at(0); | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== MarkNoBroadcastElemwise ===================== */ | /* ===================== MarkNoBroadcastElemwise ===================== */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(MarkNoBroadcastElemwise); | ||||
| @@ -404,9 +408,11 @@ SymbolVar MarkNoBroadcastElemwise::make( | |||||
| input.node(), config); | input.node(), config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { | MGB_IMPL_OPR_GRAD(MarkNoBroadcastElemwise) { | ||||
| return out_grad.at(0); | return out_grad.at(0); | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== Identity ===================== */ | /* ===================== Identity ===================== */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); | ||||
| @@ -429,9 +435,11 @@ SymbolVar Identity::make( | |||||
| return input.insert_single_output_opr<Identity>(input.node(), config); | return input.insert_single_output_opr<Identity>(input.node(), config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(Identity) { | MGB_IMPL_OPR_GRAD(Identity) { | ||||
| return out_grad.at(0); | return out_grad.at(0); | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== AssertEqual ===================== */ | /* ===================== AssertEqual ===================== */ | ||||
| @@ -530,6 +538,7 @@ SymbolVar SetGrad::make(SymbolVar input, const GradGetter& grad_getter, | |||||
| input.node(), grad_getter, config); | input.node(), grad_getter, config); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(SetGrad) { | MGB_IMPL_OPR_GRAD(SetGrad) { | ||||
| MGB_MARK_USED_VAR(wrt_idx); | MGB_MARK_USED_VAR(wrt_idx); | ||||
| MGB_MARK_USED_VAR(out_grad); | MGB_MARK_USED_VAR(out_grad); | ||||
| @@ -538,6 +547,7 @@ MGB_IMPL_OPR_GRAD(SetGrad) { | |||||
| "var returned by grad_getter belongs to a different comp graph"); | "var returned by grad_getter belongs to a different comp graph"); | ||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | |||||
| /* ===================== InvalidGrad ===================== */ | /* ===================== InvalidGrad ===================== */ | ||||
| @@ -690,6 +700,7 @@ VirtualLoss::NodeProp* VirtualLoss::do_make_node_prop() const { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(VirtualLoss) { | MGB_IMPL_OPR_GRAD(VirtualLoss) { | ||||
| mgb_assert(out_grad.size() == 1); | mgb_assert(out_grad.size() == 1); | ||||
| auto mid = opr.input().size() / 2; | auto mid = opr.input().size() / 2; | ||||
| @@ -698,6 +709,7 @@ MGB_IMPL_OPR_GRAD(VirtualLoss) { | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| #endif | |||||
| #else | #else | ||||
| VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) { | VarNode* InvalidGrad::make(const OperatorNodeBase&, size_t) { | ||||
| @@ -24,6 +24,16 @@ | |||||
| #include "megdnn/opr_param_json.h" | #include "megdnn/opr_param_json.h" | ||||
| #endif | #endif | ||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| MIDOUT_DECL(megbrain_opr_footprint) | |||||
| #define MIDOUT_B(...) \ | |||||
| MIDOUT_BEGIN(megbrain_opr_footprint, __VA_ARGS__) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | using namespace mgb; | ||||
| namespace { | namespace { | ||||
| @@ -581,9 +591,12 @@ std::shared_ptr<json::Value> opr_param_json_func<opr::Subtensor>( | |||||
| template <class OprType> | template <class OprType> | ||||
| void OprFootprint::add_single_comp_footprint() { | void OprFootprint::add_single_comp_footprint() { | ||||
| MIDOUT_B(OprType, | |||||
| midout_iv(MGB_HASH_STR("OprFootprint::add_single_comp_footprint"))) | |||||
| auto&& record = m_type2comp_footprint.emplace(OprType::typeinfo(), | auto&& record = m_type2comp_footprint.emplace(OprType::typeinfo(), | ||||
| opr_footprint_func<OprType>); | opr_footprint_func<OprType>); | ||||
| mgb_assert(record.second, "duplicate opr typeinfo"); | mgb_assert(record.second, "duplicate opr typeinfo"); | ||||
| MIDOUT_E | |||||
| } | } | ||||
| #if MGB_ENABLE_JSON | #if MGB_ENABLE_JSON | ||||