| @@ -32,6 +32,7 @@ MIDOUT_DECL(megbrain_opr_convolution) | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| #include "../internal/megdnn_opr_wrapper.inl" | #include "../internal/megdnn_opr_wrapper.inl" | ||||
| #include "../internal/invoke.h" | |||||
| #include <array> | #include <array> | ||||
| #include <chrono> | #include <chrono> | ||||
| @@ -109,104 +110,74 @@ struct OprAttributeTrait<opr::ConvBias> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| constexpr bool opr_supports_preprocess() { | |||||
| return std::is_same<Opr, megdnn::ConvolutionForward>::value || | |||||
| std::is_same<Opr, megdnn::ConvBias>::value; | |||||
| } | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct OprArityTrait; | struct OprArityTrait; | ||||
| #define cb(x) (x) | |||||
| #define cb_ref(x) (&(x)) | |||||
| #define cb_dnn(x) ((x).as_megdnn()) | |||||
| #define APPLY(statement, ...) \ | |||||
| mgb::apply([&](const auto&... args) { return statement; }, \ | |||||
| std::tuple_cat(__VA_ARGS__)) | |||||
| template <typename Opr, int _arity_in, int _arity_out> | |||||
| struct OprArityTraitTmpl { | |||||
| static constexpr int arity_in = _arity_in; | |||||
| static constexpr int arity_out = _arity_out; | |||||
| static constexpr int arity = arity_in + arity_out; | |||||
| using Algorithm = typename Opr::Algorithm; | |||||
| using TensorLayoutArray = std::array<TensorLayout, arity>; | |||||
| static size_t get_workspace_in_bytes(Opr* opr, Algorithm* algo, | |||||
| const TensorLayoutArray& layouts) { | |||||
| opr->execution_policy() = {algo}; | |||||
| size_t workspace_size; | |||||
| if_constexpr<opr_supports_preprocess<Opr>()>([&](auto) { | |||||
| workspace_size = APPLY( | |||||
| opr->get_workspace_in_bytes(args..., nullptr), layouts); | |||||
| }, /* else */ [&](auto) { | |||||
| workspace_size = | |||||
| APPLY(opr->get_workspace_in_bytes(args...), layouts); | |||||
| }); | |||||
| return workspace_size; | |||||
| } | |||||
| #define WS_ARG_true ,nullptr | |||||
| #define WS_ARG_false | |||||
| #define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \ | |||||
| template <> \ | |||||
| struct OprArityTrait<_Opr> { \ | |||||
| static constexpr int arity_in = _in; \ | |||||
| static constexpr int arity_out = _out; \ | |||||
| static constexpr int arity = _in + _out; \ | |||||
| using TensorLayoutArray = std::array<TensorLayout, arity>; \ | |||||
| static size_t get_workspace_in_bytes( \ | |||||
| _Opr* opr, typename _Opr::Algorithm* algo, \ | |||||
| const TensorLayoutArray& layouts) { \ | |||||
| opr->execution_policy() = {algo}; \ | |||||
| return opr->get_workspace_in_bytes( \ | |||||
| LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \ | |||||
| } \ | |||||
| \ | |||||
| static std::vector<typename _Opr::Algorithm*> get_all_algorithms( \ | |||||
| _Opr* opr, const TensorLayoutArray& layouts) { \ | |||||
| return opr->get_all_algorithms(LAYOUTS(cb)); \ | |||||
| } \ | |||||
| \ | |||||
| static typename _Opr::Algorithm* get_algorithm_heuristic( \ | |||||
| _Opr* opr, const TensorLayoutArray& layouts, \ | |||||
| size_t workspace_limit, bool reproducible) { \ | |||||
| return opr->get_algorithm_heuristic(LAYOUTS(cb), workspace_limit, \ | |||||
| reproducible); \ | |||||
| } \ | |||||
| \ | |||||
| static void exec(_Opr* opr, const DeviceTensorND* inp_val, \ | |||||
| const DeviceTensorND* out_val, \ | |||||
| megdnn::Workspace& workspace) { \ | |||||
| opr->exec(TENSORS(cb_dnn), workspace); \ | |||||
| } \ | |||||
| static void exec(Opr* opr, | |||||
| const std::array<DeviceTensorND, arity_in>& inp_val, | |||||
| const std::array<DeviceTensorND, arity_out>& out_val, | |||||
| megdnn::Workspace& workspace) { | |||||
| if_constexpr<opr_supports_preprocess<Opr>()>([&](auto) { | |||||
| APPLY(opr->exec(args.as_megdnn()..., nullptr, workspace), inp_val, | |||||
| out_val); | |||||
| }, /* else */ [&](auto) { | |||||
| APPLY(opr->exec(args.as_megdnn()..., workspace), inp_val, out_val); | |||||
| }); | |||||
| } | } | ||||
| }; | |||||
| #define INST_ARITY(_Opr, _in, _out) \ | |||||
| template <> \ | |||||
| struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {}; | |||||
| INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1); | |||||
| INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1); | |||||
| INST_ARITY(megdnn::Convolution3DForward, 2, 1); | |||||
| INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1); | |||||
| INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1); | |||||
| INST_ARITY(megdnn::LocalShareForward, 2, 1); | |||||
| INST_ARITY(megdnn::LocalShareBackwardData, 2, 1); | |||||
| INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1); | |||||
| INST_ARITY(megdnn::Convolution, 2, 1); | |||||
| INST_ARITY(megdnn::DeformableConvForward, 4, 1); | |||||
| INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||||
| INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | |||||
| INST_ARITY(megdnn::ConvBias, 4, 1); | |||||
| INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | |||||
| #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]) | |||||
| #define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]) | |||||
| #define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false) | |||||
| INST_ARITY_2_1(megdnn::ConvolutionBackwardData); | |||||
| INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter); | |||||
| INST_ARITY_2_1(megdnn::Convolution3DForward); | |||||
| INST_ARITY_2_1(megdnn::Convolution3DBackwardData); | |||||
| INST_ARITY_2_1(megdnn::Convolution3DBackwardFilter); | |||||
| INST_ARITY_2_1(megdnn::LocalShareForward); | |||||
| INST_ARITY_2_1(megdnn::LocalShareBackwardData); | |||||
| INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); | |||||
| #undef TENSORS | |||||
| #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr | |||||
| INST_ARITY(megdnn::Convolution, 2, 1, true); | |||||
| #undef TENSORS | |||||
| #undef LAYOUTS | |||||
| #undef INST_ARITY_2_1 | |||||
| #define TENSORS(cb) \ | |||||
| cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ | |||||
| cb(out_val[0]) | |||||
| #define LAYOUTS(cb) \ | |||||
| cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \ | |||||
| cb(layouts[4]) | |||||
| #define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false) | |||||
| INST_ARITY_4_1(megdnn::DeformableConvForward); | |||||
| INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter); | |||||
| INST_ARITY_4_1(megdnn::BatchConvBiasForward); | |||||
| #undef TENSORS | |||||
| #define TENSORS(cb) \ | |||||
| cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ | |||||
| cb(out_val[0]), nullptr | |||||
| INST_ARITY(megdnn::ConvBias, 4, 1, true); | |||||
| #undef TENSORS | |||||
| #undef LAYOUTS | |||||
| #undef INST_ARITY_4_1 | |||||
| #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), \ | |||||
| cb(inp_val[3]), cb(inp_val[4]), cb(out_val[0]), \ | |||||
| cb(out_val[1]), cb(out_val[2]) | |||||
| #define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), \ | |||||
| cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \ | |||||
| cb(layouts[6]), cb(layouts[7]) | |||||
| #define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false) | |||||
| INST_ARITY_5_3(megdnn::DeformableConvBackwardData); | |||||
| #undef TENSORS | |||||
| #undef LAYOUTS | |||||
| #undef INST_ARITY_5_3 | |||||
| #undef cb | |||||
| #undef cb_ref | |||||
| #undef cb_dnn | |||||
| #undef INST_ARITY | #undef INST_ARITY | ||||
| #undef WS_ARG_true | |||||
| #undef WS_ARG_false | |||||
| // timeout delta to be added with fastest known algorithm for new algos | // timeout delta to be added with fastest known algorithm for new algos | ||||
| constexpr double TIMEOUT_TOLERANCE = 2; | constexpr double TIMEOUT_TOLERANCE = 2; | ||||
| @@ -343,8 +314,7 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||||
| megdnn_opr->param() = param.opr_param; | megdnn_opr->param() = param.opr_param; | ||||
| { | { | ||||
| typename Opr::Algorithm* algo = nullptr; | typename Opr::Algorithm* algo = nullptr; | ||||
| for (auto i : OprArityTrait<Opr>::get_all_algorithms(megdnn_opr.get(), | |||||
| layouts)) { | |||||
| for (auto i : APPLY(megdnn_opr->get_all_algorithms(args...), layouts)) { | |||||
| if (!strcmp(i->name(), param.algo_name)) { | if (!strcmp(i->name(), param.algo_name)) { | ||||
| algo = i; | algo = i; | ||||
| break; | break; | ||||
| @@ -368,7 +338,9 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||||
| } | } | ||||
| // allocate input and output memory | // allocate input and output memory | ||||
| DeviceTensorND inp_val[arity_in], out_val[arity_out], workspace; | |||||
| std::array<DeviceTensorND, arity_in> inp_val; | |||||
| std::array<DeviceTensorND, arity_out> out_val; | |||||
| DeviceTensorND workspace; | |||||
| for (int i = 0; i < arity_in; ++i) { | for (int i = 0; i < arity_in; ++i) { | ||||
| inp_val[i] | inp_val[i] | ||||
| .comp_node(cn) | .comp_node(cn) | ||||
| @@ -484,16 +456,17 @@ class AlgoChooser { | |||||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | ||||
| opr->owner_graph(), opr->comp_node(), | opr->owner_graph(), opr->comp_node(), | ||||
| opr->execution_policy().workspace_limit); | opr->execution_policy().workspace_limit); | ||||
| return OprArityTrait<Opr>::get_algorithm_heuristic( | |||||
| m_megdnn_opr, m_layouts, workspace_limit, reproducible); | |||||
| return APPLY(m_megdnn_opr->get_algorithm_heuristic( | |||||
| args..., workspace_limit, reproducible), | |||||
| m_layouts); | |||||
| } | } | ||||
| //! get all candidate algos, and the one choose_by_heuristic() is | //! get all candidate algos, and the one choose_by_heuristic() is | ||||
| //! put first | //! put first | ||||
| std::vector<ImplAlgo> get_all_candidates() const { | std::vector<ImplAlgo> get_all_candidates() const { | ||||
| auto heu = choose_by_heuristic(); | auto heu = choose_by_heuristic(); | ||||
| auto&& ret = OprArityTrait<Opr>::get_all_algorithms(m_megdnn_opr, | |||||
| m_layouts); | |||||
| auto&& ret = | |||||
| APPLY(m_megdnn_opr->get_all_algorithms(args...), m_layouts); | |||||
| bool found = false; | bool found = false; | ||||
| for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
| if (ret[i] == heu) { | if (ret[i] == heu) { | ||||