GitOrigin-RevId: bb7ab8fa9d
tags/v1.10.0
| @@ -70,15 +70,6 @@ def _matmul( | |||
| maxdim = dim1 if dim1 > dim2 else dim2 | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| Strategy = builtin.ops.MatrixMul.Strategy | |||
| strategy = Strategy(0) | |||
| if _config._benchmark_kernel: | |||
| strategy |= Strategy.PROFILE | |||
| else: | |||
| strategy |= Strategy.HEURISTIC | |||
| if _config._deterministic_kernel: | |||
| strategy |= Strategy.REPRODUCIBLE | |||
| if dim1 == 1 and dim2 == 1: # dispatch to Dot | |||
| (result,) = apply(builtin.Dot(), inp1, inp2) | |||
| return result | |||
| @@ -621,6 +621,7 @@ def max_pool2d( | |||
| pad_h=padding_h, | |||
| pad_w=padding_w, | |||
| mode="max", | |||
| strategy=get_execution_strategy(), | |||
| format=conv_format, | |||
| ) | |||
| (output,) = apply(op, inp) | |||
| @@ -665,6 +666,7 @@ def avg_pool2d( | |||
| pad_h=padding_h, | |||
| pad_w=padding_w, | |||
| mode=mode, | |||
| strategy=get_execution_strategy(), | |||
| format=conv_format, | |||
| ) | |||
| (output,) = apply(op, inp) | |||
| @@ -1493,7 +1493,7 @@ py::object _transpose_cpp(py::handle inp_hdl, py::handle args) { | |||
| py::object _matmul_cpp( | |||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
| py::handle profile, py::handle determistic) { | |||
| py::handle profile, py::handle deterministic) { | |||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
| @@ -1506,7 +1506,7 @@ py::object _matmul_cpp( | |||
| } else { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
| } | |||
| if (determistic.cast<bool>()) { | |||
| if (deterministic.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
| } | |||
| std::shared_ptr<OpDef> op = MatrixMul::make( | |||
| @@ -1523,7 +1523,7 @@ py::object _matmul_cpp( | |||
| py::object _batched_matmul_cpp( | |||
| py::handle inp1, py::handle inp2, py::handle dim1, py::handle dim2, | |||
| py::handle transpose_a, py::handle transpose_b, py::handle compute_mode, | |||
| py::handle profile, py::handle determistic) { | |||
| py::handle profile, py::handle deterministic) { | |||
| ::megdnn::param::MatrixMul::ComputeMode mode = | |||
| ::megdnn::param::MatrixMul::ComputeMode::DEFAULT; | |||
| if (compute_mode.cast<std::string>().compare(std::string("float32")) == 0) { | |||
| @@ -1536,7 +1536,7 @@ py::object _batched_matmul_cpp( | |||
| } else { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::HEURISTIC; | |||
| } | |||
| if (determistic.cast<bool>()) { | |||
| if (deterministic.cast<bool>()) { | |||
| cstrategy |= ::megdnn::param::ExecutionPolicy::Strategy::REPRODUCIBLE; | |||
| } | |||
| std::shared_ptr<OpDef> op = BatchedMatrixMul::make( | |||
| @@ -10,6 +10,10 @@ | |||
| */ | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/rdnn/profiler.h" | |||
| #include "megbrain/serialization/opr_load_dump.h" | |||
| #include "../op_trait.h" | |||
| @@ -65,6 +69,42 @@ public: | |||
| const serialization::GraphDumpConfig& config() const { mgb_assert(0); } | |||
| }; | |||
| #define cb(FASTRUN_OPR) \ | |||
| megdnn::param::ExecutionPolicy get_strategy_##FASTRUN_OPR( \ | |||
| cg::OperatorNodeBase* opr) { \ | |||
| auto policy = \ | |||
| opr->cast_final<opr::FASTRUN_OPR>().execution_policy_transient(); \ | |||
| return policy; \ | |||
| } \ | |||
| void set_strategy_##FASTRUN_OPR( \ | |||
| cg::OperatorNodeBase* opr, megdnn::param::ExecutionPolicy policy) { \ | |||
| auto&& p = opr->cast_final<opr::FASTRUN_OPR>(); \ | |||
| p.set_execution_policy(policy); \ | |||
| } | |||
| DNN_FOREACH_FASTRUN_OPR(cb) | |||
| #undef cb | |||
| typedef thin_function<megdnn::param::ExecutionPolicy(cg::OperatorNodeBase*)> get_func; | |||
| typedef thin_function<void(cg::OperatorNodeBase*, megdnn::param::ExecutionPolicy)> | |||
| set_func; | |||
| static const mgb::thin_hash_table::ThinHashMap< | |||
| mgb::Typeinfo*, std::pair<get_func, set_func>>& | |||
| get_type2policy() { | |||
| static mgb::thin_hash_table::ThinHashMap< | |||
| mgb::Typeinfo*, std::pair<get_func, set_func>> | |||
| sl_type2policy; | |||
| static std::once_flag flag; | |||
| std::call_once(flag, [&]() { | |||
| #define cb(FASTRUN_OPR) \ | |||
| sl_type2policy[opr::FASTRUN_OPR::typeinfo()] = \ | |||
| std::make_pair(get_strategy_##FASTRUN_OPR, set_strategy_##FASTRUN_OPR); | |||
| DNN_FOREACH_FASTRUN_OPR(cb) | |||
| }); | |||
| return std::as_const(sl_type2policy); | |||
| } | |||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& attr = def.cast_final_safe<OprAttr>(); | |||
| auto config = attr.config; | |||
| @@ -73,7 +113,12 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto registry = serialization::OprRegistry::find_by_name(attr.type); | |||
| mgb_assert(registry, "operator %s not found", attr.type.c_str()); | |||
| OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()}; | |||
| return registry->loader(ctx, inputs, config).usable_output(); | |||
| auto opr_with_accessor = registry->loader(ctx, inputs, config); | |||
| auto&& opr = opr_with_accessor.opr(); | |||
| if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { | |||
| get_type2policy().at(opr->dyn_typeinfo()).second(opr, attr.policy); | |||
| } | |||
| return opr_with_accessor.usable_output(); | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
| @@ -84,7 +129,11 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) { | |||
| registry->dumper, "operator %s cannot be serialized", | |||
| opr->dyn_typeinfo()->name); | |||
| registry->dumper(ctx, *opr); | |||
| return OprAttr::make(registry->name, std::move(ctx.m_param), opr->config()); | |||
| megdnn::param::ExecutionPolicy policy; | |||
| if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) { | |||
| policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr); | |||
| } | |||
| return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config()); | |||
| } | |||
| std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
| @@ -108,6 +157,8 @@ OP_TRAIT_REG(OprAttr, OprAttr) | |||
| bool OprAttr::is_same_st(const Hashable& rhs_) const { | |||
| auto&& rhs = static_cast<const OprAttr&>(rhs_); | |||
| return type == rhs.type && param == rhs.param && | |||
| policy.strategy == rhs.policy.strategy && | |||
| policy.workspace_limit == rhs.policy.workspace_limit && | |||
| config.comp_node() == rhs.config.comp_node() && | |||
| config.output_dtype() == rhs.config.output_dtype(); | |||
| } | |||
| @@ -115,7 +166,12 @@ bool OprAttr::is_same_st(const Hashable& rhs_) const { | |||
| size_t OprAttr::hash() const { | |||
| return hash_pair_combine( | |||
| hash_pair_combine( | |||
| mgb::hash(type), mgb::hash(static_cast<std::vector<char>>(param))), | |||
| hash_pair_combine( | |||
| mgb::hash(type), | |||
| mgb::hash(static_cast<std::vector<char>>(param))), | |||
| hash_pair_combine( | |||
| static_cast<size_t>(policy.strategy), | |||
| policy.workspace_limit)), | |||
| config.hash()); | |||
| } | |||
| @@ -12,6 +12,7 @@ | |||
| #pragma once | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/opr/param_defs.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -38,12 +39,16 @@ public: | |||
| Type type; | |||
| Param param; | |||
| megdnn::param::ExecutionPolicy policy; | |||
| cg::OperatorNodeConfig config; | |||
| OprAttr() = default; | |||
| OprAttr(const Type& t) : type(t) {} | |||
| OprAttr(const Type& t, const Param& p, const cg::OperatorNodeConfig& c) | |||
| : type(t), param(p), config(c) {} | |||
| OprAttr(const Type& t, const Param& p, const megdnn::param::ExecutionPolicy ps, | |||
| const cg::OperatorNodeConfig& c) | |||
| : type(t), param(p), policy(ps), config(c) {} | |||
| std::string repr() const; | |||
| @@ -157,6 +157,51 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| } | |||
| } | |||
| TEST(TestImperative, ProfileBackward) { | |||
| auto cn = CompNode::load("xpux"); | |||
| using Policy = megdnn::param::ExecutionPolicy; | |||
| using S = Policy::Strategy; | |||
| Policy policy; | |||
| policy.strategy = S::PROFILE; | |||
| { | |||
| megdnn::param::Convolution param; | |||
| auto op = std::shared_ptr<OpDef>(Convolution::make(param, policy)); | |||
| LogicalTensorDesc inp_desc = { | |||
| TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; | |||
| LogicalTensorDesc weight_desc = { | |||
| TensorLayout({16, 3, 5, 5}, dtype::Float32()), cn}; | |||
| auto bg = OpDef::make_backward_graph( | |||
| *op, {inp_desc, weight_desc}, {true, false}, {true}); | |||
| auto&& bop = (bg.graph.exprs.at(0)).op; | |||
| auto&& attr = bop->cast_final_safe<OprAttr>(); | |||
| // attr.type = ConvolutionBackwardDataV2 | |||
| mgb_assert(attr.policy.strategy == S::PROFILE); | |||
| } | |||
| { | |||
| megdnn::param::Pooling param; | |||
| auto op = std::shared_ptr<OpDef>(Pooling::make(param, policy)); | |||
| LogicalTensorDesc inp_desc = { | |||
| TensorLayout({16, 3, 16, 16}, dtype::Float32()), cn}; | |||
| auto bg = OpDef::make_backward_graph(*op, {inp_desc}, {true}, {true}); | |||
| auto&& bop = (bg.graph.exprs.at(0)).op; | |||
| auto&& attr = bop->cast_final_safe<OprAttr>(); | |||
| // attr.type = PoolingBackwardV1 | |||
| mgb_assert(attr.policy.strategy == S::PROFILE); | |||
| } | |||
| { | |||
| megdnn::param::MatrixMul param; | |||
| auto op = std::shared_ptr<OpDef>(MatrixMul::make(param, policy, 2, 2)); | |||
| LogicalTensorDesc inp1_desc = {TensorLayout({12, 16}, dtype::Float32()), cn}; | |||
| LogicalTensorDesc inp2_desc = {TensorLayout({16, 20}, dtype::Float32()), cn}; | |||
| auto bg = OpDef::make_backward_graph( | |||
| *op, {inp1_desc, inp2_desc}, {true, false}, {true}); | |||
| auto&& bop = (bg.graph.exprs.at(0)).op; | |||
| auto&& attr = bop->cast_final_safe<OprAttr>(); | |||
| // attr.type = MatrixMulV2 | |||
| mgb_assert(attr.policy.strategy == S::PROFILE); | |||
| } | |||
| } | |||
| TEST(TestImperative, BackwardGraphIdentity) { | |||
| HostTensorGenerator<> gen; | |||
| auto host_a = gen({42}), host_dc = gen({42}); | |||
| @@ -185,17 +185,21 @@ MGB_IMPL_OPR_GRAD(MatrixMul) { | |||
| if (wrt_idx == 0) { | |||
| // A * B = C, A' = C' * Bt | |||
| if (opr.param().transposeA) { | |||
| grad = MatrixMul::make(i1, og, {opr.param().transposeB, true}); | |||
| grad = MatrixMul::make( | |||
| i1, og, {opr.param().transposeB, true}, opr.execution_policy()); | |||
| } else { | |||
| grad = MatrixMul::make(og, i1, {false, !opr.param().transposeB}); | |||
| grad = MatrixMul::make( | |||
| og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); | |||
| } | |||
| } else { | |||
| mgb_assert(wrt_idx == 1); | |||
| // A * B = C, B' = At * C' | |||
| if (opr.param().transposeB) { | |||
| grad = MatrixMul::make(og, i0, {true, opr.param().transposeA}); | |||
| grad = MatrixMul::make( | |||
| og, i0, {true, opr.param().transposeA}, opr.execution_policy()); | |||
| } else { | |||
| grad = MatrixMul::make(i0, og, {!opr.param().transposeA, false}); | |||
| grad = MatrixMul::make( | |||
| i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); | |||
| } | |||
| } | |||
| return grad.node(); | |||
| @@ -358,17 +362,21 @@ MGB_IMPL_OPR_GRAD(BatchedMatrixMul) { | |||
| if (wrt_idx == 0) { | |||
| // A * B = C, A' = C' * Bt | |||
| if (opr.param().transposeA) { | |||
| grad = BatchedMatrixMul::make(i1, og, {opr.param().transposeB, true}); | |||
| grad = BatchedMatrixMul::make( | |||
| i1, og, {opr.param().transposeB, true}, opr.execution_policy()); | |||
| } else { | |||
| grad = BatchedMatrixMul::make(og, i1, {false, !opr.param().transposeB}); | |||
| grad = BatchedMatrixMul::make( | |||
| og, i1, {false, !opr.param().transposeB}, opr.execution_policy()); | |||
| } | |||
| } else { | |||
| mgb_assert(wrt_idx == 1); | |||
| // A * B = C, B' = At * C' | |||
| if (opr.param().transposeB) { | |||
| grad = BatchedMatrixMul::make(og, i0, {true, opr.param().transposeA}); | |||
| grad = BatchedMatrixMul::make( | |||
| og, i0, {true, opr.param().transposeA}, opr.execution_policy()); | |||
| } else { | |||
| grad = BatchedMatrixMul::make(i0, og, {!opr.param().transposeA, false}); | |||
| grad = BatchedMatrixMul::make( | |||
| i0, og, {!opr.param().transposeA, false}, opr.execution_policy()); | |||
| } | |||
| } | |||
| return grad.node(); | |||
| @@ -59,7 +59,8 @@ size_t PoolingForward::get_workspace_size_bytes( | |||
| MGB_IMPL_OPR_GRAD(PoolingForward) { | |||
| mgb_assert(wrt_idx == 0); | |||
| SymbolVar grad = PoolingBackward::make( | |||
| opr.input(0), opr.output(0), out_grad[0], opr.param()); | |||
| opr.input(0), opr.output(0), out_grad[0], opr.param(), | |||
| opr.execution_policy()); | |||
| return grad.node(); | |||
| } | |||
| #endif | |||
| @@ -26,7 +26,7 @@ namespace opr { | |||
| /*! | |||
| * \brief matrix_mul(trans0(opr0), trans1(opr1)) | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS( | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| MatrixMul, intl::MegDNNOprWrapperFwd<megdnn::MatrixMul>, | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| @@ -57,7 +57,7 @@ private: | |||
| /*! | |||
| * \brief batched matrix multiplication on 3D inputs | |||
| */ | |||
| MGB_DEFINE_OPR_CLASS( | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| BatchedMatrixMul, intl::MegDNNOprWrapperFwd<megdnn::BatchedMatrixMul>, | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| @@ -18,7 +18,7 @@ | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS( | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| @@ -37,7 +37,7 @@ public: | |||
| }; | |||
| using Pooling = PoolingForward; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | |||
| public mixin::AlgoChooserHelper) // { | |||
| public: | |||
| @@ -51,7 +51,7 @@ public: | |||
| * Exception would be thrown if execution_policy() has been accessed, | |||
| * since it would influence cache and many other decisions. | |||
| */ | |||
| void set_execution_policy(const ExecutionPolicy& policy); | |||
| MGE_WIN_DECLSPEC_FUC void set_execution_policy(const ExecutionPolicy& policy); | |||
| /*! | |||
| * \brief register a hook to implement custom algo chooser | |||