GitOrigin-RevId: 6f359b5b29
tags/v1.10.0
| @@ -1939,6 +1939,11 @@ class LayerNormBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); | |||
| DEF_OPR_PARAM(LayerNorm); | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC static void deduce_layout_fwd_impl( | |||
| const TensorLayout& data, const Param& p, TensorLayout& dst, | |||
| TensorLayout& mean, TensorLayout& rstd); | |||
| protected: | |||
| void deduce_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, | |||
| @@ -4,12 +4,11 @@ | |||
| namespace megdnn { | |||
| void LayerNormBase::deduce_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
| MEGDNN_MARK_USED_VAR(weight); | |||
| MEGDNN_MARK_USED_VAR(bias); | |||
| auto p = param(); | |||
| using Param = LayerNormBase::Param; | |||
| void LayerNormBase::deduce_layout_fwd_impl( | |||
| const TensorLayout& data, const Param& p, TensorLayout& dst, TensorLayout& mean, | |||
| TensorLayout& rstd) { | |||
| TensorShape unnormalized_shape; | |||
| unnormalized_shape.ndim = data.ndim - p.normalized_dim; | |||
| for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { | |||
| @@ -22,6 +21,14 @@ void LayerNormBase::deduce_layout_fwd( | |||
| rstd = unnormalized_layout; | |||
| } | |||
| void LayerNormBase::deduce_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { | |||
| MEGDNN_MARK_USED_VAR(weight); | |||
| MEGDNN_MARK_USED_VAR(bias); | |||
| deduce_layout_fwd_impl(data, param(), dst, mean, rstd); | |||
| } | |||
| void LayerNormBase::check_layout_fwd( | |||
| const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, | |||
| const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { | |||
| @@ -63,6 +63,7 @@ __all__ = [ | |||
| "hsigmoid", | |||
| "hswish", | |||
| "indexing_one_hot", | |||
| "layer_norm", | |||
| "leaky_relu", | |||
| "linear", | |||
| "local_conv2d", | |||
| @@ -1135,9 +1136,6 @@ def layer_norm( | |||
| bias: must not be None when the affine is true | |||
| eps: a value added to the denominator for numerical stability. Default: 1e-5 | |||
| """ | |||
| if amp._enabled: | |||
| inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) | |||
| if isinstance(normalized_shape, int): | |||
| normalized_shape = [normalized_shape] | |||
| @@ -0,0 +1,115 @@ | |||
| #include "megbrain/opr/dnn/layer_norm.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "../blob_manager_impl.h" | |||
| #include "../dnn_op_helper.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative { | |||
| namespace layer_norm { | |||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const LayerNorm&>(def); | |||
| size_t nr_inp = inputs.size(); | |||
| auto p = op.param(); | |||
| mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| if (nr_inp == 3) { | |||
| return opr::LayerNorm::make( | |||
| inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } else { | |||
| return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<LayerNorm>(); | |||
| size_t nr_inp = inputs.size(); | |||
| auto p = op_def.param(); | |||
| mgb_assert( | |||
| (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||
| "num of inputs of pooling should be 1 or 3 but you give %zu", | |||
| inputs.size()); | |||
| auto&& inp = inputs[0]; | |||
| auto& inp_cn = inp.comp_node; | |||
| if (inp.layout.ndim == 0) { | |||
| return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, | |||
| {TensorLayout{dtype::Float32()}, inp_cn, {}}, | |||
| {TensorLayout{dtype::Float32()}, inp_cn, {}}}, | |||
| false}; | |||
| } | |||
| TensorLayout oup_layout, mean_layout, rstd_layout; | |||
| megdnn::LayerNorm::deduce_layout_fwd_impl( | |||
| inp.layout, p, oup_layout, mean_layout, rstd_layout); | |||
| return {{{oup_layout, inp_cn, {}}, | |||
| {mean_layout, inp_cn, {}}, | |||
| {rstd_layout, inp_cn, {}}}, | |||
| true}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) { | |||
| auto&& op_def = def.cast_final_safe<LayerNorm>(); | |||
| size_t nr_inp = inputs.size(); | |||
| auto p = op_def.param(); | |||
| mgb_assert( | |||
| (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), | |||
| "num of inputs of pooling should be 1 or 3 but you give %zu", | |||
| inputs.size()); | |||
| auto cn = inputs[0]->comp_node(); | |||
| DnnOprCaller<megdnn::LayerNorm> caller(cn); | |||
| auto&& dnn_opr = caller.op; | |||
| dnn_opr->param() = p; | |||
| TensorLayout oup_layout, mean_layout, rstd_layout; | |||
| megdnn::LayerNorm::deduce_layout_fwd_impl( | |||
| inputs[0]->dnn_tensor().layout, p, oup_layout, mean_layout, rstd_layout); | |||
| DeviceTensorND out_devtensor = | |||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, oup_layout); | |||
| DeviceTensorND mean_devtensor = | |||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, mean_layout); | |||
| DeviceTensorND rstd_devtensor = | |||
| BlobManager::inst()->alloc_workspace_with_defrag(cn, rstd_layout); | |||
| megdnn::Workspace dnn_wk; | |||
| auto wk_size = caller.op->get_workspace_in_bytes( | |||
| inputs[0]->dnn_tensor().layout, | |||
| p.affine ? inputs[1]->dnn_tensor().layout : TensorLayout(), | |||
| p.affine ? inputs[2]->dnn_tensor().layout : TensorLayout(), oup_layout, | |||
| mean_layout, rstd_layout); | |||
| if (wk_size != 0) { | |||
| TensorLayout w_layout({wk_size}, dtype::Byte()); | |||
| dnn_wk = caller.create_workspace(w_layout); | |||
| } | |||
| dnn_opr->exec( | |||
| inputs[0]->dnn_tensor(), | |||
| p.affine ? inputs[1]->dnn_tensor() : megdnn::TensorND(), | |||
| p.affine ? inputs[2]->dnn_tensor() : megdnn::TensorND(), | |||
| out_devtensor.as_megdnn(), mean_devtensor.as_megdnn(), | |||
| rstd_devtensor.as_megdnn(), dnn_wk); | |||
| return {Tensor::make(out_devtensor), Tensor::make(mean_devtensor), | |||
| Tensor::make(rstd_devtensor)}; | |||
| } | |||
| OP_TRAIT_REG(LayerNorm, LayerNorm) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .fallback(); | |||
| } // namespace layer_norm | |||
| } // namespace mgb::imperative | |||
| @@ -8,7 +8,6 @@ | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/layer_norm.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| @@ -729,28 +728,4 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace lrn | |||
| namespace layer_norm { | |||
| cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const LayerNorm&>(def); | |||
| size_t nr_inp = inputs.size(); | |||
| auto p = op.param(); | |||
| mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| if (nr_inp == 3) { | |||
| return opr::LayerNorm::make( | |||
| inputs[0], inputs[1], inputs[2], op.param(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } else { | |||
| return opr::LayerNorm::make(inputs[0], op.param(), config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace layer_norm | |||
| } // namespace mgb::imperative | |||
| @@ -289,6 +289,28 @@ ValueRefList batch_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| ValueRefList layer_norm_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| // avoid the amp_dtype_autocast | |||
| if (DTypePromoteCfg::amp_dtype_autocast_enabled) { | |||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||
| ValueRefList converted(inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| mgb::DType target_dtype = DTypePromoteCfg::amp_high_prec_dtype; | |||
| if (dtypes[i] != target_dtype) { | |||
| converted[i] = imperative::apply( | |||
| ApplyOp(*TypeCvt::make(target_dtype)), inputs[i])[0]; | |||
| } else { | |||
| converted[i] = inputs[i]; | |||
| } | |||
| } | |||
| return imperative::apply(op, converted); | |||
| } | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| ValueRefList naive_promote_rule(const OpDef& op, Span<ValueRef> inputs) { | |||
| SmallVector<DType> dtypes = get_value_dtypes(inputs); | |||
| mgb::DType target_dtype = get_promoted_dtype(dtypes); | |||
| @@ -319,6 +341,7 @@ struct DTypePromoteRuleRegistry { | |||
| register_dtype_promote_rule<BatchNorm>(batch_norm_rule); | |||
| register_dtype_promote_rule<Convolution3D>(naive_promote_rule); | |||
| register_dtype_promote_rule<Convolution3DBackwardData>(naive_promote_rule); | |||
| register_dtype_promote_rule<LayerNorm>(layer_norm_rule); | |||
| } | |||
| } register_helper; | |||