GitOrigin-RevId: 8c2b6a2aed
tags/v1.8.0
| @@ -158,70 +158,71 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
| ForceInplaceElemwise, | ForceInplaceElemwise, | ||||
| cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | |||||
| cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) // { | |||||
| public: | public: | ||||
| struct Param { | |||||
| using Mode = megdnn::Elemwise::Param::Mode; | |||||
| Mode mode; | |||||
| size_t inplace_index; | |||||
| }; | |||||
| using Mode = Param::Mode; | |||||
| ForceInplaceElemwise( | |||||
| const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {}) | |||||
| : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), | |||||
| m_param{param} { | |||||
| for (auto* input : inputs) { | |||||
| add_input({input}); | |||||
| struct Param { | |||||
| using Mode = megdnn::Elemwise::Param::Mode; | |||||
| Mode mode; | |||||
| size_t inplace_index; | |||||
| }; | |||||
| using Mode = Param::Mode; | |||||
| ForceInplaceElemwise( | |||||
| const VarNodeArray& inputs, Param param, OperatorNodeConfig config = {}) | |||||
| : Super(inputs[0]->owner_graph(), config, "device_add_update", inputs), | |||||
| m_param{param} { | |||||
| for (auto* input : inputs) { | |||||
| add_input({input}); | |||||
| } | |||||
| add_output(None) | |||||
| ->set_fwd_in2out_writable_force(input(param.inplace_index)) | |||||
| .add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
| } | } | ||||
| add_output(None) | |||||
| ->set_fwd_in2out_writable_force(input(param.inplace_index)) | |||||
| .add_flag(VarNode::Flag::NO_MEM_RECLAIM); | |||||
| } | |||||
| static SymbolVar make(const VarNodeArray& inputs, Param param) { | |||||
| return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>( | |||||
| inputs, param); | |||||
| } | |||||
| static cg::OperatorNodeBase* shallow_copy( | |||||
| const serialization::OprShallowCopyContext& ctx, | |||||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config); | |||||
| static SymbolVar make(const VarNodeArray& inputs, Param param) { | |||||
| return SymbolVar{inputs[0]}.insert_single_output_opr<ForceInplaceElemwise>( | |||||
| inputs, param); | |||||
| } | |||||
| static cg::OperatorNodeBase* shallow_copy( | |||||
| const serialization::OprShallowCopyContext& ctx, | |||||
| const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
| const OperatorNodeConfig& config); | |||||
| protected: | protected: | ||||
| NodeProp* do_make_node_prop() const override { | |||||
| auto ret = Super::do_make_node_prop(); | |||||
| ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); | |||||
| return ret; | |||||
| } | |||||
| void create_megdnn_opr() override { | |||||
| auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node()); | |||||
| opr->param().mode = m_param.mode; | |||||
| set_megdnn_opr(std::move(opr)); | |||||
| } | |||||
| void scn_do_execute() override { | |||||
| auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); }; | |||||
| megdnn::TensorNDArray inputs_dnnnd; | |||||
| for (auto* input : input()) { | |||||
| inputs_dnnnd.push_back(to_dnnnd(input)); | |||||
| NodeProp* do_make_node_prop() const override { | |||||
| auto ret = Super::do_make_node_prop(); | |||||
| ret->add_flag(NodeProp::Flag::FORCE_UPDATE_INPUT_VAR); | |||||
| return ret; | |||||
| } | } | ||||
| mgb_assert( | |||||
| input(m_param.inplace_index)->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC), | |||||
| "ForceInplaceElemwise cannot be applied in internal tensor"); | |||||
| auto* out_dest = output(0); | |||||
| auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr()); | |||||
| opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest)); | |||||
| } | |||||
| void init_output_static_infer_desc() override { | |||||
| using namespace cg::static_infer; | |||||
| void create_megdnn_opr() override { | |||||
| auto opr = DnnOprCaller<megdnn::Elemwise>::create_operator(comp_node()); | |||||
| opr->param().mode = m_param.mode; | |||||
| set_megdnn_opr(std::move(opr)); | |||||
| } | |||||
| void scn_do_execute() override { | |||||
| auto to_dnnnd = [&](auto* var) { return var->dev_tensor().as_megdnn(); }; | |||||
| megdnn::TensorNDArray inputs_dnnnd; | |||||
| for (auto* input : input()) { | |||||
| inputs_dnnnd.push_back(to_dnnnd(input)); | |||||
| } | |||||
| mgb_assert( | |||||
| input(m_param.inplace_index) | |||||
| ->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC), | |||||
| "ForceInplaceElemwise cannot be applied in internal tensor"); | |||||
| auto* out_dest = output(0); | |||||
| auto* opr = static_cast<megdnn::Elemwise*>(megdnn_opr()); | |||||
| opr->exec(std::move(inputs_dnnnd), to_dnnnd(out_dest)); | |||||
| } | |||||
| void init_output_static_infer_desc() override { | |||||
| using namespace cg::static_infer; | |||||
| owner_graph()->static_infer_manager().register_shape_infer( | |||||
| output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index))); | |||||
| } | |||||
| owner_graph()->static_infer_manager().register_shape_infer( | |||||
| output(0), ShapeInferDesc::make_identity(input(m_param.inplace_index))); | |||||
| } | |||||
| private: | private: | ||||
| Param m_param; | |||||
| void record_execute_deps(ExecDependencyArray& deps) override { | |||||
| record_megdnn_opr(deps); | |||||
| } | |||||
| Param m_param; | |||||
| void record_execute_deps(ExecDependencyArray& deps) override { | |||||
| record_megdnn_opr(deps); | |||||
| } | |||||
| }; | }; | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ForceInplaceElemwise); | ||||
| @@ -1013,13 +1013,13 @@ using OprNodeArray = SmallVector<OperatorNodeBase*>; | |||||
| * | * | ||||
| * Note that opening brace is included | * Note that opening brace is included | ||||
| */ | */ | ||||
| #define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \ | |||||
| MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| #define MGB_DEFINE_OPR_CLASS(_name, _base, ...) \ | |||||
| MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
| #define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \ | |||||
| MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT; | |||||
| #define MGB_DEFINE_OPR_CLASS_WITH_EXPORT(_name, _base, ...) \ | |||||
| MGB_DEFINE_CLS_WITH_SUPER(_name final, _base, ##__VA_ARGS__) \ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_DECL_WITH_EXPORT; | |||||
| } // namespace cg | } // namespace cg | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -495,18 +495,18 @@ private: | |||||
| } // namespace mgb | } // namespace mgb | ||||
| #define _MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \ | |||||
| class _name : public _base, ##__VA_ARGS__ { \ | |||||
| public: \ | |||||
| using Super = _tpl _base; \ | |||||
| \ | |||||
| #define MGB_DEFINE_CLS_WITH_SUPER_IMPL(_tpl, _name, _base, ...) \ | |||||
| class _name : public _base, ##__VA_ARGS__ { \ | |||||
| public: \ | |||||
| using Super = _tpl _base; \ | |||||
| \ | |||||
| private: | private: | ||||
| /*! | /*! | ||||
| * \brief define a class which has Super defined to base | * \brief define a class which has Super defined to base | ||||
| */ | */ | ||||
| #define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \ | #define MGB_DEFINE_CLS_WITH_SUPER(_name, _base, ...) \ | ||||
| _MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__) | |||||
| MGB_DEFINE_CLS_WITH_SUPER_IMPL(, _name, _base, ##__VA_ARGS__) | |||||
| /*! | /*! | ||||
| * \brief define a class which has Super defined to base | * \brief define a class which has Super defined to base | ||||
| @@ -514,5 +514,5 @@ private: | |||||
| * Used when this class is a template and base class has template | * Used when this class is a template and base class has template | ||||
| */ | */ | ||||
| #define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \ | #define MGB_DEFINE_CLS_WITH_SUPER_TPL(_name, _base, ...) \ | ||||
| _MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__) | |||||
| MGB_DEFINE_CLS_WITH_SUPER_IMPL(typename, _name, _base, ##__VA_ARGS__) | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -99,7 +99,7 @@ float GraphPartitionProfiler::duration_in_usec() const { | |||||
| * \brief An operator that indicates its input var node is contiguous | * \brief An operator that indicates its input var node is contiguous | ||||
| */ | */ | ||||
| // clang-format off | // clang-format off | ||||
| MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) //{ | |||||
| MGB_DEFINE_OPR_CLASS(MarkInputContiguous, SingleCNOperatorNodeBase) // { | |||||
| void scn_do_execute() override {}; | void scn_do_execute() override {}; | ||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| void add_input_layout_constraint() override { | void add_input_layout_constraint() override { | ||||
| @@ -20,38 +20,38 @@ namespace opr { | |||||
| MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
| PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | PoolingForward, intl::MegDNNOprWrapperFwd<megdnn::PoolingForward>, | ||||
| public mixin::AlgoChooserHelper) //{ | |||||
| public mixin::AlgoChooserHelper) // { | |||||
| public: | public: | ||||
| MGE_WIN_DECLSPEC_FUC PoolingForward( | |||||
| VarNode* src, const Param& param, const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar src, const Param& param, const ExecutionPolicy& policy = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| void init_output_static_infer_desc() override; | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override; | |||||
| MGE_WIN_DECLSPEC_FUC PoolingForward( | |||||
| VarNode* src, const Param& param, const ExecutionPolicy& policy, | |||||
| const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar src, const Param& param, const ExecutionPolicy& policy = {}, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| void init_output_static_infer_desc() override; | |||||
| size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override; | |||||
| }; | }; | ||||
| using Pooling = PoolingForward; | using Pooling = PoolingForward; | ||||
| MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
| PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | PoolingBackward, intl::MegDNNOprWrapperBwd<megdnn::PoolingBackward>, | ||||
| public mixin::AlgoChooserHelper) //{ | |||||
| public mixin::AlgoChooserHelper) // { | |||||
| public: | public: | ||||
| MGE_WIN_DECLSPEC_FUC PoolingBackward( | |||||
| VarNode* src, VarNode* dst, VarNode* diff, const Param& param, | |||||
| const ExecutionPolicy& policy, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC PoolingBackward( | |||||
| VarNode* src, VarNode* dst, VarNode* diff, const Param& param, | |||||
| const ExecutionPolicy& policy, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | |||||
| const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar src, SymbolVar dst, SymbolVar diff, const Param& param, | |||||
| const ExecutionPolicy& policy = {}, const OperatorNodeConfig& config = {}); | |||||
| MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override final; | |||||
| MGE_WIN_DECLSPEC_FUC size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | |||||
| const TensorShapeArray& output_shapes) const override final; | |||||
| }; | }; | ||||
| } // namespace opr | } // namespace opr | ||||
| @@ -86,7 +86,7 @@ MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint_contig(OperatorNodeBase& o | |||||
| //! called in constructor to add output vars | //! called in constructor to add output vars | ||||
| MGE_WIN_DECLSPEC_FUC void add_output_vars( | MGE_WIN_DECLSPEC_FUC void add_output_vars( | ||||
| OperatorNodeBase& opr, size_t nr_output, bool add_workspace); | OperatorNodeBase& opr, size_t nr_output, bool add_workspace); | ||||
| } | |||||
| } // namespace megdnn_utils | |||||
| /*! | /*! | ||||
| * \brief mixin for infer workspace size based on input and output shapes | * \brief mixin for infer workspace size based on input and output shapes | ||||
| @@ -344,34 +344,34 @@ private: | |||||
| } // namespace mgb | } // namespace mgb | ||||
| //! define a megdnn opr wrapper class with 1 input for forward | //! define a megdnn opr wrapper class with 1 input for forward | ||||
| #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \ | |||||
| MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
| public: \ | |||||
| _name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar p0, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(_name) \ | |||||
| MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
| public: \ | |||||
| _name(VarNode* p0, const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar p0, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| } | } | ||||
| //! define a megdnn opr wrapper class with 2 inputs for forward | //! define a megdnn opr wrapper class with 2 inputs for forward | ||||
| #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \ | |||||
| MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
| public: \ | |||||
| _name(VarNode* p0, VarNode* p1, const Param& param, \ | |||||
| const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar p0, SymbolVar p1, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD2(_name) \ | |||||
| MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperFwd<megdnn::_name>) \ | |||||
| public: \ | |||||
| _name(VarNode* p0, VarNode* p1, const Param& param, \ | |||||
| const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar p0, SymbolVar p1, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| } | } | ||||
| //! define a megdnn opr wrapper class with 3 inputs for grad | //! define a megdnn opr wrapper class with 3 inputs for grad | ||||
| #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_BWD3(_name, _extra...) \ | #define MGB_DEFINE_MEGDNN_OPR_WRAPPER_BWD3(_name, _extra...) \ | ||||
| MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \ | MGB_DEFINE_OPR_CLASS(_name, intl::MegDNNOprWrapperBwd<megdnn::_name>) \ | ||||
| _extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, const Param& param, \ | |||||
| const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| _extra public : _name(VarNode* p0, VarNode* p1, VarNode* p2, \ | |||||
| const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar p0, SymbolVar p1, SymbolVar p2, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -40,25 +40,25 @@ protected: | |||||
| }; | }; | ||||
| /* ================= RNG with shape ================= */ | /* ================= RNG with shape ================= */ | ||||
| #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
| \ | |||||
| public: \ | |||||
| RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar shape, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| static SymbolVar make( \ | |||||
| ComputingGraph& graph, const TensorShape& shape, \ | |||||
| const OperatorNodeConfig& config, const Param& param = {}) { \ | |||||
| return make( \ | |||||
| var_from_tensor_shape(graph, config, "rng", shape), param, config); \ | |||||
| } \ | |||||
| void init_output_static_infer_desc() override; \ | |||||
| void scn_do_execute() override; \ | |||||
| } \ | |||||
| ; | |||||
| #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
| \ | |||||
| public: \ | |||||
| RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( \ | |||||
| SymbolVar shape, const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| static SymbolVar make( \ | |||||
| ComputingGraph& graph, const TensorShape& shape, \ | |||||
| const OperatorNodeConfig& config, const Param& param = {}) { \ | |||||
| return make( \ | |||||
| var_from_tensor_shape(graph, config, "rng", shape), param, \ | |||||
| config); \ | |||||
| } \ | |||||
| void init_output_static_infer_desc() override; \ | |||||
| void scn_do_execute() override; \ | |||||
| }; | |||||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) | ||||
| _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) | ||||
| @@ -66,20 +66,19 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) | |||||
| #undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS | #undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS | ||||
| /* ================= RNG with input ================= */ | /* ================= RNG with input ================= */ | ||||
| #define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
| void add_input_layout_constraint() override; \ | |||||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
| \ | |||||
| public: \ | |||||
| RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \ | |||||
| _INPUTS(SymbolVar), const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| void init_output_static_infer_desc() override; \ | |||||
| void scn_do_execute() override; \ | |||||
| } \ | |||||
| ; | |||||
| #define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ | |||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT(RNG, RNGOprBase<megdnn::RNG>) \ | |||||
| void add_input_layout_constraint() override; \ | |||||
| cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ | |||||
| \ | |||||
| public: \ | |||||
| RNG(_INPUTS(VarNode*), const Param& param, const OperatorNodeConfig& config); \ | |||||
| MGE_WIN_DECLSPEC_FUC static _OUTPUTS make( \ | |||||
| _INPUTS(SymbolVar), const Param& param = {}, \ | |||||
| const OperatorNodeConfig& config = {}); \ | |||||
| void init_output_static_infer_desc() override; \ | |||||
| void scn_do_execute() override; \ | |||||
| }; | |||||
| /* ================= 1 input ================= */ | /* ================= 1 input ================= */ | ||||
| #define _INPUTS(preifx) preifx i0 | #define _INPUTS(preifx) preifx i0 | ||||
| @@ -100,7 +99,7 @@ _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) | |||||
| #undef _INPUTS | #undef _INPUTS | ||||
| #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS | ||||
| } // intl | |||||
| } // namespace intl | |||||
| using UniformRNG = intl::UniformRNG; | using UniformRNG = intl::UniformRNG; | ||||
| using GaussianRNG = intl::GaussianRNG; | using GaussianRNG = intl::GaussianRNG; | ||||
| @@ -111,16 +110,15 @@ using BetaRNG = intl::BetaRNG; | |||||
| using ShuffleRNG = intl::ShuffleRNGForward; | using ShuffleRNG = intl::ShuffleRNGForward; | ||||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | ||||
| ShuffleRNGBackward, | |||||
| intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{ | |||||
| ShuffleRNGBackward, intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) // { | |||||
| public: | public: | ||||
| ShuffleRNGBackward( | |||||
| VarNode* out_diff, VarNode* indices, VarNode* result_shape, const Param& param, | |||||
| const OperatorNodeConfig& config); | |||||
| ShuffleRNGBackward( | |||||
| VarNode* out_diff, VarNode* indices, VarNode* result_shape, | |||||
| const Param& param, const OperatorNodeConfig& config); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape, | |||||
| const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||||
| SymbolVar out_diff, SymbolVar indices, SymbolVar result_shape, | |||||
| const Param& param = {}, const OperatorNodeConfig& config = {}); | |||||
| }; | }; | ||||
| } // namespace opr | } // namespace opr | ||||
| @@ -19,7 +19,8 @@ failed_files = Manager().list() | |||||
| def process_file(file, clang_format, write): | def process_file(file, clang_format, write): | ||||
| source = open(file, "r").read() | source = open(file, "r").read() | ||||
| source = re.sub(r"MGB_DEFINE(?P<r>(.|\n)*?)// +{", "class MGB_DEFINE\g<r>{", source) | |||||
| source = re.sub(r"MGB_DEFINE(?P<r>([^\\]|\n)*?)// *{", r"class MGB_DEFINE\g<r>{", source) | |||||
| source, count = re.subn(r"(?<!#define )MGB_DEFINE(.*) +\\", r"class MGB_DEFINE\1{\\", source) | |||||
| result = subprocess.check_output( | result = subprocess.check_output( | ||||
| [ | [ | ||||
| @@ -33,6 +34,8 @@ def process_file(file, clang_format, write): | |||||
| ) | ) | ||||
| result = result.decode("utf-8") | result = result.decode("utf-8") | ||||
| if count: | |||||
| result = re.sub(r"class MGB_DEFINE(.*){( *)\\", r"MGB_DEFINE\1\2 \\", result) | |||||
| result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | result = re.sub(r"class MGB_DEFINE((.|\n)*?){", r"MGB_DEFINE\1// {", result) | ||||
| if write: | if write: | ||||