GitOrigin-RevId: a1b1e89b76
tags/v1.6.0
| @@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const { | |||||
| static_cast<char>(m_name), static_cast<char>(rhs.m_name)); | static_cast<char>(m_name), static_cast<char>(rhs.m_name)); | ||||
| if (operator==(rhs)) | if (operator==(rhs)) | ||||
| return Dimension(m_name, 1, 1); | return Dimension(m_name, 1, 1); | ||||
| megdnn_assert( | |||||
| !(*this < rhs), | |||||
| "Divisor must be smaller than dividend(dividend:%s, divisor:%s)", | |||||
| to_string().c_str(), rhs.to_string().c_str()); | |||||
| if (m_stride == rhs.m_stride) { | if (m_stride == rhs.m_stride) { | ||||
| if (m_extent == UNDETERMINED_EXTENT) { | if (m_extent == UNDETERMINED_EXTENT) { | ||||
| megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, | megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, | ||||
| @@ -0,0 +1,431 @@ | |||||
| /** | |||||
| * \file src/gopt/impl/folding_conv_dimshuffle.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| #include "megbrain/gopt/reformat_manager.h" | |||||
| #if CUDA_VERSION >= 10020 | |||||
| MIDOUT_DECL(megbrain_folding_conv_dimshuffle) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \ | |||||
| midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | |||||
| using namespace gopt; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| /* ==================== FoldingConvBiasDimshufflePass ================= */ | |||||
| const char* FoldingConvBiasDimshufflePass::name() const { | |||||
| return mgb_cstr_log("folding conv bias dimshuffle pass"); | |||||
| } | |||||
| void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
| MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); | |||||
| using DepType = cg::OperatorNodeProp::DepType; | |||||
| ThinHashMap<OperatorNodeBase*, | |||||
| SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||||
| readers; | |||||
| static const ThinHashSet<Typeinfo*> opr_type_list = { | |||||
| opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), | |||||
| opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; | |||||
| opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||||
| for (auto&& i : opr->node_prop().dep_map()) { | |||||
| if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||||
| readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||||
| } | |||||
| } | |||||
| }); | |||||
| auto rewriter = opt.graph().make_rewriter(); | |||||
| auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers]( | |||||
| OperatorNodeBase* opr) { | |||||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||||
| // check typecvt | |||||
| auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||||
| if (typecvt == nullptr) | |||||
| return false; | |||||
| auto inp_dtype = typecvt->input(0)->dtype(), | |||||
| out_dtype = typecvt->output(0)->dtype(); | |||||
| bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| out_dtype.enumv() == DTypeEnum::Float32; | |||||
| if (!is_s82f32) | |||||
| return false; | |||||
| opr_set.insert(opr); | |||||
| // check reshape | |||||
| auto reshape = | |||||
| try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr()); | |||||
| if (reshape == nullptr) | |||||
| return false; | |||||
| opr_set.insert(reshape); | |||||
| for (auto&& i : readers[reshape]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check shuffle | |||||
| auto shuffle = | |||||
| try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||||
| if (shuffle == nullptr) | |||||
| return false; | |||||
| auto&& param = shuffle->param(); | |||||
| if (param.pattern_len != 5) | |||||
| return false; | |||||
| bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
| param.pattern[2] == 4 && param.pattern[3] == 2 && | |||||
| param.pattern[4] == 3 && | |||||
| shuffle->input(0)->shape()[4] == 4; | |||||
| if (!is_nchw42nchw) | |||||
| return false; | |||||
| opr_set.insert(shuffle); | |||||
| for (auto&& i : readers[shuffle]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check conv bias | |||||
| auto conv_bias = | |||||
| try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr()); | |||||
| if (conv_bias == nullptr) | |||||
| return false; | |||||
| inp_dtype = conv_bias->input(0)->dtype(); | |||||
| bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW4; | |||||
| if (!is_s8nchw4) | |||||
| return false; | |||||
| if (conv_bias->input().size() != 3) | |||||
| return false; | |||||
| opr_set.insert(conv_bias); | |||||
| for (auto&& i : readers[conv_bias]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| for (auto reader : reader_set) { | |||||
| if (opr_set.count(reader) <= 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||||
| filter = rewriter.get_var(conv_bias->input(1)), | |||||
| bias = rewriter.get_var(conv_bias->input(2)); | |||||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHW})({bias}); | |||||
| new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node(); | |||||
| auto new_param = conv_bias->param(); | |||||
| new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; | |||||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
| OperatorNodeConfig{dtype::Float32()}); | |||||
| rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||||
| mgb_cstr_log("replace conv_bias + typecvt + " | |||||
| "dimshuffle + " | |||||
| "reshape to conv_bias(NCHW4_NCHW)")); | |||||
| return true; | |||||
| }; | |||||
| auto try_conv_reformat_nchw42nchw32 = [&rewriter, | |||||
| &readers](OperatorNodeBase* opr) { | |||||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||||
| // check reshape | |||||
| auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||||
| if (reshape1 == nullptr) | |||||
| return false; | |||||
| opr_set.insert(opr); | |||||
| // check dimshuffle | |||||
| auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||||
| reshape1->input(0)->owner_opr()); | |||||
| if (shuffle == nullptr) | |||||
| return false; | |||||
| auto&& param = shuffle->param(); | |||||
| if (param.pattern_len != 6) | |||||
| return false; | |||||
| bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
| param.pattern[2] == 3 && param.pattern[3] == 4 && | |||||
| param.pattern[4] == 2 && param.pattern[5] == 5 && | |||||
| shuffle->output(0)->shape()[5] == 4 && | |||||
| shuffle->output(0)->shape()[4] == 8; | |||||
| if (!is_nchw42nchw32) | |||||
| return false; | |||||
| opr_set.insert(shuffle); | |||||
| for (auto&& i : readers[shuffle]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check reshape | |||||
| auto reshape2 = | |||||
| try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||||
| if (reshape2 == nullptr) | |||||
| return false; | |||||
| opr_set.insert(reshape2); | |||||
| for (auto&& i : readers[reshape2]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check conv bias | |||||
| auto conv_bias = | |||||
| try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||||
| if (conv_bias == nullptr) | |||||
| return false; | |||||
| auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
| bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW4; | |||||
| if (!is_s8nchw4) | |||||
| return false; | |||||
| if (conv_bias->input().size() != 3) | |||||
| return false; | |||||
| opr_set.insert(conv_bias); | |||||
| for (auto&& i : readers[conv_bias]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| for (auto reader : reader_set) { | |||||
| if (opr_set.count(reader) <= 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||||
| filter = rewriter.get_var(conv_bias->input(1)), | |||||
| bias = rewriter.get_var(conv_bias->input(2)); | |||||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
| TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias}); | |||||
| auto new_param = conv_bias->param(); | |||||
| new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; | |||||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
| conv_bias->config()); | |||||
| rewriter.replace_var( | |||||
| opr->output(0), conv_bias_shuffle.node(), | |||||
| mgb_cstr_log("replace conv_bias + " | |||||
| "reformat to conv_bias(NCHW4_NCHW32)")); | |||||
| return true; | |||||
| }; | |||||
| auto try_conv_reformat_nchw42nhwc = [&rewriter, | |||||
| &readers](OperatorNodeBase* opr) { | |||||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||||
| // check reshape | |||||
| auto reshape = try_cast_as_op<opr::Reshape>(opr); | |||||
| if (reshape == nullptr) | |||||
| return false; | |||||
| opr_set.insert(opr); | |||||
| // check dimshuffle | |||||
| auto shuffle = | |||||
| try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||||
| if (shuffle == nullptr) | |||||
| return false; | |||||
| auto&& param = shuffle->param(); | |||||
| if (param.pattern_len != 5) | |||||
| return false; | |||||
| bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && | |||||
| param.pattern[2] == 3 && param.pattern[3] == 1 && | |||||
| param.pattern[4] == 4 && | |||||
| shuffle->output(0)->shape()[4] == 4; | |||||
| if (!is_nchw42nhwc) | |||||
| return false; | |||||
| opr_set.insert(shuffle); | |||||
| for (auto&& i : readers[shuffle]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| auto typecvt = | |||||
| try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr()); | |||||
| if (typecvt == nullptr) | |||||
| return false; | |||||
| auto in_dtype = typecvt->input(0)->dtype(), | |||||
| out_dtype = typecvt->output(0)->dtype(); | |||||
| bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| (out_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| out_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||||
| if (!is_s82s4) | |||||
| return false; | |||||
| opr_set.insert(typecvt); | |||||
| for (auto&& i : readers[typecvt]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check conv bias | |||||
| auto conv_bias = | |||||
| try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||||
| if (conv_bias == nullptr) | |||||
| return false; | |||||
| auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
| bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW4; | |||||
| if (!is_s8nchw4) | |||||
| return false; | |||||
| if (conv_bias->input().size() != 3) | |||||
| return false; | |||||
| opr_set.insert(conv_bias); | |||||
| for (auto&& i : readers[conv_bias]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| for (auto reader : reader_set) { | |||||
| if (opr_set.count(reader) <= 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||||
| filter = rewriter.get_var(conv_bias->input(1)), | |||||
| bias = rewriter.get_var(conv_bias->input(2)); | |||||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
| TensorFormats::NCHWc4, TensorFormats::NHWC})({bias}); | |||||
| auto new_param = conv_bias->param(); | |||||
| new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; | |||||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
| OperatorNodeConfig{out_dtype}); | |||||
| rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||||
| mgb_cstr_log("replace conv_bias + " | |||||
| "reformat to conv_bias(NCHW4_NHWC)")); | |||||
| return true; | |||||
| }; | |||||
| auto try_conv_reformat_nchw322nchw4 = [&rewriter, | |||||
| &readers](OperatorNodeBase* opr) { | |||||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||||
| // check reshape | |||||
| auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||||
| if (reshape1 == nullptr) | |||||
| return false; | |||||
| opr_set.insert(opr); | |||||
| // check dimshuffle | |||||
| auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||||
| reshape1->input(0)->owner_opr()); | |||||
| if (shuffle == nullptr) | |||||
| return false; | |||||
| auto&& param = shuffle->param(); | |||||
| if (param.pattern_len != 6) | |||||
| return false; | |||||
| bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
| param.pattern[2] == 4 && param.pattern[3] == 2 && | |||||
| param.pattern[4] == 3 && param.pattern[5] == 5 && | |||||
| shuffle->input(0)->shape()[5] == 4 && | |||||
| shuffle->input(0)->shape()[4] == 8; | |||||
| if (!is_nchw322nchw4) | |||||
| return false; | |||||
| opr_set.insert(shuffle); | |||||
| for (auto&& i : readers[shuffle]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check reshape | |||||
| auto reshape2 = | |||||
| try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||||
| if (reshape2 == nullptr) | |||||
| return false; | |||||
| opr_set.insert(reshape2); | |||||
| for (auto&& i : readers[reshape2]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| // check conv bias | |||||
| auto conv_bias = | |||||
| try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||||
| if (conv_bias == nullptr) | |||||
| return false; | |||||
| auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
| bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| conv_bias->param().format == | |||||
| megdnn::param::ConvBias::Format::NCHW32; | |||||
| if (!is_s8nchw32) | |||||
| return false; | |||||
| if (conv_bias->input().size() != 3) | |||||
| return false; | |||||
| opr_set.insert(conv_bias); | |||||
| for (auto&& i : readers[conv_bias]) { | |||||
| if (i.second & DepType::DEV_VALUE) { | |||||
| reader_set.insert(i.first); | |||||
| } | |||||
| } | |||||
| for (auto reader : reader_set) { | |||||
| if (opr_set.count(reader) <= 0) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||||
| filter = rewriter.get_var(conv_bias->input(1)), | |||||
| bias = rewriter.get_var(conv_bias->input(2)); | |||||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
| TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias}); | |||||
| auto new_param = conv_bias->param(); | |||||
| new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; | |||||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
| conv_bias->config()); | |||||
| rewriter.replace_var( | |||||
| opr->output(0), conv_bias_shuffle.node(), | |||||
| mgb_cstr_log("replace conv_bias + " | |||||
| "reformat to conv_bias(NCHW32_NCHW4)")); | |||||
| return true; | |||||
| }; | |||||
| MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); | |||||
| MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32); | |||||
| auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | |||||
| &try_conv_reformat_nchw42nchw32, | |||||
| &try_conv_reformat_nchw42nhwc, | |||||
| &try_conv_reformat_nchw322nchw4, | |||||
| &rewriter](OperatorNodeBase* opr) { | |||||
| if (!try_conv_dimshuffle_reshape_typecvt(opr) && | |||||
| !try_conv_reformat_nchw42nchw32(opr) && | |||||
| !try_conv_reformat_nchw42nhwc(opr) && | |||||
| !try_conv_reformat_nchw322nchw4(opr)) { | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| } | |||||
| }; | |||||
| opt.graph().iter(on_opr); | |||||
| rewriter.apply_inplace(); | |||||
| MIDOUT_E | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,451 @@ | |||||
| /** | |||||
| * \file src/gopt/impl/padding_channel.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "megbrain/gopt/inference.h" | |||||
| #include "megbrain/opr/basic_arith.h" | |||||
| #include "megbrain/opr/dnn/convolution.h" | |||||
| #include "megbrain/opr/dnn/pooling.h" | |||||
| #include "megbrain/opr/imgproc.h" | |||||
| #include "megbrain/opr/misc.h" | |||||
| #include "megbrain/opr/nn_int.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| #include "megbrain/opr/utility.h" | |||||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megdnn/tensor_format.h" | |||||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
| #include "megbrain/gopt/misc.h" | |||||
| #include "megbrain/utils/hash_ct.h" | |||||
| #include "midout.h" | |||||
| #include "megbrain/gopt/reformat_manager.h" | |||||
| MIDOUT_DECL(megbrain_padding_channel) | |||||
| #define MIDOUT_B(tag) \ | |||||
| MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) { | |||||
| #define MIDOUT_E \ | |||||
| } \ | |||||
| MIDOUT_END(); | |||||
| using namespace mgb; | |||||
| using namespace gopt; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| /* ==================== PaddingChannelPass ================= */ | |||||
| const char* PaddingChannelPass::name() const { | |||||
| return mgb_cstr_log("padding output channel to multiple of 4/32"); | |||||
| } | |||||
| void PaddingChannelPass::apply(OptState& opt) const { | |||||
| MIDOUT_B("PaddingChannelPass::apply"); | |||||
| // do not check shape | |||||
| opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | |||||
| VarReplaceCheckFlag::CHECK_SHAPE); | |||||
| ThinHashSet<OperatorNodeBase*> padding_oprs; | |||||
| ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*( | |||||
| OperatorNodeBase*, const VarNodeArray&)>> | |||||
| opr_replace_funcs; | |||||
| auto rewriter = opt.graph().make_rewriter(); | |||||
| auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||||
| mgb_assert(inp->shape().ndim == 4); | |||||
| mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||||
| inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||||
| TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], | |||||
| inp->shape()[3]}; | |||||
| std::shared_ptr<HostTensorND> host_val = | |||||
| std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||||
| host_val->resize(shape); | |||||
| auto ptr = host_val->raw_ptr(); | |||||
| size_t size_bytes = | |||||
| TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||||
| std::memset(ptr, 0, size_bytes); | |||||
| auto padding = | |||||
| opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||||
| auto out = opr::Concat::make({inp, padding}, 1); | |||||
| return out.node(); | |||||
| }; | |||||
| auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||||
| mgb_assert(inp->shape().ndim == 4); | |||||
| mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||||
| inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||||
| TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], | |||||
| inp->shape()[3]}; | |||||
| std::shared_ptr<HostTensorND> host_val = | |||||
| std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||||
| host_val->resize(shape); | |||||
| auto ptr = host_val->raw_ptr(); | |||||
| size_t size_bytes = | |||||
| TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||||
| std::memset(ptr, 0, size_bytes); | |||||
| auto padding = | |||||
| opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||||
| auto out = opr::Concat::make({inp, padding}, 0); | |||||
| return out.node(); | |||||
| }; | |||||
| auto extract_subtensor = [](VarNode* inp, | |||||
| const TensorShape& orig_shape) -> VarNode* { | |||||
| mgb_assert(inp->shape().ndim == 4); | |||||
| mgb_assert(inp->shape()[0] == orig_shape[0]); | |||||
| mgb_assert(inp->shape()[2] == orig_shape[2]); | |||||
| mgb_assert(inp->shape()[3] == orig_shape[3]); | |||||
| size_t orig_channels = orig_shape[1]; | |||||
| auto x = SymbolVar(inp); | |||||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
| using AIdx = opr::Subtensor::AxisIndexer; | |||||
| auto sub = opr::Subtensor::make( | |||||
| x, {AIdx::make_interval(0, None, None, cv(1)), | |||||
| AIdx::make_interval(1, None, cv(orig_channels), None), | |||||
| AIdx::make_interval(2, None, None, cv(1)), | |||||
| AIdx::make_interval(3, None, None, cv(1))}); | |||||
| return sub.node(); | |||||
| }; | |||||
| // padding policy for conv bias with data type qint8 | |||||
| auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, | |||||
| &pad_out_channels]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| mgb_assert(new_inp.size() == 3); | |||||
| mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||||
| auto inps = new_inp; | |||||
| size_t out_channels = opr->input(1)->shape()[0]; | |||||
| size_t in_channels = opr->input(1)->shape()[1]; | |||||
| size_t new_in_channels = new_inp[0]->shape()[1]; | |||||
| // pad input channels | |||||
| if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||||
| size_t pad_channels = new_in_channels - in_channels; | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
| } else { | |||||
| size_t pad_channels = 0; | |||||
| mgb_assert(new_in_channels == in_channels); | |||||
| if (in_channels <= 16) { | |||||
| if (in_channels % 4) | |||||
| pad_channels = 4 - (in_channels % 4); // pad to use dp4a | |||||
| } else { | |||||
| if (in_channels % 32) | |||||
| pad_channels = | |||||
| 32 - (in_channels % 32); // pad to use tensorcore | |||||
| } | |||||
| if (pad_channels > 0) { | |||||
| inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
| } | |||||
| } | |||||
| out_channels = inps[1]->shape()[0]; | |||||
| in_channels = inps[1]->shape()[1]; | |||||
| size_t pad_channels = 0; | |||||
| if (out_channels <= 16) { | |||||
| if (out_channels % 4) | |||||
| pad_channels = 4 - (out_channels % 4); | |||||
| } else { | |||||
| if (out_channels % 32) | |||||
| pad_channels = 32 - (out_channels % 32); | |||||
| } | |||||
| if (pad_channels > 0) { | |||||
| inps[1] = pad_out_channels(inps[1], pad_channels); | |||||
| inps[2] = pad_in_channels(inps[2], pad_channels); | |||||
| padding_oprs.insert(opr); | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
| }; | |||||
| // padding policy for conv bias with data type qint4 and quint4 | |||||
| auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, | |||||
| &pad_out_channels]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| mgb_assert(new_inp.size() == 3); | |||||
| mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||||
| auto inps = new_inp; | |||||
| size_t out_channels = opr->input(1)->shape()[0]; | |||||
| size_t in_channels = opr->input(1)->shape()[1]; | |||||
| size_t new_in_channels = new_inp[0]->shape()[1]; | |||||
| // pad input channels | |||||
| if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||||
| if (new_in_channels <= 32) { | |||||
| if (new_in_channels % 8 == 0) { | |||||
| size_t pad_channels = new_in_channels - in_channels; | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
| } else { | |||||
| size_t pad_channels_0 = 8 - (new_in_channels % 8); | |||||
| size_t pad_channels_1 = 8 - (in_channels % 8); | |||||
| inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||||
| } | |||||
| } else { | |||||
| if (new_in_channels % 64 == 0) { | |||||
| size_t pad_channels = new_in_channels - in_channels; | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
| } else { | |||||
| size_t pad_channels_0 = 64 - (new_in_channels % 64); | |||||
| size_t pad_channels_1 = 64 - (in_channels % 64); | |||||
| inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| size_t pad_channels = 0; | |||||
| mgb_assert(new_in_channels == in_channels); | |||||
| if (in_channels <= 32) { | |||||
| if (in_channels % 8) | |||||
| pad_channels = 8 - (in_channels % 8); | |||||
| } else { | |||||
| if (in_channels % 64) | |||||
| pad_channels = 64 - (in_channels % 64); | |||||
| } | |||||
| if (pad_channels > 0) { | |||||
| inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
| } | |||||
| } | |||||
| out_channels = inps[1]->shape()[0]; | |||||
| in_channels = inps[1]->shape()[1]; | |||||
| size_t pad_channels = 0; | |||||
| if (out_channels <= 32) { | |||||
| if (out_channels % 8) | |||||
| pad_channels = 8 - (out_channels % 8); | |||||
| } else { | |||||
| if (out_channels % 64) | |||||
| pad_channels = 64 - (out_channels % 64); | |||||
| } | |||||
| if (pad_channels > 0) { | |||||
| inps[1] = pad_out_channels(inps[1], pad_channels); | |||||
| inps[2] = pad_in_channels(inps[2], pad_channels); | |||||
| padding_oprs.insert(opr); | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
| }; | |||||
| opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = | |||||
| [&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( | |||||
| OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||||
| if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { | |||||
| return padding_policy_qint8(opr, new_inp); | |||||
| } else if (opr->input(0)->dtype().enumv() == | |||||
| DTypeEnum::QuantizedS4 || | |||||
| opr->input(0)->dtype().enumv() == | |||||
| DTypeEnum::Quantized4Asymm) { | |||||
| return padding_policy_int4(opr, new_inp); | |||||
| } else { | |||||
| mgb_assert( | |||||
| padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||||
| "conv bias operator for data type(%s) cannot be " | |||||
| "padded channel. " | |||||
| "consumer(%s), producer(%s)", | |||||
| opr->input(0)->dtype().name(), opr->cname(), | |||||
| opr->input(0)->owner_opr()->cname()); | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| }; | |||||
| opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = | |||||
| [&padding_oprs, &pad_in_channels, &pad_out_channels]( | |||||
| OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||||
| if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { | |||||
| mgb_assert( | |||||
| padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||||
| "conv bwd data operator for data type(%s) cannot " | |||||
| "be " | |||||
| "padded channel. " | |||||
| "consumer(%s), producer(%s)", | |||||
| opr->input(0)->dtype().name(), opr->cname(), | |||||
| opr->input(0)->owner_opr()->cname()); | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| mgb_assert(new_inp.size() == 2, | |||||
| "deconv (conv bwd data) operator for inference can " | |||||
| "only have 2 input vars(got:%zu)", | |||||
| new_inp.size()); | |||||
| mgb_assert( | |||||
| opr->input(0)->shape().eq_shape(new_inp[0]->shape())); | |||||
| auto inps = new_inp; | |||||
| size_t out_channels = opr->input(0)->shape()[0]; | |||||
| size_t in_channels = opr->input(0)->shape()[1]; | |||||
| size_t new_out_channels = new_inp[1]->shape()[1]; | |||||
| // pad output channels | |||||
| if (padding_oprs.count(opr->input(1)->owner_opr())) { | |||||
| size_t pad_channels = new_out_channels - out_channels; | |||||
| inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||||
| } else { | |||||
| size_t pad_channels = 0; | |||||
| if (out_channels % 4) | |||||
| pad_channels = 4 - (out_channels % 4); | |||||
| if (pad_channels > 0) { | |||||
| inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
| } | |||||
| } | |||||
| out_channels = inps[0]->shape()[0]; | |||||
| in_channels = inps[0]->shape()[1]; | |||||
| // pad input channels | |||||
| size_t pad_channels = 0; | |||||
| if (in_channels % 4) | |||||
| pad_channels = 4 - (in_channels % 4); | |||||
| if (pad_channels > 0) { | |||||
| inps[0] = pad_in_channels(inps[0], pad_channels); | |||||
| padding_oprs.insert(opr); | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, inps, | |||||
| opr->config()); | |||||
| }; | |||||
| auto replace_format_aware_opr = [&padding_oprs]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && | |||||
| opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && | |||||
| opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { | |||||
| mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||||
| "operator(type:%s,name:%s) for data type(%s) cannot be " | |||||
| "padded channel. extra info:" | |||||
| "consumer(%s), producer(%s)", | |||||
| opr->dyn_typeinfo()->name, opr->cname(), | |||||
| opr->input(0)->dtype().name(), opr->cname(), | |||||
| opr->input(0)->owner_opr()->cname()); | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||||
| opr->config()); | |||||
| } | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||||
| padding_oprs.insert(opr); | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
| }; | |||||
| opr_replace_funcs[opr::PoolingForward::typeinfo()] = | |||||
| replace_format_aware_opr; | |||||
| opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = | |||||
| replace_format_aware_opr; | |||||
| auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| bool have_padding_inp = false; | |||||
| bool padding_all_inps = true; | |||||
| bool same_padding = true; | |||||
| size_t channels_after_padding = 0; | |||||
| size_t i = 0; | |||||
| for (auto&& cur_inp : opr->input()) { | |||||
| bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||||
| if (padding_cur_inp) { | |||||
| if (!have_padding_inp) | |||||
| have_padding_inp = true; | |||||
| if (channels_after_padding == 0) { | |||||
| channels_after_padding = new_inp[i]->shape()[1]; | |||||
| } else { | |||||
| same_padding = | |||||
| channels_after_padding == new_inp[i]->shape()[1]; | |||||
| } | |||||
| } | |||||
| if (padding_all_inps && (!padding_cur_inp || !same_padding)) | |||||
| padding_all_inps = false; | |||||
| ++i; | |||||
| } | |||||
| if (have_padding_inp && !padding_all_inps) { | |||||
| auto inps = new_inp; | |||||
| for (size_t i = 0; i < new_inp.size(); ++i) { | |||||
| auto cur_inp = opr->input(i); | |||||
| bool padding_cur_inp = | |||||
| padding_oprs.count(cur_inp->owner_opr()) > 0; | |||||
| if (padding_cur_inp) { | |||||
| inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||||
| } | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
| } | |||||
| if (padding_all_inps) { | |||||
| padding_oprs.insert(opr); | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
| }; | |||||
| opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = | |||||
| replace_elemwise_like_opr; | |||||
| opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; | |||||
| opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; | |||||
| auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( | |||||
| OperatorNodeBase* opr, | |||||
| const VarNodeArray& new_inp) { | |||||
| mgb_assert(opr->input().size() == new_inp.size()); | |||||
| auto inps = new_inp; | |||||
| for (size_t i = 0; i < new_inp.size(); ++i) { | |||||
| auto cur_inp = opr->input(i); | |||||
| bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||||
| if (padding_cur_inp) { | |||||
| inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||||
| } | |||||
| } | |||||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
| }; | |||||
| opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; | |||||
| opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; | |||||
| opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; | |||||
| opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; | |||||
| opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; | |||||
| auto on_opr = [&opt, &rewriter, &opr_replace_funcs, | |||||
| &extract_subtensor](OperatorNodeBase* opr) { | |||||
| auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); | |||||
| if (it != opr_replace_funcs.end()) { | |||||
| VarNodeArray new_inp; | |||||
| new_inp.reserve(opr->input().size()); | |||||
| for (auto&& inp : opr->input()) { | |||||
| new_inp.push_back(rewriter.get_var(inp)); | |||||
| } | |||||
| auto new_opr = (it->second)(opr, new_inp); | |||||
| auto &&out0 = opr->output(), &&out1 = new_opr->output(); | |||||
| mgb_assert(out0.size() == out1.size(), | |||||
| "bad opr replace: src=%s{%s} dst=%s{%s}, " | |||||
| "src.size=%zu " | |||||
| "dst.size=%zu", | |||||
| opr->cname(), opr->dyn_typeinfo()->name, | |||||
| new_opr->cname(), new_opr->dyn_typeinfo()->name, | |||||
| out0.size(), out1.size()); | |||||
| for (size_t i = 0; i < out0.size(); ++i) { | |||||
| if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||||
| mgb_assert(!out1[i]->contain_flag( | |||||
| VarNode::Flag::VOLATILE_CONTENT)); | |||||
| auto src = out0[i]; | |||||
| auto dst = out1[i]; | |||||
| if (opt.graph().endpoint_contain(src) && | |||||
| !src->shape().eq_shape(dst->shape())) { | |||||
| dst = extract_subtensor(dst, src->shape()); | |||||
| } | |||||
| rewriter.replace_var(src, dst, nullptr); | |||||
| } | |||||
| } | |||||
| } else { | |||||
| rewriter.auto_replace_outputs(opr); | |||||
| } | |||||
| }; | |||||
| opt.graph().iter(on_opr); | |||||
| rewriter.apply_inplace(); | |||||
| MIDOUT_E | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -11,7 +11,6 @@ | |||||
| */ | */ | ||||
| #include "megbrain/gopt/reformat_manager.h" | #include "megbrain/gopt/reformat_manager.h" | ||||
| #include <numeric> | |||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| using namespace mgb; | using namespace mgb; | ||||
| @@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) { | |||||
| return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; | return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; | ||||
| case TensorFormats::KRSCk8: | case TensorFormats::KRSCk8: | ||||
| return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; | return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; | ||||
| case TensorFormats::KCRSc4: | |||||
| return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||||
| case TensorFormats::GKCRSc4: | |||||
| return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||||
| case TensorFormats::KCRS: | case TensorFormats::KCRS: | ||||
| return {{"K"}, {"C"}, {"R"}, {"S"}}; | return {{"K"}, {"C"}, {"R"}, {"S"}}; | ||||
| case TensorFormats::GKCRS: | case TensorFormats::GKCRS: | ||||
| @@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()( | |||||
| lhs.attribute == rhs.attribute; | lhs.attribute == rhs.attribute; | ||||
| } | } | ||||
| ReformatManager::ReformatKey& | |||||
| ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { | |||||
| static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = { | |||||
| {TensorFormats::NCHW, TensorFormats::NCHWc64}, | |||||
| {TensorFormats::NCHWc64, TensorFormats::NCHW}, | |||||
| {TensorFormats::NCHW, TensorFormats::NHWC}, | |||||
| {TensorFormats::NHWC, TensorFormats::NCHW}}; | |||||
| if (set.count({input_format, output_format}) > 0 && | |||||
| (dt.enumv() == DTypeEnum::QuantizedS4 || | |||||
| dt.enumv() == DTypeEnum::Quantized4Asymm)) { | |||||
| input_dtype = output_dtype = dt.enumv(); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| // =================== ReformatManager ====================*/ | // =================== ReformatManager ====================*/ | ||||
| #define FOREACH_FEATURE_TENSOR_FORMATS(cb) \ | |||||
| cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \ | |||||
| cb(NHCWc4) | |||||
| #define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \ | |||||
| cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \ | |||||
| cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \ | |||||
| cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8) | |||||
| ReformatManager::ReformatManager() { | ReformatManager::ReformatManager() { | ||||
| static constexpr TensorFormats feature_tensor_formats[] = { | |||||
| #define cb(_fmt) TensorFormats::_fmt, | |||||
| FOREACH_FEATURE_TENSOR_FORMATS(cb) | |||||
| #undef cb | |||||
| }; | |||||
| static constexpr int nr_feature_tensor_formats = | |||||
| sizeof(feature_tensor_formats) / sizeof(TensorFormats); | |||||
| for (int i = 0; i < nr_feature_tensor_formats; ++i) { | |||||
| for (int o = 0; o < nr_feature_tensor_formats; ++o) { | |||||
| if (i == o) | |||||
| continue; | |||||
| NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape( | |||||
| feature_tensor_formats[i]); | |||||
| NamedTensorShape output_shape = | |||||
| tensor_formats_to_named_tensor_shape( | |||||
| feature_tensor_formats[o]); | |||||
| auto impl = std::get<0>( | |||||
| ReformatEmitter{input_shape, output_shape}.emit()); | |||||
| m_cache.emplace(ReformatKey{feature_tensor_formats[i], | |||||
| feature_tensor_formats[o]}, | |||||
| impl); | |||||
| } | |||||
| } | |||||
| static constexpr TensorFormats default_weight_tensor_formats = | |||||
| TensorFormats::KCRS; | |||||
| static constexpr TensorFormats default_group_conv_weight_tensor_formats = | |||||
| TensorFormats::GKCRS; | |||||
| static constexpr TensorFormats default_chan_conv_weight_tensor_formats = | |||||
| TensorFormats::C11RS; | |||||
| static constexpr TensorFormats weight_tensor_formats[] = { | |||||
| #define cb(_fmt) TensorFormats::_fmt, | |||||
| FOREACH_WEIGHT_TENSOR_FORMATS(cb) | |||||
| #undef cb | |||||
| }; | |||||
| static constexpr int nr_weight_tensor_formats = | |||||
| sizeof(weight_tensor_formats) / sizeof(TensorFormats); | |||||
| using Name = megdnn::Dimension::Name; | |||||
| for (int o = 0; o < nr_weight_tensor_formats; ++o) { | |||||
| NamedTensorShape output_shape = | |||||
| tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]); | |||||
| TensorFormats input_format; | |||||
| if (output_shape[0].name() == Name::G) { | |||||
| input_format = default_group_conv_weight_tensor_formats; | |||||
| } else if (output_shape[0].name() == Name::C) { | |||||
| input_format = default_chan_conv_weight_tensor_formats; | |||||
| } else { | |||||
| mgb_assert(output_shape[0].name() == Name::K); | |||||
| input_format = default_weight_tensor_formats; | |||||
| } | |||||
| NamedTensorShape input_shape = | |||||
| tensor_formats_to_named_tensor_shape(input_format); | |||||
| auto impl = | |||||
| std::get<0>(ReformatEmitter{input_shape, output_shape}.emit()); | |||||
| m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]}, | |||||
| impl); | |||||
| using Attribute = ReformatKey::Attribute; | |||||
| { | |||||
| auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4; | |||||
| auto&& impl1 = [](const VarNodeArray& vars) { | |||||
| return opr::RelayoutFormat::make( | |||||
| vars[0], | |||||
| megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4) | |||||
| .node(); | |||||
| }; | |||||
| m_cache.emplace(ReformatKey{i, o}, impl1); | |||||
| auto&& impl2 = [](const VarNodeArray& vars) { | |||||
| return opr::RelayoutFormat::make( | |||||
| vars[0], | |||||
| megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4) | |||||
| .node(); | |||||
| }; | |||||
| m_cache.emplace(ReformatKey{o, i}, impl2); | |||||
| } | } | ||||
| { | { | ||||
| auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; | auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; | ||||
| @@ -206,7 +179,7 @@ ReformatManager::ReformatManager() { | |||||
| m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); | m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); | ||||
| } | } | ||||
| { | { | ||||
| auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4; | |||||
| auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4; | |||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
| vars[0], | vars[0], | ||||
| @@ -238,7 +211,7 @@ ReformatManager::ReformatManager() { | |||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
| vars[0], | vars[0], | ||||
| megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) | |||||
| megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW) | |||||
| .node(); | .node(); | ||||
| }; | }; | ||||
| m_cache.emplace( | m_cache.emplace( | ||||
| @@ -272,7 +245,7 @@ ReformatManager::ReformatManager() { | |||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
| vars[0], | vars[0], | ||||
| megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) | |||||
| megdnn::param::RelayoutFormat::Mode::NHWC_NCHW) | |||||
| .node(); | .node(); | ||||
| }; | }; | ||||
| m_cache.emplace( | m_cache.emplace( | ||||
| @@ -371,14 +344,23 @@ ReformatManager::ReformatManager() { | |||||
| impl); | impl); | ||||
| } | } | ||||
| } | } | ||||
| #undef FOREACH_FEATURE_TENSOR_FORMATS | |||||
| #undef FOREACH_WEIGHT_TENSOR_FORMATS | |||||
| const ReformatManager::ReformatImpl& ReformatManager::get( | |||||
| ReformatManager::ReformatImpl ReformatManager::get( | |||||
| const ReformatKey& key) const { | const ReformatKey& key) const { | ||||
| using Attribute = ReformatKey::Attribute; | |||||
| MGB_TRY { | MGB_TRY { | ||||
| auto&& impl = m_cache.at(key); | |||||
| return impl; | |||||
| auto find = m_cache.find(key); | |||||
| if (find != m_cache.end()) { | |||||
| auto rst = find->second; | |||||
| return rst; | |||||
| } | |||||
| mgb_assert(key.attribute == Attribute::DEFAULT); | |||||
| auto&& i = key.input_format; | |||||
| auto&& o = key.output_format; | |||||
| auto ishp = tensor_formats_to_named_tensor_shape(i); | |||||
| auto oshp = tensor_formats_to_named_tensor_shape(o); | |||||
| auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit()); | |||||
| return builder; | |||||
| } | } | ||||
| MGB_CATCH(std::exception & exc, { | MGB_CATCH(std::exception & exc, { | ||||
| mgb_log_error( | mgb_log_error( | ||||
| @@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get( | |||||
| } | } | ||||
| const ReformatManager& ReformatManager::instance() { | const ReformatManager& ReformatManager::instance() { | ||||
| static ReformatManager* inst = nullptr; | |||||
| if (inst == nullptr) { | |||||
| inst = new ReformatManager(); | |||||
| } | |||||
| return *inst; | |||||
| static ReformatManager inst; | |||||
| return inst; | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -227,6 +227,7 @@ namespace gopt { | |||||
| VarReplaceCheckFlag m_var_replace_check_flag = | VarReplaceCheckFlag m_var_replace_check_flag = | ||||
| VarReplaceCheckFlag::CHECK_ALL; | VarReplaceCheckFlag::CHECK_ALL; | ||||
| class RelayoutPlaceholder; | class RelayoutPlaceholder; | ||||
| friend class ShuffleShuffleRemovePass; | |||||
| public: | public: | ||||
| TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { | TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { | ||||
| @@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t { | |||||
| KRSCk8 = 21, ///< [K/8, R, S, C, K%8] | KRSCk8 = 21, ///< [K/8, R, S, C, K%8] | ||||
| // NCHW4 | |||||
| KCRSc4 = 22, ///< [K, C/4, R, S, C%4] | |||||
| GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4] | |||||
| // default weight format | // default weight format | ||||
| KCRS = 22, ///< [K, C, R, S] | |||||
| GKCRS = 23, ///< [G, K, C, R, S] | |||||
| C11RS = 24, ///< [C, 1, 1, R, S] | |||||
| KCRS = 24, ///< [K, C, R, S] | |||||
| GKCRS = 25, ///< [G, K, C, R, S] | |||||
| C11RS = 26, ///< [C, 1, 1, R, S] | |||||
| }; | }; | ||||
| class ReformatManager : public NonCopyableObj { | class ReformatManager : public NonCopyableObj { | ||||
| @@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj { | |||||
| public: | public: | ||||
| using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; | using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; | ||||
| enum class Attribute : uint32_t { | |||||
| DEFAULT = 0, | |||||
| IMAGE2D = 1 << 0, | |||||
| IC_SMALL = 1 << 1, | |||||
| }; | |||||
| struct ReformatKey { | struct ReformatKey { | ||||
| enum class Attribute : uint32_t { | |||||
| DEFAULT = 0, | |||||
| IMAGE2D = 1 << 0, | |||||
| IC_SMALL = 1 << 1, | |||||
| }; | |||||
| TensorFormats input_format, output_format; | TensorFormats input_format, output_format; | ||||
| DTypeEnum input_dtype, output_dtype; | DTypeEnum input_dtype, output_dtype; | ||||
| Attribute attribute; | Attribute attribute; | ||||
| std::string to_string() const; | std::string to_string() const; | ||||
| ReformatKey() | |||||
| : input_dtype{DTypeEnum::Float32}, | |||||
| output_dtype{DTypeEnum::Float32}, | |||||
| attribute{Attribute::DEFAULT} {} | |||||
| ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | ||||
| Attribute attribute_ = Attribute::DEFAULT, | Attribute attribute_ = Attribute::DEFAULT, | ||||
| DTypeEnum input_dtype_ = DTypeEnum::Float32, | DTypeEnum input_dtype_ = DTypeEnum::Float32, | ||||
| @@ -86,11 +94,13 @@ public: | |||||
| bool operator()(const ReformatKey& lhs, | bool operator()(const ReformatKey& lhs, | ||||
| const ReformatKey& rhs) const; | const ReformatKey& rhs) const; | ||||
| }; | }; | ||||
| ReformatKey& deduce_reformat_dtype_enum(const DType& dt); | |||||
| }; | }; | ||||
| using ReformatCache = | using ReformatCache = | ||||
| std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | ||||
| ReformatKey::Equal>; | ReformatKey::Equal>; | ||||
| const ReformatImpl& get(const ReformatKey& key) const; | |||||
| ReformatImpl get(const ReformatKey& key) const; | |||||
| ReformatImpl get(ReformatKey&& key) const { return get(key); } | |||||
| static const ReformatManager& instance(); | static const ReformatManager& instance(); | ||||
| private: | private: | ||||
| @@ -0,0 +1,171 @@ | |||||
| /** | |||||
| * \file src/gopt/test/reformat_manager.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "./helper.h" | |||||
| #include "megbrain/gopt/reformat_manager.h" | |||||
| #include "megbrain/opr/tensor_manip.h" | |||||
| using namespace mgb; | |||||
| using namespace gopt; | |||||
| TEST(TestReformatManager, Feature) { | |||||
| constexpr size_t N = 16, C = 128, H = 7, W = 7; | |||||
| HostTensorGenerator<> gen; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64; | |||||
| ReformatKey key{src_format, dst_format}; | |||||
| auto reformat = ReformatManager::instance().get(key); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto r = [](VarNode* inp) { | |||||
| auto x = SymbolVar(inp); | |||||
| auto xshp = opr::GetVarShape::make(x); | |||||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
| auto sub = [&xshp, &cv](int idx) { | |||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
| }; | |||||
| auto tshp0 = opr::Concat::make( | |||||
| {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); | |||||
| auto y0 = opr::Reshape::make(x, tshp0); | |||||
| auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); | |||||
| return y1; | |||||
| }; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
| }; | |||||
| auto x = mkvar("x", {N, H, W, C}); | |||||
| auto y1 = SymbolVar(reformat({x.node()})); | |||||
| auto y2 = r(x.node()); | |||||
| size_t nr_shapeof = 0; | |||||
| size_t nr_reshape = 0; | |||||
| cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||||
| if (o->same_type<opr::GetVarShape>()) | |||||
| nr_shapeof++; | |||||
| if (o->same_type<opr::Reshape>()) | |||||
| nr_reshape++; | |||||
| }} | |||||
| .add(y1.node()->owner_opr()); | |||||
| ASSERT_EQ(nr_shapeof, 1); | |||||
| ASSERT_EQ(nr_reshape, 1); | |||||
| HostTensorND t1, t2; | |||||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
| func1->execute(); | |||||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
| func2->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| TEST(TestReformatManager, Weight) { | |||||
| constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3; | |||||
| HostTensorGenerator<> gen; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| auto src_format = TensorFormats::GKCRS, | |||||
| dst_format = TensorFormats::GKCRSk4c4; | |||||
| ReformatKey key{src_format, dst_format}; | |||||
| auto reformat = ReformatManager::instance().get(key); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto r = [](VarNode* inp) { | |||||
| auto x = SymbolVar(inp); | |||||
| auto xshp = opr::GetVarShape::make(x); | |||||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
| auto sub = [&xshp, &cv](int idx) { | |||||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
| }; | |||||
| auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, | |||||
| cv(4), sub(3), sub(4)}, | |||||
| 0), | |||||
| tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), | |||||
| sub(4), cv(4), cv(4)}, | |||||
| 0); | |||||
| auto y0 = opr::Reshape::make(x, tshp0); | |||||
| auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); | |||||
| auto y2 = opr::Reshape::make(y1, tshp1); | |||||
| return y2; | |||||
| }; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
| }; | |||||
| auto w = mkvar("w", {G, K / G, C / G, R, S}); | |||||
| auto y1 = SymbolVar(reformat({w.node()})); | |||||
| auto y2 = r(w.node()); | |||||
| size_t nr_shapeof = 0; | |||||
| size_t nr_reshape = 0; | |||||
| cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||||
| if (o->same_type<opr::GetVarShape>()) | |||||
| nr_shapeof++; | |||||
| if (o->same_type<opr::Reshape>()) | |||||
| nr_reshape++; | |||||
| }} | |||||
| .add(y1.node()->owner_opr()); | |||||
| ASSERT_EQ(nr_shapeof, 1); | |||||
| ASSERT_EQ(nr_reshape, 1); | |||||
| HostTensorND t1, t2; | |||||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
| func1->execute(); | |||||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
| func2->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| TEST(TestReformatManager, InvalidKey) { | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| using Attribute = ReformatKey::Attribute; | |||||
| auto src_format = TensorFormats::GKCRS, | |||||
| dst_format = TensorFormats::GKCRSk4c4; | |||||
| Attribute attribute = Attribute::IMAGE2D; | |||||
| ReformatKey key{src_format, dst_format, attribute}; | |||||
| ASSERT_THROW(ReformatManager::instance().get(key), AssertionError); | |||||
| } | |||||
| TEST(TestReformatManager, InputChannelSmall) { | |||||
| constexpr size_t N = 16, C = 3, H = 224, W = 224; | |||||
| auto cn = CompNode::load("cpux"); | |||||
| HostTensorGenerator<> gen; | |||||
| using ReformatKey = ReformatManager::ReformatKey; | |||||
| using Attribute = ReformatKey::Attribute; | |||||
| auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4; | |||||
| ReformatKey key{src_format, dst_format, Attribute::IC_SMALL}; | |||||
| auto reformat = ReformatManager::instance().get(key); | |||||
| auto graph = ComputingGraph::make(); | |||||
| graph->options().graph_opt_level = 0; | |||||
| auto r = [](VarNode* inp) { | |||||
| auto x = SymbolVar(inp); | |||||
| auto y = opr::RelayoutFormat::make( | |||||
| x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); | |||||
| return y; | |||||
| }; | |||||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
| }; | |||||
| auto x = mkvar("x", {N, C, H, W}); | |||||
| auto y1 = SymbolVar(reformat({x.node()})); | |||||
| auto y2 = r(x.node()); | |||||
| HostTensorND t1, t2; | |||||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
| func1->execute(); | |||||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
| func2->execute(); | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||