| @@ -32,6 +32,7 @@ MIDOUT_DECL(megbrain_opr_convolution) | |||
| MIDOUT_END(); | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| #include "../internal/invoke.h" | |||
| #include <array> | |||
| #include <chrono> | |||
| @@ -109,104 +110,74 @@ struct OprAttributeTrait<opr::ConvBias> { | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| constexpr bool opr_supports_preprocess() { | |||
| return std::is_same<Opr, megdnn::ConvolutionForward>::value || | |||
| std::is_same<Opr, megdnn::ConvBias>::value; | |||
| } | |||
| template <typename Opr> | |||
| struct OprArityTrait; | |||
| #define cb(x) (x) | |||
| #define cb_ref(x) (&(x)) | |||
| #define cb_dnn(x) ((x).as_megdnn()) | |||
| #define APPLY(statement, ...) \ | |||
| mgb::apply([&](const auto&... args) { return statement; }, \ | |||
| std::tuple_cat(__VA_ARGS__)) | |||
| template <typename Opr, int _arity_in, int _arity_out> | |||
| struct OprArityTraitTmpl { | |||
| static constexpr int arity_in = _arity_in; | |||
| static constexpr int arity_out = _arity_out; | |||
| static constexpr int arity = arity_in + arity_out; | |||
| using Algorithm = typename Opr::Algorithm; | |||
| using TensorLayoutArray = std::array<TensorLayout, arity>; | |||
| static size_t get_workspace_in_bytes(Opr* opr, Algorithm* algo, | |||
| const TensorLayoutArray& layouts) { | |||
| opr->execution_policy() = {algo}; | |||
| size_t workspace_size; | |||
| if_constexpr<opr_supports_preprocess<Opr>()>([&](auto) { | |||
| workspace_size = APPLY( | |||
| opr->get_workspace_in_bytes(args..., nullptr), layouts); | |||
| }, /* else */ [&](auto) { | |||
| workspace_size = | |||
| APPLY(opr->get_workspace_in_bytes(args...), layouts); | |||
| }); | |||
| return workspace_size; | |||
| } | |||
| #define WS_ARG_true ,nullptr | |||
| #define WS_ARG_false | |||
| #define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \ | |||
| template <> \ | |||
| struct OprArityTrait<_Opr> { \ | |||
| static constexpr int arity_in = _in; \ | |||
| static constexpr int arity_out = _out; \ | |||
| static constexpr int arity = _in + _out; \ | |||
| using TensorLayoutArray = std::array<TensorLayout, arity>; \ | |||
| static size_t get_workspace_in_bytes( \ | |||
| _Opr* opr, typename _Opr::Algorithm* algo, \ | |||
| const TensorLayoutArray& layouts) { \ | |||
| opr->execution_policy() = {algo}; \ | |||
| return opr->get_workspace_in_bytes( \ | |||
| LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \ | |||
| } \ | |||
| \ | |||
| static std::vector<typename _Opr::Algorithm*> get_all_algorithms( \ | |||
| _Opr* opr, const TensorLayoutArray& layouts) { \ | |||
| return opr->get_all_algorithms(LAYOUTS(cb)); \ | |||
| } \ | |||
| \ | |||
| static typename _Opr::Algorithm* get_algorithm_heuristic( \ | |||
| _Opr* opr, const TensorLayoutArray& layouts, \ | |||
| size_t workspace_limit, bool reproducible) { \ | |||
| return opr->get_algorithm_heuristic(LAYOUTS(cb), workspace_limit, \ | |||
| reproducible); \ | |||
| } \ | |||
| \ | |||
| static void exec(_Opr* opr, const DeviceTensorND* inp_val, \ | |||
| const DeviceTensorND* out_val, \ | |||
| megdnn::Workspace& workspace) { \ | |||
| opr->exec(TENSORS(cb_dnn), workspace); \ | |||
| } \ | |||
| static void exec(Opr* opr, | |||
| const std::array<DeviceTensorND, arity_in>& inp_val, | |||
| const std::array<DeviceTensorND, arity_out>& out_val, | |||
| megdnn::Workspace& workspace) { | |||
| if_constexpr<opr_supports_preprocess<Opr>()>([&](auto) { | |||
| APPLY(opr->exec(args.as_megdnn()..., nullptr, workspace), inp_val, | |||
| out_val); | |||
| }, /* else */ [&](auto) { | |||
| APPLY(opr->exec(args.as_megdnn()..., workspace), inp_val, out_val); | |||
| }); | |||
| } | |||
| }; | |||
| #define INST_ARITY(_Opr, _in, _out) \ | |||
| template <> \ | |||
| struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {}; | |||
| INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1); | |||
| INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1); | |||
| INST_ARITY(megdnn::Convolution3DForward, 2, 1); | |||
| INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1); | |||
| INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1); | |||
| INST_ARITY(megdnn::LocalShareForward, 2, 1); | |||
| INST_ARITY(megdnn::LocalShareBackwardData, 2, 1); | |||
| INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1); | |||
| INST_ARITY(megdnn::Convolution, 2, 1); | |||
| INST_ARITY(megdnn::DeformableConvForward, 4, 1); | |||
| INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||
| INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | |||
| INST_ARITY(megdnn::ConvBias, 4, 1); | |||
| INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | |||
| #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]) | |||
| #define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]) | |||
| #define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false) | |||
| INST_ARITY_2_1(megdnn::ConvolutionBackwardData); | |||
| INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter); | |||
| INST_ARITY_2_1(megdnn::Convolution3DForward); | |||
| INST_ARITY_2_1(megdnn::Convolution3DBackwardData); | |||
| INST_ARITY_2_1(megdnn::Convolution3DBackwardFilter); | |||
| INST_ARITY_2_1(megdnn::LocalShareForward); | |||
| INST_ARITY_2_1(megdnn::LocalShareBackwardData); | |||
| INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); | |||
| #undef TENSORS | |||
| #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr | |||
| INST_ARITY(megdnn::Convolution, 2, 1, true); | |||
| #undef TENSORS | |||
| #undef LAYOUTS | |||
| #undef INST_ARITY_2_1 | |||
| #define TENSORS(cb) \ | |||
| cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ | |||
| cb(out_val[0]) | |||
| #define LAYOUTS(cb) \ | |||
| cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \ | |||
| cb(layouts[4]) | |||
| #define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false) | |||
| INST_ARITY_4_1(megdnn::DeformableConvForward); | |||
| INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter); | |||
| INST_ARITY_4_1(megdnn::BatchConvBiasForward); | |||
| #undef TENSORS | |||
| #define TENSORS(cb) \ | |||
| cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ | |||
| cb(out_val[0]), nullptr | |||
| INST_ARITY(megdnn::ConvBias, 4, 1, true); | |||
| #undef TENSORS | |||
| #undef LAYOUTS | |||
| #undef INST_ARITY_4_1 | |||
| #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), \ | |||
| cb(inp_val[3]), cb(inp_val[4]), cb(out_val[0]), \ | |||
| cb(out_val[1]), cb(out_val[2]) | |||
| #define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), \ | |||
| cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \ | |||
| cb(layouts[6]), cb(layouts[7]) | |||
| #define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false) | |||
| INST_ARITY_5_3(megdnn::DeformableConvBackwardData); | |||
| #undef TENSORS | |||
| #undef LAYOUTS | |||
| #undef INST_ARITY_5_3 | |||
| #undef cb | |||
| #undef cb_ref | |||
| #undef cb_dnn | |||
| #undef INST_ARITY | |||
| #undef WS_ARG_true | |||
| #undef WS_ARG_false | |||
| // timeout delta to be added with fastest known algorithm for new algos | |||
| constexpr double TIMEOUT_TOLERANCE = 2; | |||
| @@ -343,8 +314,7 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||
| megdnn_opr->param() = param.opr_param; | |||
| { | |||
| typename Opr::Algorithm* algo = nullptr; | |||
| for (auto i : OprArityTrait<Opr>::get_all_algorithms(megdnn_opr.get(), | |||
| layouts)) { | |||
| for (auto i : APPLY(megdnn_opr->get_all_algorithms(args...), layouts)) { | |||
| if (!strcmp(i->name(), param.algo_name)) { | |||
| algo = i; | |||
| break; | |||
| @@ -368,7 +338,9 @@ typename TimedProfiler<Opr>::TResult TimedProfiler<Opr>::prof_impl( | |||
| } | |||
| // allocate input and output memory | |||
| DeviceTensorND inp_val[arity_in], out_val[arity_out], workspace; | |||
| std::array<DeviceTensorND, arity_in> inp_val; | |||
| std::array<DeviceTensorND, arity_out> out_val; | |||
| DeviceTensorND workspace; | |||
| for (int i = 0; i < arity_in; ++i) { | |||
| inp_val[i] | |||
| .comp_node(cn) | |||
| @@ -484,16 +456,17 @@ class AlgoChooser { | |||
| auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( | |||
| opr->owner_graph(), opr->comp_node(), | |||
| opr->execution_policy().workspace_limit); | |||
| return OprArityTrait<Opr>::get_algorithm_heuristic( | |||
| m_megdnn_opr, m_layouts, workspace_limit, reproducible); | |||
| return APPLY(m_megdnn_opr->get_algorithm_heuristic( | |||
| args..., workspace_limit, reproducible), | |||
| m_layouts); | |||
| } | |||
| //! get all candidate algos, and the one choose_by_heuristic() is | |||
| //! put first | |||
| std::vector<ImplAlgo> get_all_candidates() const { | |||
| auto heu = choose_by_heuristic(); | |||
| auto&& ret = OprArityTrait<Opr>::get_all_algorithms(m_megdnn_opr, | |||
| m_layouts); | |||
| auto&& ret = | |||
| APPLY(m_megdnn_opr->get_all_algorithms(args...), m_layouts); | |||
| bool found = false; | |||
| for (size_t i = 0; i < ret.size(); ++i) { | |||
| if (ret[i] == heu) { | |||