GitOrigin-RevId: 6f359b5b29
tags/v1.10.0
| @@ -1939,6 +1939,11 @@ class LayerNormBase : public OperatorBase { | |||||
| DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); | DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); | ||||
| DEF_OPR_PARAM(LayerNorm); | DEF_OPR_PARAM(LayerNorm); | ||||
| public: | |||||
| MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl( | |||||
| const TensorLayout& data, const Param& p, TensorLayout& dst, | |||||
| TensorLayout& mean, TensorLayout& rstd); | |||||
| protected: | protected: | ||||
| void deduce_layout_fwd( | void deduce_layout_fwd( | ||||
| const TensorLayout& data, const TensorLayout& weight, | const TensorLayout& data, const TensorLayout& weight, | ||||
| @@ -4,12 +4,11 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| void LayerNormBase::deduce_layout_fwd( | |||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||||
| MEGDNN_MARK_USED_VAR(weight); | |||||
| MEGDNN_MARK_USED_VAR(bias); | |||||
| auto p = param(); | |||||
| using Param = LayerNormBase::Param; | |||||
| void LayerNormBase::deduce_layout_fwd_impl( | |||||
| const TensorLayout& data, const Param& p, TensorLayout& dst, TensorLayout& mean, | |||||
| TensorLayout& rstd) { | |||||
| TensorShape unnormalized_shape; | TensorShape unnormalized_shape; | ||||
| unnormalized_shape.ndim = data.ndim - p.normalized_dim; | unnormalized_shape.ndim = data.ndim - p.normalized_dim; | ||||
| for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | ||||
| @@ -22,6 +21,14 @@ void LayerNormBase::deduce_layout_fwd( | |||||
| rstd = unnormalized_layout; | rstd = unnormalized_layout; | ||||
| } | } | ||||
| void LayerNormBase::deduce_layout_fwd( | |||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||||
| MEGDNN_MARK_USED_VAR(weight); | |||||
| MEGDNN_MARK_USED_VAR(bias); | |||||
| deduce_layout_fwd_impl(data, param(), dst, mean, rstd); | |||||
| } | |||||
| void LayerNormBase::check_layout_fwd( | void LayerNormBase::check_layout_fwd( | ||||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | ||||
| const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | ||||
| @@ -63,6 +63,7 @@ __all__ = [ | |||||
| "hsigmoid", | "hsigmoid", | ||||
| "hswish", | "hswish", | ||||
| "indexing_one_hot", | "indexing_one_hot", | ||||
| "layer_norm", | |||||
| "leaky_relu", | "leaky_relu", | ||||
| "linear", | "linear", | ||||
| "local_conv2d", | "local_conv2d", | ||||
| @@ -1135,9 +1136,6 @@ def layer_norm( | |||||
| bias: must not be None when the affine is true | bias: must not be None when the affine is true | ||||
| eps: a value added to the denominator for numerical stability. Default: 1e-5 | eps: a value added to the denominator for numerical stability. Default: 1e-5 | ||||
| """ | """ | ||||
| if amp._enabled: | |||||
| inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) | |||||
| if isinstance(normalized_shape, int): | if isinstance(normalized_shape, int): | ||||
| normalized_shape = [normalized_shape] | normalized_shape = [normalized_shape] | ||||
| @@ -0,0 +1,115 @@ | |||||
| #include "megbrain/opr/dnn/layer_norm.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "../blob_manager_impl.h" | |||||
| #include "../dnn_op_helper.h" | |||||
| #include "../op_trait.h" | |||||
| namespace mgb::imperative { | |||||
| namespace layer_norm { | |||||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const LayerNorm&>(def); | |||||
| size_t nr_inp = inputs.size(); | |||||
| auto p = op.param(); | |||||
| mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| if (nr_inp == 3) { | |||||
| return opr::LayerNorm::make( | |||||
| inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else { | |||||
| return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } | |||||
| } | |||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| auto&& op_def = def.cast_final_safe<LayerNorm>(); | |||||
| size_t nr_inp = inputs.size(); | |||||
| auto p = op_def.param(); | |||||
| mgb_assert( | |||||
| (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||||
| "num of inputs of pooling should be 1 or 3 but you give %zu", | |||||
| inputs.size()); | |||||
| auto&& inp = inputs[0]; | |||||
| auto& inp_cn = inp.comp_node; | |||||
| if (inp.layout.ndim == 0) { | |||||
| return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, | |||||
| {TensorLayout{dtype::Float32()}, inp_cn, {}}, | |||||
| {TensorLayout{dtype::Float32()}, inp_cn, {}}}, | |||||
| false}; | |||||
| } | |||||
| TensorLayout oup_layout, mean_layout, rstd_layout; | |||||
| megdnn::LayerNorm::deduce_layout_fwd_impl( | |||||
| inp.layout, p, oup_layout, mean_layout, rstd_layout); | |||||
| return {{{oup_layout, inp_cn, {}}, | |||||
| {mean_layout, inp_cn, {}}, | |||||
| {rstd_layout, inp_cn, {}}}, | |||||
| true}; | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||||
| auto&& op_def = def.cast_final_safe<LayerNorm>(); | |||||
| size_t nr_inp = inputs.size(); | |||||
| auto p = op_def.param(); | |||||
| mgb_assert( | |||||
| (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||||
| "num of inputs of pooling should be 1 or 3 but you give %zu", | |||||
| inputs.size()); | |||||
| auto cn = inputs[0]->comp_node(); | |||||
| DnnOprCaller<megdnn::LayerNorm> caller(cn); | |||||
| auto&& dnn_opr = caller.op; | |||||
| dnn_opr->param() = p; | |||||
| TensorLayout oup_layout, mean_layout, rstd_layout; | |||||
| megdnn::LayerNorm::deduce_layout_fwd_impl( | |||||
| inputs[0]->dnn_tensor().layout, p, oup_layout, mean_layout, rstd_layout); | |||||
| DeviceTensorND out_devtensor = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout); | |||||
| DeviceTensorND mean_devtensor = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, mean_layout); | |||||
| DeviceTensorND rstd_devtensor = | |||||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, rstd_layout); | |||||
| megdnn::Workspace dnn_wk; | |||||
| auto wk_size = caller.op->get_workspace_in_bytes( | |||||
| inputs[0]->dnn_tensor().layout, | |||||
| p.affine ? inputs[1]->dnn_tensor().layout : TensorLayout(), | |||||
| p.affine ? inputs[2]->dnn_tensor().layout : TensorLayout(), oup_layout, | |||||
| mean_layout, rstd_layout); | |||||
| if (wk_size != 0) { | |||||
| TensorLayout w_layout({wk_size}, dtype::Byte()); | |||||
| dnn_wk = caller.create_workspace(w_layout); | |||||
| } | |||||
| dnn_opr->exec( | |||||
| inputs[0]->dnn_tensor(), | |||||
| p.affine ? inputs[1]->dnn_tensor() : megdnn::TensorND(), | |||||
| p.affine ? inputs[2]->dnn_tensor() : megdnn::TensorND(), | |||||
| out_devtensor.as_megdnn(), mean_devtensor.as_megdnn(), | |||||
| rstd_devtensor.as_megdnn(), dnn_wk); | |||||
| return {Tensor::make(out_devtensor), Tensor::make(mean_devtensor), | |||||
| Tensor::make(rstd_devtensor)}; | |||||
| } | |||||
| OP_TRAIT_REG(LayerNorm, LayerNorm) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||||
| .fallback(); | |||||
| } // namespace layer_norm | |||||
| } // namespace mgb::imperative | |||||
| @@ -8,7 +8,6 @@ | |||||
| #include "megbrain/opr/dnn/correlation.h" | #include "megbrain/opr/dnn/correlation.h" | ||||
| #include "megbrain/opr/dnn/fake_quant.h" | #include "megbrain/opr/dnn/fake_quant.h" | ||||
| #include "megbrain/opr/dnn/images2neibs.h" | #include "megbrain/opr/dnn/images2neibs.h" | ||||
| #include "megbrain/opr/dnn/layer_norm.h" | |||||
| #include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
| #include "megbrain/opr/dnn/lrn.h" | #include "megbrain/opr/dnn/lrn.h" | ||||
| #include "megbrain/opr/dnn/lsq.h" | #include "megbrain/opr/dnn/lsq.h" | ||||
| @@ -729,28 +728,4 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | ||||
| } // namespace lrn | } // namespace lrn | ||||
| namespace layer_norm { | |||||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = static_cast<const LayerNorm&>(def); | |||||
| size_t nr_inp = inputs.size(); | |||||
| auto p = op.param(); | |||||
| mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| if (nr_inp == 3) { | |||||
| return opr::LayerNorm::make( | |||||
| inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } else { | |||||
| return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||||
| .node() | |||||
| ->owner_opr(); | |||||
| } | |||||
| } | |||||
| OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); | |||||
| } // namespace layer_norm | |||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||
| @@ -289,6 +289,28 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| return imperative::apply(op, inputs); | return imperative::apply(op, inputs); | ||||
| } | } | ||||
| ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||||
| // avoid the amp_dtype_autocast | |||||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||||
| ValueRefList converted(inputs.size()); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype; | |||||
| if (dtypes[i] != target_dtype) { | |||||
| converted[i] = imperative::apply( | |||||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||||
| } else { | |||||
| converted[i] = inputs[i]; | |||||
| } | |||||
| } | |||||
| return imperative::apply(op, converted); | |||||
| } | |||||
| return imperative::apply(op, inputs); | |||||
| } | |||||
| ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | ||||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | SmallVector<DType> dtypes = get_value_dtypes(inputs); | ||||
| mgb::DType target_dtype = get_promoted_dtype(dtypes); | mgb::DType target_dtype = get_promoted_dtype(dtypes); | ||||
| @@ -319,6 +341,7 @@ struct DTypePromoteRuleRegistry { | |||||
| register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | ||||
| register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | ||||
| register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); | register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); | ||||
| register_dtype_promote_rule<LayerNorm>(layer_norm_rule); | |||||
| } | } | ||||
| } register_helper; | } register_helper; | ||||