GitOrigin-RevId: 4738136e4a
tags/v1.7.2.m1
| @@ -183,7 +183,7 @@ namespace pooling { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | ||||
| auto&& pool = static_cast<const Pooling&>(def); | auto&& pool = static_cast<const Pooling&>(def); | ||||
| OperatorNodeConfig config{pool.make_name()}; | OperatorNodeConfig config{pool.make_name()}; | ||||
| return opr::Pooling::make(inputs[0], pool.param(), config); | |||||
| return opr::Pooling::make(inputs[0], pool.param(), pool.policy(), config); | |||||
| } | } | ||||
| OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | ||||
| } // namespace pooling | } // namespace pooling | ||||
| @@ -63,7 +63,7 @@ def DeformableConv : MgbHashableOp<"DeformableConv", [ConvolutionParam, Executio | |||||
| def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | def GroupLocal: MgbHashableOp<"GroupLocal", [ConvolutionParam]>; | ||||
| def Pooling: MgbHashableOp<"Pooling", [PoolingParam]>; | |||||
| def Pooling: MgbHashableOp<"Pooling", [PoolingParam, ExecutionPolicyParamBase<"policy">]>; | |||||
| def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | def AdaptivePooling : MgbHashableOp<"AdaptivePooling", [AdaptivePoolingParam]>; | ||||
| @@ -31,7 +31,21 @@ using namespace opr; | |||||
| namespace { | namespace { | ||||
| template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
| struct MakeConvCaller2 { | |||||
| struct MakeOprWithPolicyCaller1 { | |||||
| template <typename Opr> | |||||
| static VarNode* make( | |||||
| const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | |||||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||||
| const OperatorNodeConfig& config) { | |||||
| if (inputs.size() == 1) { | |||||
| return Opr::make(inputs[0], param, execution_policy, config).node(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| }; | |||||
| template <class MegDNNConv = megdnn::Convolution> | |||||
| struct MakeOprWithPolicyCaller2 { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
| @@ -46,7 +60,7 @@ struct MakeConvCaller2 { | |||||
| }; | }; | ||||
| template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
| struct MakeConvCaller3 { | |||||
| struct MakeOprWithPolicyCaller3 { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
| @@ -63,7 +77,7 @@ struct MakeConvCaller3 { | |||||
| }; | }; | ||||
| template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
| struct MakeConvCaller4 { | |||||
| struct MakeOprWithPolicyCaller4 { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
| @@ -80,7 +94,7 @@ struct MakeConvCaller4 { | |||||
| }; | }; | ||||
| template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
| struct MakeConvCaller5 { | |||||
| struct MakeOprWithPolicyCaller5 { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, | ||||
| @@ -97,7 +111,7 @@ struct MakeConvCaller5 { | |||||
| }; | }; | ||||
| template <class MegDNNConv = megdnn::Convolution> | template <class MegDNNConv = megdnn::Convolution> | ||||
| struct MakeConvCallerEmpty { | |||||
| struct MakeOprWithPolicyCallerEmpty { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray&, const typename MegDNNConv::Param&, | const cg::VarNodeArray&, const typename MegDNNConv::Param&, | ||||
| @@ -108,10 +122,10 @@ struct MakeConvCallerEmpty { | |||||
| template < | template < | ||||
| class Opr, class Maker0, class MegDNNConv, | class Opr, class Maker0, class MegDNNConv, | ||||
| class Maker1 = MakeConvCallerEmpty<MegDNNConv>, | |||||
| class Maker2 = MakeConvCallerEmpty<MegDNNConv>, | |||||
| typename ConvParam = megdnn::param::Convolution> | |||||
| struct ConvMakerImpl { | |||||
| class Maker1 = MakeOprWithPolicyCallerEmpty<MegDNNConv>, | |||||
| class Maker2 = MakeOprWithPolicyCallerEmpty<MegDNNConv>, | |||||
| typename ConvParam = typename MegDNNConv::Param> | |||||
| struct OprWithPolicyMakerImpl { | |||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const ConvParam& param, | const cg::VarNodeArray& inputs, const ConvParam& param, | ||||
| const megdnn::param::ExecutionPolicy& execution_policy, | const megdnn::param::ExecutionPolicy& execution_policy, | ||||
| @@ -130,33 +144,43 @@ struct ConvMakerImpl { | |||||
| }; | }; | ||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvMaker; | |||||
| struct OprWithPolicyMaker; | |||||
| template <> | |||||
| struct OprWithPolicyMaker<opr::Pooling> | |||||
| : public OprWithPolicyMakerImpl< | |||||
| opr::Pooling, MakeOprWithPolicyCaller1<megdnn::Pooling>, | |||||
| megdnn::Pooling> {}; | |||||
| template <> | template <> | ||||
| struct ConvMaker<opr::Convolution> | |||||
| : public ConvMakerImpl< | |||||
| opr::Convolution, MakeConvCaller2<megdnn::Convolution>, | |||||
| struct OprWithPolicyMaker<opr::Convolution> | |||||
| : public OprWithPolicyMakerImpl< | |||||
| opr::Convolution, MakeOprWithPolicyCaller2<megdnn::Convolution>, | |||||
| megdnn::Convolution> {}; | megdnn::Convolution> {}; | ||||
| template <> | template <> | ||||
| struct ConvMaker<opr::ConvolutionBackwardData> | |||||
| : public ConvMakerImpl< | |||||
| opr::ConvolutionBackwardData, MakeConvCaller2<megdnn::Convolution>, | |||||
| megdnn::Convolution, MakeConvCaller3<megdnn::Convolution>> {}; | |||||
| struct OprWithPolicyMaker<opr::ConvolutionBackwardData> | |||||
| : public OprWithPolicyMakerImpl< | |||||
| opr::ConvolutionBackwardData, | |||||
| MakeOprWithPolicyCaller2<megdnn::Convolution>, megdnn::Convolution, | |||||
| MakeOprWithPolicyCaller3<megdnn::Convolution>> {}; | |||||
| template <> | template <> | ||||
| struct ConvMaker<opr::ConvBiasForward> | |||||
| : public ConvMakerImpl< | |||||
| opr::ConvBiasForward, MakeConvCaller2<megdnn::ConvBiasForward>, | |||||
| megdnn::ConvBiasForward, MakeConvCaller3<megdnn::ConvBiasForward>, | |||||
| MakeConvCaller4<megdnn::ConvBiasForward>, megdnn::param::ConvBias> {}; | |||||
| struct OprWithPolicyMaker<opr::ConvBiasForward> | |||||
| : public OprWithPolicyMakerImpl< | |||||
| opr::ConvBiasForward, | |||||
| MakeOprWithPolicyCaller2<megdnn::ConvBiasForward>, | |||||
| megdnn::ConvBiasForward, | |||||
| MakeOprWithPolicyCaller3<megdnn::ConvBiasForward>, | |||||
| MakeOprWithPolicyCaller4<megdnn::ConvBiasForward>, | |||||
| megdnn::param::ConvBias> {}; | |||||
| template <> | template <> | ||||
| struct ConvMaker<opr::BatchConvBiasForward> | |||||
| : public ConvMakerImpl< | |||||
| struct OprWithPolicyMaker<opr::BatchConvBiasForward> | |||||
| : public OprWithPolicyMakerImpl< | |||||
| opr::BatchConvBiasForward, | opr::BatchConvBiasForward, | ||||
| MakeConvCaller2<megdnn::BatchConvBiasForward>, | |||||
| MakeOprWithPolicyCaller2<megdnn::BatchConvBiasForward>, | |||||
| megdnn::BatchConvBiasForward, | megdnn::BatchConvBiasForward, | ||||
| MakeConvCaller3<megdnn::BatchConvBiasForward>, | |||||
| MakeConvCaller4<megdnn::BatchConvBiasForward>, | |||||
| MakeOprWithPolicyCaller3<megdnn::BatchConvBiasForward>, | |||||
| MakeOprWithPolicyCaller4<megdnn::BatchConvBiasForward>, | |||||
| megdnn::param::BatchConvBias> {}; | megdnn::param::BatchConvBias> {}; | ||||
| #include "../../opr/impl/internal/invoke.h" | #include "../../opr/impl/internal/invoke.h" | ||||
| @@ -254,7 +278,7 @@ struct OprFormatModifier; | |||||
| auto&& opr = opr_->cast_final_safe<_Opr>(); \ | auto&& opr = opr_->cast_final_safe<_Opr>(); \ | ||||
| auto param = opr.param(); \ | auto param = opr.param(); \ | ||||
| param.format = opr_format; \ | param.format = opr_format; \ | ||||
| return ConvMaker<_Opr>::make( \ | |||||
| return OprWithPolicyMaker<_Opr>::make( \ | |||||
| i, param, opr.execution_policy(), opr.config()); \ | i, param, opr.execution_policy(), opr.config()); \ | ||||
| MIDOUT_E \ | MIDOUT_E \ | ||||
| } \ | } \ | ||||
| @@ -263,6 +287,7 @@ INST(Convolution); | |||||
| INST(ConvBiasForward); | INST(ConvBiasForward); | ||||
| INST(ConvolutionBackwardData); | INST(ConvolutionBackwardData); | ||||
| INST(BatchConvBiasForward); | INST(BatchConvBiasForward); | ||||
| INST(Pooling); | |||||
| #undef INST | #undef INST | ||||
| template <> | template <> | ||||
| @@ -303,7 +328,6 @@ struct OprFormatModifier<WarpPerspective> { | |||||
| MIDOUT_E \ | MIDOUT_E \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| INST(PoolingForward, 1); | |||||
| INST(Resize, 2); | INST(Resize, 2); | ||||
| #undef INST | #undef INST | ||||
| @@ -1492,7 +1492,8 @@ std::unique_ptr<ConvertFormatPass> ConvertFormatPass::make_nhwcd4_converter() { | |||||
| } | } | ||||
| auto new_param = pooling_opr.param(); | auto new_param = pooling_opr.param(); | ||||
| new_param.format = megdnn::param::Pooling::Format::NHWCD4; | new_param.format = megdnn::param::Pooling::Format::NHWCD4; | ||||
| auto new_pooling_opr = opr::PoolingForward::make(inp, new_param, opr->config()); | |||||
| auto new_pooling_opr = opr::PoolingForward::make( | |||||
| inp, new_param, pooling_opr.execution_policy(), opr->config()); | |||||
| return new_pooling_opr.node()->owner_opr(); | return new_pooling_opr.node()->owner_opr(); | ||||
| }; | }; | ||||
| @@ -525,8 +525,8 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: | |||||
| } | } | ||||
| auto new_param = pooling.param(); | auto new_param = pooling.param(); | ||||
| new_param.format = Format::NCHW32; | new_param.format = Format::NCHW32; | ||||
| auto new_pooling = | |||||
| opr::PoolingForward::make(new_inp_var, new_param, opr->config()); | |||||
| auto new_pooling = opr::PoolingForward::make( | |||||
| new_inp_var, new_param, pooling.execution_policy(), opr->config()); | |||||
| return new_pooling.node()->owner_opr(); | return new_pooling.node()->owner_opr(); | ||||
| } | } | ||||
| return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | ||||
| @@ -795,8 +795,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() { | |||||
| if (varshape_changed.count(new_inp[0])) { | if (varshape_changed.count(new_inp[0])) { | ||||
| auto new_param = pooling.param(); | auto new_param = pooling.param(); | ||||
| new_param.format = Format::CHWN4; | new_param.format = Format::CHWN4; | ||||
| auto new_pooling = | |||||
| opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
| auto new_pooling = opr::PoolingForward::make( | |||||
| new_inp[0], new_param, pooling.execution_policy(), opr->config()); | |||||
| varshape_changed.insert(new_pooling.node()); | varshape_changed.insert(new_pooling.node()); | ||||
| return new_pooling.node()->owner_opr(); | return new_pooling.node()->owner_opr(); | ||||
| } | } | ||||
| @@ -1174,8 +1174,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { | |||||
| mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8); | ||||
| auto new_param = pooling.param(); | auto new_param = pooling.param(); | ||||
| new_param.format = Format::NCHW4; | new_param.format = Format::NCHW4; | ||||
| auto new_pooling = | |||||
| opr::PoolingForward::make(new_inp[0], new_param, opr->config()); | |||||
| auto new_pooling = opr::PoolingForward::make( | |||||
| new_inp[0], new_param, pooling.execution_policy(), opr->config()); | |||||
| mgb_assert( | mgb_assert( | ||||
| new_pooling.shape().ndim == 5, | new_pooling.shape().ndim == 5, | ||||
| "out var of Pooling opr after transform must be 5 (got: " | "out var of Pooling opr after transform must be 5 (got: " | ||||
| @@ -1646,8 +1646,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
| if (inp->shape().ndim == 5) { | if (inp->shape().ndim == 5) { | ||||
| auto new_param = pooling_opr.param(); | auto new_param = pooling_opr.param(); | ||||
| new_param.format = pooling_format; | new_param.format = pooling_format; | ||||
| auto new_pooling_opr = | |||||
| opr::PoolingForward::make(inp, new_param, opr->config()); | |||||
| auto new_pooling_opr = opr::PoolingForward::make( | |||||
| inp, new_param, pooling_opr.execution_policy(), opr->config()); | |||||
| mgb_assert( | mgb_assert( | ||||
| new_pooling_opr.shape().ndim == 5, | new_pooling_opr.shape().ndim == 5, | ||||
| "The pooling dst dim is not trans to nchwxx"); | "The pooling dst dim is not trans to nchwxx"); | ||||
| @@ -3003,7 +3003,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() { | |||||
| auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; | auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC; | ||||
| auto param = pooling.param(); | auto param = pooling.param(); | ||||
| param.format = target_format; | param.format = target_format; | ||||
| auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); | |||||
| auto new_pool = opr::PoolingForward::make( | |||||
| inps[0], param, pooling.execution_policy(), pooling.config()); | |||||
| auto ret = new_pool.node()->owner_opr(); | auto ret = new_pool.node()->owner_opr(); | ||||
| format_map.insert(std::make_pair(ret, target_format)); | format_map.insert(std::make_pair(ret, target_format)); | ||||
| return ret; | return ret; | ||||
| @@ -3055,7 +3056,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() { | |||||
| auto param = pooling.param(); | auto param = pooling.param(); | ||||
| param.format = out_format; | param.format = out_format; | ||||
| auto new_pool = opr::PoolingForward::make(inps[0], param, pooling.config()); | |||||
| auto new_pool = opr::PoolingForward::make( | |||||
| inps[0], param, pooling.execution_policy(), pooling.config()); | |||||
| auto ret = new_pool.node()->owner_opr(); | auto ret = new_pool.node()->owner_opr(); | ||||
| format_map.insert(std::make_pair(ret, out_format)); | format_map.insert(std::make_pair(ret, out_format)); | ||||
| return ret; | return ret; | ||||
| @@ -281,7 +281,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) { | |||||
| auto new_out_var = new_output[0]; | auto new_out_var = new_output[0]; | ||||
| /// check global layout transform pass | /// check global layout transform pass | ||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | ||||
| ASSERT_EQ(nr_dimshuffle, 3u); | |||||
| ASSERT_EQ(nr_dimshuffle, 5u); | |||||
| /// check pass fuse conv bias with z | /// check pass fuse conv bias with z | ||||
| auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var); | auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var); | ||||
| ASSERT_EQ(nr_elemwise_mult_type, 4u); | ASSERT_EQ(nr_elemwise_mult_type, 4u); | ||||
| @@ -822,7 +822,7 @@ TEST(TestLayoutTransform, Resnet18_F16) { | |||||
| auto new_out_var = new_output[0]; | auto new_out_var = new_output[0]; | ||||
| /// check global layout transform pass | /// check global layout transform pass | ||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | ||||
| ASSERT_EQ(nr_dimshuffle, 4u); | |||||
| ASSERT_EQ(nr_dimshuffle, 2u); | |||||
| /// check pass fuse conv bias with z | /// check pass fuse conv bias with z | ||||
| auto nr_elemwise = find_opr_num<opr::Elemwise>(new_out_var); | auto nr_elemwise = find_opr_num<opr::Elemwise>(new_out_var); | ||||
| ASSERT_EQ(nr_elemwise, 4u); | ASSERT_EQ(nr_elemwise, 4u); | ||||
| @@ -80,14 +80,26 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2> | |||||
| opr::BatchedMatrixMul, MakeMatrixMulCaller<megdnn::BatchedMatrixMul>, | opr::BatchedMatrixMul, MakeMatrixMulCaller<megdnn::BatchedMatrixMul>, | ||||
| megdnn::BatchedMatrixMul> {}; | megdnn::BatchedMatrixMul> {}; | ||||
| template <typename Opr> | |||||
| cg::OperatorNodeBase* opr_shallow_copy_matmul( | |||||
| const serialization::OprShallowCopyContext& ctx, | |||||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config) { | |||||
| MGB_MARK_USED_VAR(ctx); | |||||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
| return OprLoadDumpImpl<Opr, 2>::make( | |||||
| inputs, opr.param(), opr.execution_policy_transient(), config) | |||||
| ->owner_opr(); | |||||
| } | |||||
| } // namespace serialization | } // namespace serialization | ||||
| namespace opr { | namespace opr { | ||||
| using MatrixMulV2 = MatrixMul; | using MatrixMulV2 = MatrixMul; | ||||
| using BatchedMatrixMulV2 = BatchedMatrixMul; | using BatchedMatrixMulV2 = BatchedMatrixMul; | ||||
| MGB_SEREG_OPR(MatrixMulV2, 2); | |||||
| MGB_SEREG_OPR(BatchedMatrixMulV2, 2); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(MatrixMulV2, 2, opr_shallow_copy_matmul); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchedMatrixMulV2, 2, opr_shallow_copy_matmul); | |||||
| MGB_SEREG_OPR(Dot, 2); | MGB_SEREG_OPR(Dot, 2); | ||||
| MGB_SEREG_OPR(MatrixInverse, 1); | MGB_SEREG_OPR(MatrixInverse, 1); | ||||
| MGB_SEREG_OPR(SVD, 1); | MGB_SEREG_OPR(SVD, 1); | ||||
| @@ -36,9 +36,10 @@ struct MakePoolingCaller1 { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | ||||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| if (inputs.size() == 1) { | if (inputs.size() == 1) { | ||||
| return Opr::make(inputs[0], param, config).node(); | |||||
| return Opr::make(inputs[0], param, execution_policy, config).node(); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -78,9 +79,13 @@ struct MakePoolingBackwardCaller3 { | |||||
| template <typename Opr> | template <typename Opr> | ||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | const cg::VarNodeArray& inputs, const typename MegDNNPooling::Param& param, | ||||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| if (inputs.size() == 3) { | if (inputs.size() == 3) { | ||||
| return Opr::make(inputs[0], inputs[1], inputs[2], param, config).node(); | |||||
| return Opr::make( | |||||
| inputs[0], inputs[1], inputs[2], param, execution_policy, | |||||
| config) | |||||
| .node(); | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -223,8 +228,10 @@ struct PoolingLoadDumpImpl { | |||||
| static VarNode* make( | static VarNode* make( | ||||
| const cg::VarNodeArray& inputs, const PoolingParam& param, | const cg::VarNodeArray& inputs, const PoolingParam& param, | ||||
| const megdnn::param::ExecutionPolicy& execution_policy, | |||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| VarNode* ret = Maker0::template make<Opr>(inputs, param, config); | |||||
| VarNode* ret = | |||||
| Maker0::template make<Opr>(inputs, param, execution_policy, config); | |||||
| mgb_assert(ret); | mgb_assert(ret); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -233,6 +240,29 @@ struct PoolingLoadDumpImpl { | |||||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | OprLoadContext& ctx, const cg::VarNodeArray& inputs, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| auto param = ctx.read_param<PoolingParam>(); | auto param = ctx.read_param<PoolingParam>(); | ||||
| return make(inputs, param, {}, config)->owner_opr(); | |||||
| } | |||||
| }; | |||||
| template <class Opr, class Maker0, typename GeneralOprParam = megdnn::param::ROIAlign> | |||||
| struct GeneralOprLoadDumpImpl { | |||||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
| ctx.write_param<GeneralOprParam>(opr.param()); | |||||
| } | |||||
| static VarNode* make( | |||||
| const cg::VarNodeArray& inputs, const GeneralOprParam& param, | |||||
| const OperatorNodeConfig& config) { | |||||
| VarNode* ret = Maker0::template make<Opr>(inputs, param, config); | |||||
| mgb_assert(ret); | |||||
| return ret; | |||||
| } | |||||
| static cg::OperatorNodeBase* load( | |||||
| OprLoadContext& ctx, const cg::VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config) { | |||||
| auto param = ctx.read_param<GeneralOprParam>(); | |||||
| return make(inputs, param, config)->owner_opr(); | return make(inputs, param, config)->owner_opr(); | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -264,26 +294,26 @@ struct OprMaker<opr::LSQBackward, 5> { | |||||
| }; | }; | ||||
| template <> | template <> | ||||
| struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | ||||
| : public PoolingLoadDumpImpl< | |||||
| : public GeneralOprLoadDumpImpl< | |||||
| opr::AdaptivePoolingBackward, | opr::AdaptivePoolingBackward, | ||||
| MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>, | MakeAdaptivePoolingBackwardCaller3<megdnn::AdaptivePoolingBackward>, | ||||
| megdnn::param::AdaptivePooling> {}; | megdnn::param::AdaptivePooling> {}; | ||||
| template <> | template <> | ||||
| struct OprLoadDumpImpl<opr::AdaptivePooling, 0> | struct OprLoadDumpImpl<opr::AdaptivePooling, 0> | ||||
| : public PoolingLoadDumpImpl< | |||||
| : public GeneralOprLoadDumpImpl< | |||||
| opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>, | opr::AdaptivePooling, MakeROIAlignCaller1<megdnn::AdaptivePooling>, | ||||
| megdnn::param::AdaptivePooling> {}; | megdnn::param::AdaptivePooling> {}; | ||||
| template <> | template <> | ||||
| struct OprLoadDumpImpl<opr::ROIAlign, 0> | struct OprLoadDumpImpl<opr::ROIAlign, 0> | ||||
| : public PoolingLoadDumpImpl< | |||||
| : public GeneralOprLoadDumpImpl< | |||||
| opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>, | opr::ROIAlign, MakeROIAlignCaller1<megdnn::ROIAlign>, | ||||
| megdnn::param::ROIAlign> {}; | megdnn::param::ROIAlign> {}; | ||||
| template <> | template <> | ||||
| struct OprLoadDumpImpl<opr::ROIAlignBackward, 0> | struct OprLoadDumpImpl<opr::ROIAlignBackward, 0> | ||||
| : public PoolingLoadDumpImpl< | |||||
| : public GeneralOprLoadDumpImpl< | |||||
| opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>, | opr::ROIAlignBackward, MakeROIAlignCaller4<megdnn::ROIAlignBackward>, | ||||
| megdnn::param::ROIAlign> {}; | megdnn::param::ROIAlign> {}; | ||||
| @@ -500,15 +530,29 @@ struct OprLoadDumpImpl<opr::DeformableConvBackwardFilter, 0> | |||||
| opr::DeformableConvBackwardFilter, | opr::DeformableConvBackwardFilter, | ||||
| MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | MakeConvCaller5<megdnn::DeformableConvBackwardFilter>, | ||||
| megdnn::Convolution> {}; | megdnn::Convolution> {}; | ||||
| template <typename Opr> | |||||
| cg::OperatorNodeBase* opr_shallow_copy_conv( | |||||
| const serialization::OprShallowCopyContext& ctx, | |||||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config) { | |||||
| MGB_MARK_USED_VAR(ctx); | |||||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||||
| return OprLoadDumpImpl<Opr, 0>::make( | |||||
| inputs, opr.param(), opr.execution_policy_transient(), config) | |||||
| ->owner_opr(); | |||||
| } | |||||
| } // namespace serialization | } // namespace serialization | ||||
| namespace opr { | namespace opr { | ||||
| using ConvolutionV2 = Convolution; | using ConvolutionV2 = Convolution; | ||||
| using ConvolutionBackwardDataV2 = ConvolutionBackwardData; | using ConvolutionBackwardDataV2 = ConvolutionBackwardData; | ||||
| using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter; | using ConvolutionBackwardFilterV2 = ConvolutionBackwardFilter; | ||||
| MGB_SEREG_OPR(ConvolutionV2, 0); | |||||
| MGB_SEREG_OPR(ConvolutionBackwardDataV2, 0); | |||||
| MGB_SEREG_OPR(ConvolutionBackwardFilterV2, 0); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionV2, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvolutionBackwardDataV2, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
| ConvolutionBackwardFilterV2, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR(Images2Neibs, 1); | MGB_SEREG_OPR(Images2Neibs, 1); | ||||
| MGB_SEREG_OPR(Images2NeibsBackward, 2); | MGB_SEREG_OPR(Images2NeibsBackward, 2); | ||||
| @@ -534,8 +578,8 @@ MGB_SEREG_OPR(LRN, 1); | |||||
| MGB_SEREG_OPR(LRNBackward, 3); | MGB_SEREG_OPR(LRNBackward, 3); | ||||
| using PoolingV1 = Pooling; | using PoolingV1 = Pooling; | ||||
| using PoolingBackwardV1 = PoolingBackward; | using PoolingBackwardV1 = PoolingBackward; | ||||
| MGB_SEREG_OPR(PoolingV1, 1); | |||||
| MGB_SEREG_OPR(PoolingBackwardV1, 3); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(PoolingBackwardV1, 0, opr_shallow_copy_conv); | |||||
| using AdaptivePoolingV1 = AdaptivePooling; | using AdaptivePoolingV1 = AdaptivePooling; | ||||
| using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward; | using AdaptivePoolingBackwardV1 = AdaptivePoolingBackward; | ||||
| MGB_SEREG_OPR(AdaptivePoolingV1, 2); | MGB_SEREG_OPR(AdaptivePoolingV1, 2); | ||||
| @@ -548,12 +592,13 @@ using MaskConvolutionV2 = MaskConvolution; | |||||
| MGB_SEREG_OPR(MaskConvolutionV2, 3); | MGB_SEREG_OPR(MaskConvolutionV2, 3); | ||||
| MGB_SEREG_OPR(MaskPropagate, 1); | MGB_SEREG_OPR(MaskPropagate, 1); | ||||
| MGB_SEREG_OPR(Convolution3D, 0); | |||||
| MGB_SEREG_OPR(Convolution3DBackwardData, 0); | |||||
| MGB_SEREG_OPR(Convolution3DBackwardFilter, 0); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3D, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(Convolution3DBackwardData, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
| Convolution3DBackwardFilter, 0, opr_shallow_copy_conv); | |||||
| using ConvBiasForwardV4 = ConvBiasForward; | using ConvBiasForwardV4 = ConvBiasForward; | ||||
| MGB_SEREG_OPR(ConvBiasForwardV4, 0); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(ConvBiasForwardV4, 0, opr_shallow_copy_conv); | |||||
| using BatchNormV1 = BatchNorm; | using BatchNormV1 = BatchNorm; | ||||
| using BatchNormBackwardV1 = BatchNormBackward; | using BatchNormBackwardV1 = BatchNormBackward; | ||||
| @@ -563,9 +608,10 @@ MGB_SEREG_OPR(BatchNormBackwardV1, 6); | |||||
| using LocalShareForwardV1 = LocalShareForward; | using LocalShareForwardV1 = LocalShareForward; | ||||
| using LocalShareBackwardDataV1 = LocalShareBackwardData; | using LocalShareBackwardDataV1 = LocalShareBackwardData; | ||||
| using LocalShareBackwardFilterV1 = LocalShareBackwardFilter; | using LocalShareBackwardFilterV1 = LocalShareBackwardFilter; | ||||
| MGB_SEREG_OPR(LocalShareForwardV1, 0); | |||||
| MGB_SEREG_OPR(LocalShareBackwardDataV1, 0); | |||||
| MGB_SEREG_OPR(LocalShareBackwardFilterV1, 0); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareForwardV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(LocalShareBackwardDataV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
| LocalShareBackwardFilterV1, 0, opr_shallow_copy_conv); | |||||
| using ROIAlignV1 = ROIAlign; | using ROIAlignV1 = ROIAlign; | ||||
| using ROIAlignBackwardV1 = ROIAlignBackward; | using ROIAlignBackwardV1 = ROIAlignBackward; | ||||
| @@ -574,9 +620,11 @@ MGB_SEREG_OPR(ROIAlignBackwardV1, 4); | |||||
| using DeformableConvForwardV1 = DeformableConvForward; | using DeformableConvForwardV1 = DeformableConvForward; | ||||
| using DeformableConvBackwardDataV1 = DeformableConvBackwardData; | using DeformableConvBackwardDataV1 = DeformableConvBackwardData; | ||||
| using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter; | using DeformableConvBackwardFilterV1 = DeformableConvBackwardFilter; | ||||
| MGB_SEREG_OPR(DeformableConvForwardV1, 0); | |||||
| MGB_SEREG_OPR(DeformableConvBackwardDataV1, 0); | |||||
| MGB_SEREG_OPR(DeformableConvBackwardFilterV1, 0); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(DeformableConvForwardV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
| DeformableConvBackwardDataV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY( | |||||
| DeformableConvBackwardFilterV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR(CorrelationForward, 2); | MGB_SEREG_OPR(CorrelationForward, 2); | ||||
| MGB_SEREG_OPR(CorrelationBackwardData1, 3); | MGB_SEREG_OPR(CorrelationBackwardData1, 3); | ||||
| @@ -586,7 +634,7 @@ MGB_SEREG_OPR(DeformablePSROIPoolingForward, 3); | |||||
| MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | MGB_SEREG_OPR(DeformablePSROIPoolingBackward, 5); | ||||
| using BatchConvBiasForwardV1 = BatchConvBiasForward; | using BatchConvBiasForwardV1 = BatchConvBiasForward; | ||||
| MGB_SEREG_OPR(BatchConvBiasForwardV1, 0); | |||||
| MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(BatchConvBiasForwardV1, 0, opr_shallow_copy_conv); | |||||
| MGB_SEREG_OPR(FakeQuant, 3); | MGB_SEREG_OPR(FakeQuant, 3); | ||||
| MGB_SEREG_OPR(FakeQuantBackward, 4); | MGB_SEREG_OPR(FakeQuantBackward, 4); | ||||
| MGB_SEREG_OPR(TQT, 2); | MGB_SEREG_OPR(TQT, 2); | ||||
| @@ -32,8 +32,8 @@ PoolingForward::PoolingForward( | |||||
| } | } | ||||
| SymbolVar PoolingForward::make( | SymbolVar PoolingForward::make( | ||||
| SymbolVar i0, const Param& param, const OperatorNodeConfig& config, | |||||
| const ExecutionPolicy& policy) { | |||||
| SymbolVar i0, const Param& param, const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config) { | |||||
| intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0}); | intl::MegDNNOprInitInputsModifier<PoolingForward>::apply(param, {&i0}); | ||||
| return i0.insert_single_output_opr<PoolingForward>( | return i0.insert_single_output_opr<PoolingForward>( | ||||
| i0.node(), param, policy, config); | i0.node(), param, policy, config); | ||||
| @@ -75,12 +75,13 @@ PoolingBackward::PoolingBackward( | |||||
| 0, true) { | 0, true) { | ||||
| init_megdnn_opr(*this, param); | init_megdnn_opr(*this, param); | ||||
| add_input({i0, i1, i2}); | add_input({i0, i1, i2}); | ||||
| m_policy = policy; | |||||
| intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this); | intl::MegDNNOprInitPostCtor<PoolingBackward>::apply(*this); | ||||
| } | } | ||||
| SymbolVar PoolingBackward::make( | SymbolVar PoolingBackward::make( | ||||
| SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param, | SymbolVar i0, SymbolVar i1, SymbolVar i2, const Param& param, | ||||
| const OperatorNodeConfig& config, const ExecutionPolicy& policy) { | |||||
| const ExecutionPolicy& policy, const OperatorNodeConfig& config) { | |||||
| intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param, {&i0, &i1, &i2}); | intl::MegDNNOprInitInputsModifier<PoolingBackward>::apply(param, {&i0, &i1, &i2}); | ||||
| return i0.insert_single_output_opr<PoolingBackward>( | return i0.insert_single_output_opr<PoolingBackward>( | ||||
| i0.node(), i1.node(), i2.node(), param, policy, config); | i0.node(), i1.node(), i2.node(), param, policy, config); | ||||
| @@ -26,8 +26,8 @@ MGE_WIN_DECLSPEC_FUC PoolingForward( | |||||
| VarNode* src, const Param& param, const ExecutionPolicy& policy, | VarNode* src, const Param& param, const ExecutionPolicy& policy, | ||||
| const OperatorNodeConfig& config); | const OperatorNodeConfig& config); | ||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | MGE_WIN_DECLSPEC_FUC static SymbolVar make( | ||||
| SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}, | |||||
| const ExecutionPolicy& policy = {}); | |||||
| SymbolVar src, const Param& param, const ExecutionPolicy& policy = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| @@ -47,7 +47,7 @@ MGE_WIN_DECLSPEC_FUC PoolingBackward( | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | MGE_WIN_DECLSPEC_FUC static SymbolVar make( | ||||
| SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | ||||
| const OperatorNodeConfig& config = {}, const ExecutionPolicy& policy = {}); | |||||
| const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); | |||||
| MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | ||||
| const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
| @@ -15,7 +15,9 @@ | |||||
| #include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
| #include "megbrain/opr/blas.h" | #include "megbrain/opr/blas.h" | ||||
| #include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
| #include "megbrain/opr/dnn/pooling.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| #include "megbrain/test/autocheck.h" | #include "megbrain/test/autocheck.h" | ||||
| #include "megbrain/test/helper.h" | #include "megbrain/test/helper.h" | ||||
| @@ -32,39 +34,24 @@ using namespace mgb; | |||||
| namespace { | namespace { | ||||
| #if MGB_CUDA | |||||
| #if MGB_ENABLE_FASTRUN | |||||
| template <typename MgbOpr, int arith> | template <typename MgbOpr, int arith> | ||||
| struct GraphMaker; | struct GraphMaker; | ||||
| template <typename MgbOpr> | |||||
| struct GraphMaker<MgbOpr, 2> { | |||||
| SymbolVar operator()( | |||||
| const std::array<cg::SymbolVar, 2>& inputs, typename MgbOpr::Param& param, | |||||
| typename MgbOpr::ExecutionPolicy& policy) { | |||||
| return MgbOpr::make(inputs[0], inputs[1], param, policy); | |||||
| } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct GraphMaker<opr::ConvolutionBackwardData, 2> { | |||||
| struct GraphMaker<opr::Pooling, 1> { | |||||
| SymbolVar operator()( | SymbolVar operator()( | ||||
| const std::array<cg::SymbolVar, 2>& inputs, | |||||
| opr::ConvolutionBackwardData::Param& param, | |||||
| opr::ConvolutionBackwardData::ExecutionPolicy& policy) { | |||||
| return opr::ConvolutionBackwardData::make_deconv( | |||||
| inputs[0], inputs[1], param, policy); | |||||
| const std::array<cg::SymbolVar, 1>& inputs, opr::Pooling::Param& param, | |||||
| opr::Pooling::ExecutionPolicy& policy) { | |||||
| return opr::Pooling::make(inputs[0], param, policy); | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct GraphMaker<opr::Convolution3DBackwardData, 2> { | |||||
| template <typename MgbOpr> | |||||
| struct GraphMaker<MgbOpr, 2> { | |||||
| SymbolVar operator()( | SymbolVar operator()( | ||||
| const std::array<cg::SymbolVar, 2>& inputs, | |||||
| opr::Convolution3DBackwardData::Param& param, | |||||
| opr::Convolution3DBackwardData::ExecutionPolicy& policy) { | |||||
| return opr::Convolution3DBackwardData::make_deconv( | |||||
| inputs[0], inputs[1], param, policy); | |||||
| const std::array<cg::SymbolVar, 2>& inputs, typename MgbOpr::Param& param, | |||||
| typename MgbOpr::ExecutionPolicy& policy) { | |||||
| return MgbOpr::make(inputs[0], inputs[1], param, policy); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -98,6 +85,37 @@ struct GraphMaker<MgbOpr, 5> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename MgbOpr, int arith, typename dtype = dtype::Float32> | |||||
| void test_execution_policy_shallow_copy( | |||||
| std::array<TensorShape, arith> shapes, typename MgbOpr::Param param = {}) { | |||||
| using Policy = typename MgbOpr::ExecutionPolicy; | |||||
| Policy policy; | |||||
| policy.strategy = Policy::Strategy::PROFILE; | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| auto graph0 = ComputingGraph::make(), graph1 = ComputingGraph::make(); | |||||
| std::array<cg::SymbolVar, arith> inputs0; | |||||
| VarNodeArray inputs1; | |||||
| for (size_t i = 0; i < arith; ++i) { | |||||
| HostTensorND hi{cn, shapes[i], dtype()}; | |||||
| inputs0[i] = opr::ImmutableTensor::make(*graph0, hi); | |||||
| inputs1.push_back(opr::ImmutableTensor::make(*graph1, hi).node()); | |||||
| } | |||||
| GraphMaker<MgbOpr, arith> graph_maker; | |||||
| auto opr0 = graph_maker(inputs0, param, policy).node()->owner_opr(); | |||||
| auto opr1 = serialization::copy_opr_shallow(*opr0, inputs1, OperatorNodeConfig{}); | |||||
| auto m0 = &(opr0->template cast_final<MgbOpr>()); | |||||
| auto m1 = &(opr1->template cast_final<MgbOpr>()); | |||||
| ASSERT_EQ(policy.strategy, m0->execution_policy().strategy); | |||||
| ASSERT_EQ(policy.strategy, m1->execution_policy().strategy); | |||||
| } | |||||
| #if MGB_CUDA | |||||
| #if MGB_ENABLE_FASTRUN | |||||
| template <typename MgbOpr, int arith, typename dtype = dtype::Float32> | template <typename MgbOpr, int arith, typename dtype = dtype::Float32> | ||||
| void test_fastrun_opr( | void test_fastrun_opr( | ||||
| std::array<TensorShape, arith> inps0, std::array<TensorShape, arith> inps1, | std::array<TensorShape, arith> inps0, std::array<TensorShape, arith> inps1, | ||||
| @@ -162,16 +180,24 @@ void test_fastrun_opr( | |||||
| size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0; | size_t nr_set_total = expect_nr_cache_set_inp1 + nr_set_inp0; | ||||
| ASSERT_EQ(cache_set_history.size(), nr_set_total); | ASSERT_EQ(cache_set_history.size(), nr_set_total); | ||||
| } | } | ||||
| #endif // MGB_ENABLE_FASTRUN | |||||
| #endif // MGB_CUDA | |||||
| } // anonymous namespace | |||||
| #if MGB_CUDA | |||||
| #if MGB_ENABLE_FASTRUN | |||||
| TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) { | TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution) { | ||||
| REQUIRE_GPU(1); | REQUIRE_GPU(1); | ||||
| test_fastrun_opr<opr::Convolution, 2>( | test_fastrun_opr<opr::Convolution, 2>( | ||||
| {TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}, | {TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}, | ||||
| {TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); | {TensorShape{1, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); | ||||
| test_fastrun_opr<opr::ConvolutionBackwardData, 2>( | |||||
| {TensorShape{12, 4, 23, 29}, TensorShape{4, 5, 3, 2}}, | |||||
| {TensorShape{2, 4, 23, 29}, TensorShape{4, 5, 3, 2}}); | |||||
| test_fastrun_opr<opr::ConvolutionBackwardData, 3>( | |||||
| {TensorShape{4, 5, 3, 2}, TensorShape{12, 4, 23, 29}, | |||||
| TensorShape{12, 5, 25, 30}}, | |||||
| {TensorShape{4, 5, 3, 2}, TensorShape{2, 4, 23, 29}, | |||||
| TensorShape{2, 5, 25, 30}}); | |||||
| test_fastrun_opr<opr::ConvolutionBackwardFilter, 3>( | test_fastrun_opr<opr::ConvolutionBackwardFilter, 3>( | ||||
| {TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, | {TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, | ||||
| @@ -195,9 +221,11 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeConvolution3D) { | |||||
| {TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}, | {TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}, | ||||
| {TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); | {TensorShape{3, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); | ||||
| test_fastrun_opr<opr::Convolution3DBackwardData, 2>( | |||||
| {TensorShape{14, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}}, | |||||
| {TensorShape{4, 5, 12, 12, 16}, TensorShape{5, 5, 3, 3, 3}}); | |||||
| test_fastrun_opr<opr::Convolution3DBackwardData, 3>( | |||||
| {TensorShape{5, 5, 3, 3, 3}, TensorShape{14, 5, 12, 12, 16}, | |||||
| TensorShape{14, 5, 14, 14, 18}}, | |||||
| {TensorShape{5, 5, 3, 3, 3}, TensorShape{4, 5, 12, 12, 16}, | |||||
| TensorShape{4, 5, 14, 14, 18}}); | |||||
| test_fastrun_opr<opr::Convolution3DBackwardFilter, 3>( | test_fastrun_opr<opr::Convolution3DBackwardFilter, 3>( | ||||
| {TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, | {TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, | ||||
| @@ -295,6 +323,87 @@ TEST(TestOprDNN, FastrunIgnoreBatchSizeBatchedMatrixMul) { | |||||
| #endif // MGB_ENABLE_FASTRUN | #endif // MGB_ENABLE_FASTRUN | ||||
| #endif // MGB_CUDA | #endif // MGB_CUDA | ||||
| } // anonymous namespace | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyConvolution) { | |||||
| test_execution_policy_shallow_copy<opr::Convolution, 2>( | |||||
| {TensorShape{12, 3, 36, 36}, TensorShape{4, 3, 3, 3}}); | |||||
| test_execution_policy_shallow_copy<opr::ConvolutionBackwardData, 3>( | |||||
| {TensorShape{4, 5, 3, 2}, TensorShape{12, 4, 23, 29}, | |||||
| TensorShape{12, 5, 25, 30}}); | |||||
| test_execution_policy_shallow_copy<opr::ConvolutionBackwardFilter, 3>( | |||||
| {TensorShape{12, 4, 23, 29}, TensorShape{12, 5, 21, 28}, | |||||
| TensorShape{5, 4, 3, 2}}); | |||||
| } | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyConvBias) { | |||||
| test_execution_policy_shallow_copy<opr::ConvBias, 3>( | |||||
| {TensorShape{20, 16, 50, 50}, TensorShape{24, 16, 3, 3}, | |||||
| TensorShape{1, 24, 1, 1}}); | |||||
| } | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyConvolution3D) { | |||||
| test_execution_policy_shallow_copy<opr::Convolution3D, 2>( | |||||
| {TensorShape{8, 4, 12, 13, 14}, TensorShape{4, 4, 3, 3, 3}}); | |||||
| test_execution_policy_shallow_copy<opr::Convolution3DBackwardData, 3>( | |||||
| {TensorShape{5, 5, 3, 3, 3}, TensorShape{14, 5, 12, 12, 16}, | |||||
| TensorShape{14, 5, 14, 14, 18}}); | |||||
| test_execution_policy_shallow_copy<opr::Convolution3DBackwardFilter, 3>( | |||||
| {TensorShape{64, 16, 18, 18, 18}, TensorShape{64, 16, 18, 18, 18}, | |||||
| TensorShape{16, 16, 1, 1, 1}}); | |||||
| } | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyLocalShare) { | |||||
| opr::LocalShare::Param local_share_param; | |||||
| local_share_param.mode = opr::LocalShare::Param::Mode::CROSS_CORRELATION; | |||||
| local_share_param.pad_h = local_share_param.pad_w = 1; | |||||
| local_share_param.stride_h = local_share_param.stride_w = 1; | |||||
| local_share_param.spatial_groups_h = local_share_param.spatial_groups_w = 2; | |||||
| test_execution_policy_shallow_copy<opr::LocalShareForward, 2>( | |||||
| {TensorShape{32, 2, 23, 23}, TensorShape{2, 2, 2, 2, 2, 7}}, | |||||
| local_share_param); | |||||
| test_execution_policy_shallow_copy<opr::LocalShareBackwardData, 3>( | |||||
| {TensorShape{3, 3, 128, 1, 1, 128}, TensorShape{32, 128, 24, 24}, | |||||
| TensorShape{32, 128, 24, 24}}); | |||||
| test_execution_policy_shallow_copy<opr::LocalShareBackwardFilter, 3>( | |||||
| {TensorShape{12, 3, 36, 36}, TensorShape{12, 4, 35, 35}, | |||||
| TensorShape{3, 3, 3, 3, 3, 4}}); | |||||
| } | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyDeformableConv) { | |||||
| test_execution_policy_shallow_copy<opr::DeformableConvForward, 4>( | |||||
| {TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, | |||||
| TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}}); | |||||
| test_execution_policy_shallow_copy<opr::DeformableConvBackwardData, 5>( | |||||
| {TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, | |||||
| TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}, | |||||
| TensorShape{12, 6, 18, 18}}); | |||||
| test_execution_policy_shallow_copy<opr::DeformableConvBackwardFilter, 5>( | |||||
| {TensorShape{12, 6, 20, 20}, TensorShape{6, 6, 3, 3}, | |||||
| TensorShape{12, 18, 18, 18}, TensorShape{12, 9, 18, 18}, | |||||
| TensorShape{12, 6, 18, 18}}); | |||||
| } | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyMatrixMul) { | |||||
| test_execution_policy_shallow_copy<opr::MatrixMul, 2>( | |||||
| {TensorShape{10, 12}, TensorShape{12, 12}}); | |||||
| test_execution_policy_shallow_copy<opr::BatchedMatrixMul, 2>( | |||||
| {TensorShape{12, 6, 8}, TensorShape{12, 8, 4}}); | |||||
| } | |||||
| TEST(TestOprDNN, ExecutionPolicyShallowCopyPooling) { | |||||
| test_execution_policy_shallow_copy<opr::Pooling, 1>({TensorShape{1, 20, 24, 24}}); | |||||
| test_execution_policy_shallow_copy<opr::PoolingBackward, 3>( | |||||
| {TensorShape{1, 20, 24, 24}, TensorShape{1, 20, 12, 12}, | |||||
| TensorShape{1, 20, 12, 12}}); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -135,7 +135,7 @@ TEST(TestOprDNN, PoolingExePolicy) { | |||||
| Policy policy; | Policy policy; | ||||
| policy.strategy = strategy; | policy.strategy = strategy; | ||||
| auto pooling = opr::PoolingForward::make(input, param, {}, policy); | |||||
| auto pooling = opr::PoolingForward::make(input, param, policy); | |||||
| auto loss0 = opr::reduce_sum_sqr(pooling, pooling.make_scalar(1)); | auto loss0 = opr::reduce_sum_sqr(pooling, pooling.make_scalar(1)); | ||||
| auto grad = cg::grad(loss0, input, true, false); | auto grad = cg::grad(loss0, input, true, false); | ||||
| @@ -187,7 +187,7 @@ TEST(TestOprDNN, PoolingForwardFastrun) { | |||||
| Policy policy; | Policy policy; | ||||
| policy.strategy = strategy; | policy.strategy = strategy; | ||||
| auto pooling = opr::PoolingForward::make(input, param, {}, policy); | |||||
| auto pooling = opr::PoolingForward::make(input, param, policy); | |||||
| auto func = graph->compile({make_callback_copy(pooling, host_y)}); | auto func = graph->compile({make_callback_copy(pooling, host_y)}); | ||||
| func->execute().wait(); | func->execute().wait(); | ||||
| @@ -253,4 +253,11 @@ struct IsComplete<T, decltype(void(sizeof(T)))> : std::true_type {}; | |||||
| __caller_OprRegShallowCopy##_cls##_ins; \ | __caller_OprRegShallowCopy##_cls##_ins; \ | ||||
| } | } | ||||
| /*! | |||||
| * \brief register opr serialization and shallow copy methods | |||||
| */ | |||||
| #define MGB_SEREG_OPR_AND_REG_SHALLOW_COPY(_cls, _arity, _copy) \ | |||||
| MGB_SEREG_OPR(_cls, _arity) \ | |||||
| MGB_REG_OPR_SHALLOW_COPY(_cls, ::mgb::serialization::_copy<_cls>) | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||