GitOrigin-RevId: bb7ab8fa9d
tags/v1.10.0
| @@ -70,15 +70,6 @@ def _matmul( | |||||
| maxdim = dim1 if dim1 > dim2 else dim2 | maxdim = dim1 if dim1 > dim2 else dim2 | ||||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | ||||
| Strategy = builtin.ops.MatrixMul.Strategy | |||||
| strategy = Strategy(0) | |||||
| if _config._benchmark_kernel: | |||||
| strategy |= Strategy.PROFILE | |||||
| else: | |||||
| strategy |= Strategy.HEURISTIC | |||||
| if _config._deterministic_kernel: | |||||
| strategy |= Strategy.REPRODUCIBLE | |||||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | if dim1 == 1 and dim2 == 1: # dispatch to Dot | ||||
| (result,) = apply(builtin.Dot(), inp1, inp2) | (result,) = apply(builtin.Dot(), inp1, inp2) | ||||
| return result | return result | ||||
| @@ -621,6 +621,7 @@ def max_pool2d( | |||||
| pad_h=padding_h, | pad_h=padding_h, | ||||
| pad_w=padding_w, | pad_w=padding_w, | ||||
| mode="max", | mode="max", | ||||
| strategy=get_execution_strategy(), | |||||
| format=conv_format, | format=conv_format, | ||||
| ) | ) | ||||
| (output,) = apply(op, inp) | (output,) = apply(op, inp) | ||||
| @@ -665,6 +666,7 @@ def avg_pool2d( | |||||
| pad_h=padding_h, | pad_h=padding_h, | ||||
| pad_w=padding_w, | pad_w=padding_w, | ||||
| mode=mode, | mode=mode, | ||||
| strategy=get_execution_strategy(), | |||||
| format=conv_format, | format=conv_format, | ||||
| ) | ) | ||||
| (output,) = apply(op, inp) | (output,) = apply(op, inp) | ||||
| @@ -1493,7 +1493,7 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||||
| py::object _matmul_cpp( | py::object _matmul_cpp( | ||||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | ||||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | ||||
| py::handle profile, py::handle determistic) { | |||||
| py::handle profile, py::handle deterministic) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | ::megdnn::param::MatrixMul::ComputeMode mode = | ||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | ||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | ||||
| @@ -1506,7 +1506,7 @@ py::object _matmul_cpp( | |||||
| } else { | } else { | ||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | ||||
| } | } | ||||
| if (determistic.cast<bool>()) { | |||||
| if (deterministic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | ||||
| } | } | ||||
| std::shared_ptr<OpDef> op = MatrixMul::make( | std::shared_ptr<OpDef> op = MatrixMul::make( | ||||
| @@ -1523,7 +1523,7 @@ py::object _matmul_cpp( | |||||
| py::object _batched_matmul_cpp( | py::object _batched_matmul_cpp( | ||||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | ||||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | ||||
| py::handle profile, py::handle determistic) { | |||||
| py::handle profile, py::handle deterministic) { | |||||
| ::megdnn::param::MatrixMul::ComputeMode mode = | ::megdnn::param::MatrixMul::ComputeMode mode = | ||||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | ||||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | ||||
| @@ -1536,7 +1536,7 @@ py::object _batched_matmul_cpp( | |||||
| } else { | } else { | ||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | ||||
| } | } | ||||
| if (determistic.cast<bool>()) { | |||||
| if (deterministic.cast<bool>()) { | |||||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | ||||
| } | } | ||||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | ||||
| @@ -10,6 +10,10 @@ | |||||
| */ | */ | ||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/opr/blas.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | |||||
| #include "megbrain/rdnn/profiler.h" | |||||
| #include "megbrain/serialization/opr_load_dump.h" | #include "megbrain/serialization/opr_load_dump.h" | ||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| @@ -65,6 +69,42 @@ public: | |||||
| const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | ||||
| }; | }; | ||||
| #define cb(FASTRUN_OPR) \ | |||||
| megdnn::param::ExecutionPolicy get_strategy_##FASTRUN_OPR( \ | |||||
| cg::OperatorNodeBase* opr) { \ | |||||
| auto policy = \ | |||||
| opr->cast_final<opr::FASTRUN_OPR>().execution_policy_transient(); \ | |||||
| return policy; \ | |||||
| } \ | |||||
| void set_strategy_##FASTRUN_OPR( \ | |||||
| cg::OperatorNodeBase* opr, megdnn::param::ExecutionPolicy policy) { \ | |||||
| auto&& p = opr->cast_final<opr::FASTRUN_OPR>(); \ | |||||
| p.set_execution_policy(policy); \ | |||||
| } | |||||
| DNN_FOREACH_FASTRUN_OPR(cb) | |||||
| #undef cb | |||||
| typedef thin_function<megdnn::param::ExecutionPolicy(cg::OperatorNodeBase*)> get_func; | |||||
| typedef thin_function<void(cg::OperatorNodeBase*, megdnn::param::ExecutionPolicy)> | |||||
| set_func; | |||||
| static const mgb::thin_hash_table::ThinHashMap< | |||||
| mgb::Typeinfo*, std::pair<get_func, set_func>>& | |||||
| get_type2policy() { | |||||
| static mgb::thin_hash_table::ThinHashMap< | |||||
| mgb::Typeinfo*, std::pair<get_func, set_func>> | |||||
| sl_type2policy; | |||||
| static std::once_flag flag; | |||||
| std::call_once(flag, [&]() { | |||||
| #define cb(FASTRUN_OPR) \ | |||||
| sl_type2policy[opr::FASTRUN_OPR::typeinfo()] = \ | |||||
| std::make_pair(get_strategy_##FASTRUN_OPR, set_strategy_##FASTRUN_OPR); | |||||
| DNN_FOREACH_FASTRUN_OPR(cb) | |||||
| }); | |||||
| return std::as_const(sl_type2policy); | |||||
| } | |||||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
| auto&& attr = def.cast_final_safe<OprAttr>(); | auto&& attr = def.cast_final_safe<OprAttr>(); | ||||
| auto config = attr.config; | auto config = attr.config; | ||||
| @@ -73,7 +113,12 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto registry = serialization::OprRegistry::find_by_name(attr.type); | auto registry = serialization::OprRegistry::find_by_name(attr.type); | ||||
| mgb_assert(registry, "operator %s not found", attr.type.c_str()); | mgb_assert(registry, "operator %s not found", attr.type.c_str()); | ||||
| OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | ||||
| return registry->loader(ctx, inputs, config).usable_output(); | |||||
| auto opr_with_accessor = registry->loader(ctx, inputs, config); | |||||
| auto&& opr = opr_with_accessor.opr(); | |||||
| if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { | |||||
| get_type2policy().at(opr->dyn_typeinfo()).second(opr, attr.policy); | |||||
| } | |||||
| return opr_with_accessor.usable_output(); | |||||
| } | } | ||||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | ||||
| @@ -84,7 +129,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||||
| registry->dumper, "operator %s cannot be serialized", | registry->dumper, "operator %s cannot be serialized", | ||||
| opr->dyn_typeinfo()->name); | opr->dyn_typeinfo()->name); | ||||
| registry->dumper(ctx, *opr); | registry->dumper(ctx, *opr); | ||||
| return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); | |||||
| megdnn::param::ExecutionPolicy policy; | |||||
| if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { | |||||
| policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr); | |||||
| } | |||||
| return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config()); | |||||
| } | } | ||||
| std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | ||||
| @@ -108,6 +157,8 @@ OP_TRAIT_REG(OprAttr, OprAttr) | |||||
| bool OprAttr::is_same_st(const Hashable& rhs_) const { | bool OprAttr::is_same_st(const Hashable& rhs_) const { | ||||
| auto&& rhs = static_cast<const OprAttr&>(rhs_); | auto&& rhs = static_cast<const OprAttr&>(rhs_); | ||||
| return type == rhs.type && param == rhs.param && | return type == rhs.type && param == rhs.param && | ||||
| policy.strategy == rhs.policy.strategy && | |||||
| policy.workspace_limit == rhs.policy.workspace_limit && | |||||
| config.comp_node() == rhs.config.comp_node() && | config.comp_node() == rhs.config.comp_node() && | ||||
| config.output_dtype() == rhs.config.output_dtype(); | config.output_dtype() == rhs.config.output_dtype(); | ||||
| } | } | ||||
| @@ -115,7 +166,12 @@ bool OprAttr::is_same_st(const Hashable& rhs_) const { | |||||
| size_t OprAttr::hash() const { | size_t OprAttr::hash() const { | ||||
| return hash_pair_combine( | return hash_pair_combine( | ||||
| hash_pair_combine( | hash_pair_combine( | ||||
| mgb::hash(type), mgb::hash(static_cast<std::vector<char>>(param))), | |||||
| hash_pair_combine( | |||||
| mgb::hash(type), | |||||
| mgb::hash(static_cast<std::vector<char>>(param))), | |||||
| hash_pair_combine( | |||||
| static_cast<size_t>(policy.strategy), | |||||
| policy.workspace_limit)), | |||||
| config.hash()); | config.hash()); | ||||
| } | } | ||||
| @@ -12,6 +12,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
| #include "megbrain/opr/param_defs.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -38,12 +39,16 @@ public: | |||||
| Type type; | Type type; | ||||
| Param param; | Param param; | ||||
| megdnn::param::ExecutionPolicy policy; | |||||
| cg::OperatorNodeConfig config; | cg::OperatorNodeConfig config; | ||||
| OprAttr() = default; | OprAttr() = default; | ||||
| OprAttr(const Type& t) : type(t) {} | OprAttr(const Type& t) : type(t) {} | ||||
| OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) | OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) | ||||
| : type(t), param(p), config(c) {} | : type(t), param(p), config(c) {} | ||||
| OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps, | |||||
| const cg::OperatorNodeConfig& c) | |||||
| : type(t), param(p), policy(ps), config(c) {} | |||||
| std::string repr() const; | std::string repr() const; | ||||
| @@ -157,6 +157,51 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
| } | } | ||||
| } | } | ||||
| TEST(TestImperative, ProfileBackward) { | |||||
| auto cn = CompNode::load("xpux"); | |||||
| using Policy = megdnn::param::ExecutionPolicy; | |||||
| using S = Policy::Strategy; | |||||
| Policy policy; | |||||
| policy.strategy = S::PROFILE; | |||||
| { | |||||
| megdnn::param::Convolution param; | |||||
| auto op = std::shared_ptr<OpDef>(Convolution::make(param, policy)); | |||||
| LogicalTensorDesc inp_desc = { | |||||
| TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; | |||||
| LogicalTensorDesc weight_desc = { | |||||
| TensorLayout({16, 3, 5, 5}, dtype::Float32()), cn}; | |||||
| auto bg = OpDef::make_backward_graph( | |||||
| *op, {inp_desc, weight_desc}, {true, false}, {true}); | |||||
| auto&& bop = (bg.graph.exprs.at(0)).op; | |||||
| auto&& attr = bop->cast_final_safe<OprAttr>(); | |||||
| // attr.type = ConvolutionBackwardDataV2 | |||||
| mgb_assert(attr.policy.strategy == S::PROFILE); | |||||
| } | |||||
| { | |||||
| megdnn::param::Pooling param; | |||||
| auto op = std::shared_ptr<OpDef>(Pooling::make(param, policy)); | |||||
| LogicalTensorDesc inp_desc = { | |||||
| TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; | |||||
| auto bg = OpDef::make_backward_graph(*op, {inp_desc}, {true}, {true}); | |||||
| auto&& bop = (bg.graph.exprs.at(0)).op; | |||||
| auto&& attr = bop->cast_final_safe<OprAttr>(); | |||||
| // attr.type = PoolingBackwardV1 | |||||
| mgb_assert(attr.policy.strategy == S::PROFILE); | |||||
| } | |||||
| { | |||||
| megdnn::param::MatrixMul param; | |||||
| auto op = std::shared_ptr<OpDef>(MatrixMul::make(param, policy, 2, 2)); | |||||
| LogicalTensorDesc inp1_desc = {TensorLayout({12, 16}, dtype::Float32()), cn}; | |||||
| LogicalTensorDesc inp2_desc = {TensorLayout({16, 20}, dtype::Float32()), cn}; | |||||
| auto bg = OpDef::make_backward_graph( | |||||
| *op, {inp1_desc, inp2_desc}, {true, false}, {true}); | |||||
| auto&& bop = (bg.graph.exprs.at(0)).op; | |||||
| auto&& attr = bop->cast_final_safe<OprAttr>(); | |||||
| // attr.type = MatrixMulV2 | |||||
| mgb_assert(attr.policy.strategy == S::PROFILE); | |||||
| } | |||||
| } | |||||
| TEST(TestImperative, BackwardGraphIdentity) { | TEST(TestImperative, BackwardGraphIdentity) { | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto host_a = gen({42}), host_dc = gen({42}); | auto host_a = gen({42}), host_dc = gen({42}); | ||||
| @@ -185,17 +185,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { | |||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| // A * B = C, A' = C' * Bt | // A * B = C, A' = C' * Bt | ||||
| if (opr.param().transposeA) { | if (opr.param().transposeA) { | ||||
| grad = MatrixMul::make(i1, og, {opr.param().transposeB, true}); | |||||
| grad = MatrixMul::make( | |||||
| i1, og, {opr.param().transposeB, true}, opr.execution_policy()); | |||||
| } else { | } else { | ||||
| grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB}); | |||||
| grad = MatrixMul::make( | |||||
| og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); | |||||
| } | } | ||||
| } else { | } else { | ||||
| mgb_assert(wrt_idx == 1); | mgb_assert(wrt_idx == 1); | ||||
| // A * B = C, B' = At * C' | // A * B = C, B' = At * C' | ||||
| if (opr.param().transposeB) { | if (opr.param().transposeB) { | ||||
| grad = MatrixMul::make(og, i0, {true, opr.param().transposeA}); | |||||
| grad = MatrixMul::make( | |||||
| og, i0, {true, opr.param().transposeA}, opr.execution_policy()); | |||||
| } else { | } else { | ||||
| grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false}); | |||||
| grad = MatrixMul::make( | |||||
| i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); | |||||
| } | } | ||||
| } | } | ||||
| return grad.node(); | return grad.node(); | ||||
| @@ -358,17 +362,21 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | |||||
| if (wrt_idx == 0) { | if (wrt_idx == 0) { | ||||
| // A * B = C, A' = C' * Bt | // A * B = C, A' = C' * Bt | ||||
| if (opr.param().transposeA) { | if (opr.param().transposeA) { | ||||
| grad = BatchedMatrixMul::make(i1, og, {opr.param().transposeB, true}); | |||||
| grad = BatchedMatrixMul::make( | |||||
| i1, og, {opr.param().transposeB, true}, opr.execution_policy()); | |||||
| } else { | } else { | ||||
| grad = BatchedMatrixMul::make(og, i1, {false, !opr.param().transposeB}); | |||||
| grad = BatchedMatrixMul::make( | |||||
| og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); | |||||
| } | } | ||||
| } else { | } else { | ||||
| mgb_assert(wrt_idx == 1); | mgb_assert(wrt_idx == 1); | ||||
| // A * B = C, B' = At * C' | // A * B = C, B' = At * C' | ||||
| if (opr.param().transposeB) { | if (opr.param().transposeB) { | ||||
| grad = BatchedMatrixMul::make(og, i0, {true, opr.param().transposeA}); | |||||
| grad = BatchedMatrixMul::make( | |||||
| og, i0, {true, opr.param().transposeA}, opr.execution_policy()); | |||||
| } else { | } else { | ||||
| grad = BatchedMatrixMul::make(i0, og, {!opr.param().transposeA, false}); | |||||
| grad = BatchedMatrixMul::make( | |||||
| i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); | |||||
| } | } | ||||
| } | } | ||||
| return grad.node(); | return grad.node(); | ||||
| @@ -59,7 +59,8 @@ size_t PoolingForward::get_workspace_size_bytes( | |||||
| MGB_IMPL_OPR_GRAD(PoolingForward) { | MGB_IMPL_OPR_GRAD(PoolingForward) { | ||||
| mgb_assert(wrt_idx == 0); | mgb_assert(wrt_idx == 0); | ||||
| SymbolVar grad = PoolingBackward::make( | SymbolVar grad = PoolingBackward::make( | ||||
| opr.input(0), opr.output(0), out_grad[0], opr.param()); | |||||
| opr.input(0), opr.output(0), out_grad[0], opr.param(), | |||||
| opr.execution_policy()); | |||||
| return grad.node(); | return grad.node(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -26,7 +26,7 @@ namespace opr { | |||||
| /*! | /*! | ||||
| * \brief matrix_mul(trans0(opr0), trans1(opr1)) | * \brief matrix_mul(trans0(opr0), trans1(opr1)) | ||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS( | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>, | MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>, | ||||
| public mixin::AlgoChooserHelper) // { | public mixin::AlgoChooserHelper) // { | ||||
| public: | public: | ||||
| @@ -57,7 +57,7 @@ private: | |||||
| /*! | /*! | ||||
| * \brief batched matrix multiplication on 3D inputs | * \brief batched matrix multiplication on 3D inputs | ||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS( | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| BatchedMatrixMul, intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>, | BatchedMatrixMul, intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>, | ||||
| public mixin::AlgoChooserHelper) // { | public mixin::AlgoChooserHelper) // { | ||||
| public: | public: | ||||
| @@ -18,7 +18,7 @@ | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace opr { | namespace opr { | ||||
| MGB_DEFINE_OPR_CLASS( | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | ||||
| public mixin::AlgoChooserHelper) // { | public mixin::AlgoChooserHelper) // { | ||||
| public: | public: | ||||
| @@ -37,7 +37,7 @@ public: | |||||
| }; | }; | ||||
| using Pooling = PoolingForward; | using Pooling = PoolingForward; | ||||
| MGB_DEFINE_OPR_CLASS( | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | ||||
| public mixin::AlgoChooserHelper) // { | public mixin::AlgoChooserHelper) // { | ||||
| public: | public: | ||||
| @@ -51,7 +51,7 @@ public: | |||||
| * Exception would be thrown if execution_policy() has been accessed, | * Exception would be thrown if execution_policy() has been accessed, | ||||
| * since it would influence cache and many other decisions. | * since it would influence cache and many other decisions. | ||||
| */ | */ | ||||
| void set_execution_policy(const ExecutionPolicy& policy); | |||||
| MGE_WIN_DECLSPEC_FUC void set_execution_policy(const ExecutionPolicy& policy); | |||||
| /*! | /*! | ||||
| * \brief register a hook to implement custom algo chooser | * \brief register a hook to implement custom algo chooser | ||||