GitOrigin-RevId: a1b1e89b76
tags/v1.6.0
| @@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const { | |||
| static_cast<char>(m_name), static_cast<char>(rhs.m_name)); | |||
| if (operator==(rhs)) | |||
| return Dimension(m_name, 1, 1); | |||
| megdnn_assert( | |||
| !(*this < rhs), | |||
| "Divisor must be smaller than dividend(dividend:%s, divisor:%s)", | |||
| to_string().c_str(), rhs.to_string().c_str()); | |||
| if (m_stride == rhs.m_stride) { | |||
| if (m_extent == UNDETERMINED_EXTENT) { | |||
| megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, | |||
| @@ -0,0 +1,431 @@ | |||
| /** | |||
| * \file src/gopt/impl/folding_conv_dimshuffle.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/utils/hash_ct.h" | |||
| #include "midout.h" | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| #if CUDA_VERSION >= 10020 | |||
| MIDOUT_DECL(megbrain_folding_conv_dimshuffle) | |||
| #define MIDOUT_B(tag) \ | |||
| MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \ | |||
| midout_iv(MGB_HASH_STR(tag))) { | |||
| #define MIDOUT_E \ | |||
| } \ | |||
| MIDOUT_END(); | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| /* ==================== FoldingConvBiasDimshufflePass ================= */ | |||
| const char* FoldingConvBiasDimshufflePass::name() const { | |||
| return mgb_cstr_log("folding conv bias dimshuffle pass"); | |||
| } | |||
| void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||
| MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); | |||
| using DepType = cg::OperatorNodeProp::DepType; | |||
| ThinHashMap<OperatorNodeBase*, | |||
| SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||
| readers; | |||
| static const ThinHashSet<Typeinfo*> opr_type_list = { | |||
| opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), | |||
| opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; | |||
| opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||
| for (auto&& i : opr->node_prop().dep_map()) { | |||
| if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||
| readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||
| } | |||
| } | |||
| }); | |||
| auto rewriter = opt.graph().make_rewriter(); | |||
| auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers]( | |||
| OperatorNodeBase* opr) { | |||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||
| // check typecvt | |||
| auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||
| if (typecvt == nullptr) | |||
| return false; | |||
| auto inp_dtype = typecvt->input(0)->dtype(), | |||
| out_dtype = typecvt->output(0)->dtype(); | |||
| bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| out_dtype.enumv() == DTypeEnum::Float32; | |||
| if (!is_s82f32) | |||
| return false; | |||
| opr_set.insert(opr); | |||
| // check reshape | |||
| auto reshape = | |||
| try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr()); | |||
| if (reshape == nullptr) | |||
| return false; | |||
| opr_set.insert(reshape); | |||
| for (auto&& i : readers[reshape]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check shuffle | |||
| auto shuffle = | |||
| try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||
| if (shuffle == nullptr) | |||
| return false; | |||
| auto&& param = shuffle->param(); | |||
| if (param.pattern_len != 5) | |||
| return false; | |||
| bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
| param.pattern[2] == 4 && param.pattern[3] == 2 && | |||
| param.pattern[4] == 3 && | |||
| shuffle->input(0)->shape()[4] == 4; | |||
| if (!is_nchw42nchw) | |||
| return false; | |||
| opr_set.insert(shuffle); | |||
| for (auto&& i : readers[shuffle]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check conv bias | |||
| auto conv_bias = | |||
| try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr()); | |||
| if (conv_bias == nullptr) | |||
| return false; | |||
| inp_dtype = conv_bias->input(0)->dtype(); | |||
| bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| conv_bias->param().format == | |||
| megdnn::param::ConvBias::Format::NCHW4; | |||
| if (!is_s8nchw4) | |||
| return false; | |||
| if (conv_bias->input().size() != 3) | |||
| return false; | |||
| opr_set.insert(conv_bias); | |||
| for (auto&& i : readers[conv_bias]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| for (auto reader : reader_set) { | |||
| if (opr_set.count(reader) <= 0) { | |||
| return false; | |||
| } | |||
| } | |||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||
| filter = rewriter.get_var(conv_bias->input(1)), | |||
| bias = rewriter.get_var(conv_bias->input(2)); | |||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
| TensorFormats::NCHWc4, TensorFormats::NCHW})({bias}); | |||
| new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node(); | |||
| auto new_param = conv_bias->param(); | |||
| new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; | |||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
| OperatorNodeConfig{dtype::Float32()}); | |||
| rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||
| mgb_cstr_log("replace conv_bias + typecvt + " | |||
| "dimshuffle + " | |||
| "reshape to conv_bias(NCHW4_NCHW)")); | |||
| return true; | |||
| }; | |||
| auto try_conv_reformat_nchw42nchw32 = [&rewriter, | |||
| &readers](OperatorNodeBase* opr) { | |||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||
| // check reshape | |||
| auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||
| if (reshape1 == nullptr) | |||
| return false; | |||
| opr_set.insert(opr); | |||
| // check dimshuffle | |||
| auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||
| reshape1->input(0)->owner_opr()); | |||
| if (shuffle == nullptr) | |||
| return false; | |||
| auto&& param = shuffle->param(); | |||
| if (param.pattern_len != 6) | |||
| return false; | |||
| bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
| param.pattern[2] == 3 && param.pattern[3] == 4 && | |||
| param.pattern[4] == 2 && param.pattern[5] == 5 && | |||
| shuffle->output(0)->shape()[5] == 4 && | |||
| shuffle->output(0)->shape()[4] == 8; | |||
| if (!is_nchw42nchw32) | |||
| return false; | |||
| opr_set.insert(shuffle); | |||
| for (auto&& i : readers[shuffle]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check reshape | |||
| auto reshape2 = | |||
| try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||
| if (reshape2 == nullptr) | |||
| return false; | |||
| opr_set.insert(reshape2); | |||
| for (auto&& i : readers[reshape2]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check conv bias | |||
| auto conv_bias = | |||
| try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||
| if (conv_bias == nullptr) | |||
| return false; | |||
| auto inp_dtype = conv_bias->input(0)->dtype(); | |||
| bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| conv_bias->param().format == | |||
| megdnn::param::ConvBias::Format::NCHW4; | |||
| if (!is_s8nchw4) | |||
| return false; | |||
| if (conv_bias->input().size() != 3) | |||
| return false; | |||
| opr_set.insert(conv_bias); | |||
| for (auto&& i : readers[conv_bias]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| for (auto reader : reader_set) { | |||
| if (opr_set.count(reader) <= 0) { | |||
| return false; | |||
| } | |||
| } | |||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||
| filter = rewriter.get_var(conv_bias->input(1)), | |||
| bias = rewriter.get_var(conv_bias->input(2)); | |||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
| TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias}); | |||
| auto new_param = conv_bias->param(); | |||
| new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; | |||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
| conv_bias->config()); | |||
| rewriter.replace_var( | |||
| opr->output(0), conv_bias_shuffle.node(), | |||
| mgb_cstr_log("replace conv_bias + " | |||
| "reformat to conv_bias(NCHW4_NCHW32)")); | |||
| return true; | |||
| }; | |||
| auto try_conv_reformat_nchw42nhwc = [&rewriter, | |||
| &readers](OperatorNodeBase* opr) { | |||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||
| // check reshape | |||
| auto reshape = try_cast_as_op<opr::Reshape>(opr); | |||
| if (reshape == nullptr) | |||
| return false; | |||
| opr_set.insert(opr); | |||
| // check dimshuffle | |||
| auto shuffle = | |||
| try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||
| if (shuffle == nullptr) | |||
| return false; | |||
| auto&& param = shuffle->param(); | |||
| if (param.pattern_len != 5) | |||
| return false; | |||
| bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && | |||
| param.pattern[2] == 3 && param.pattern[3] == 1 && | |||
| param.pattern[4] == 4 && | |||
| shuffle->output(0)->shape()[4] == 4; | |||
| if (!is_nchw42nhwc) | |||
| return false; | |||
| opr_set.insert(shuffle); | |||
| for (auto&& i : readers[shuffle]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| auto typecvt = | |||
| try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr()); | |||
| if (typecvt == nullptr) | |||
| return false; | |||
| auto in_dtype = typecvt->input(0)->dtype(), | |||
| out_dtype = typecvt->output(0)->dtype(); | |||
| bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| (out_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
| out_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
| if (!is_s82s4) | |||
| return false; | |||
| opr_set.insert(typecvt); | |||
| for (auto&& i : readers[typecvt]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check conv bias | |||
| auto conv_bias = | |||
| try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||
| if (conv_bias == nullptr) | |||
| return false; | |||
| auto inp_dtype = conv_bias->input(0)->dtype(); | |||
| bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| conv_bias->param().format == | |||
| megdnn::param::ConvBias::Format::NCHW4; | |||
| if (!is_s8nchw4) | |||
| return false; | |||
| if (conv_bias->input().size() != 3) | |||
| return false; | |||
| opr_set.insert(conv_bias); | |||
| for (auto&& i : readers[conv_bias]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| for (auto reader : reader_set) { | |||
| if (opr_set.count(reader) <= 0) { | |||
| return false; | |||
| } | |||
| } | |||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||
| filter = rewriter.get_var(conv_bias->input(1)), | |||
| bias = rewriter.get_var(conv_bias->input(2)); | |||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
| TensorFormats::NCHWc4, TensorFormats::NHWC})({bias}); | |||
| auto new_param = conv_bias->param(); | |||
| new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; | |||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
| OperatorNodeConfig{out_dtype}); | |||
| rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||
| mgb_cstr_log("replace conv_bias + " | |||
| "reformat to conv_bias(NCHW4_NHWC)")); | |||
| return true; | |||
| }; | |||
| auto try_conv_reformat_nchw322nchw4 = [&rewriter, | |||
| &readers](OperatorNodeBase* opr) { | |||
| ThinHashSet<OperatorNodeBase*> opr_set; | |||
| ThinHashSet<OperatorNodeBase*> reader_set; | |||
| // check reshape | |||
| auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||
| if (reshape1 == nullptr) | |||
| return false; | |||
| opr_set.insert(opr); | |||
| // check dimshuffle | |||
| auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||
| reshape1->input(0)->owner_opr()); | |||
| if (shuffle == nullptr) | |||
| return false; | |||
| auto&& param = shuffle->param(); | |||
| if (param.pattern_len != 6) | |||
| return false; | |||
| bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
| param.pattern[2] == 4 && param.pattern[3] == 2 && | |||
| param.pattern[4] == 3 && param.pattern[5] == 5 && | |||
| shuffle->input(0)->shape()[5] == 4 && | |||
| shuffle->input(0)->shape()[4] == 8; | |||
| if (!is_nchw322nchw4) | |||
| return false; | |||
| opr_set.insert(shuffle); | |||
| for (auto&& i : readers[shuffle]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check reshape | |||
| auto reshape2 = | |||
| try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||
| if (reshape2 == nullptr) | |||
| return false; | |||
| opr_set.insert(reshape2); | |||
| for (auto&& i : readers[reshape2]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| // check conv bias | |||
| auto conv_bias = | |||
| try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||
| if (conv_bias == nullptr) | |||
| return false; | |||
| auto inp_dtype = conv_bias->input(0)->dtype(); | |||
| bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| conv_bias->param().format == | |||
| megdnn::param::ConvBias::Format::NCHW32; | |||
| if (!is_s8nchw32) | |||
| return false; | |||
| if (conv_bias->input().size() != 3) | |||
| return false; | |||
| opr_set.insert(conv_bias); | |||
| for (auto&& i : readers[conv_bias]) { | |||
| if (i.second & DepType::DEV_VALUE) { | |||
| reader_set.insert(i.first); | |||
| } | |||
| } | |||
| for (auto reader : reader_set) { | |||
| if (opr_set.count(reader) <= 0) { | |||
| return false; | |||
| } | |||
| } | |||
| auto src = rewriter.get_var(conv_bias->input(0)), | |||
| filter = rewriter.get_var(conv_bias->input(1)), | |||
| bias = rewriter.get_var(conv_bias->input(2)); | |||
| auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
| TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias}); | |||
| auto new_param = conv_bias->param(); | |||
| new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; | |||
| auto conv_bias_shuffle = opr::ConvBias::make( | |||
| src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
| conv_bias->config()); | |||
| rewriter.replace_var( | |||
| opr->output(0), conv_bias_shuffle.node(), | |||
| mgb_cstr_log("replace conv_bias + " | |||
| "reformat to conv_bias(NCHW32_NCHW4)")); | |||
| return true; | |||
| }; | |||
| MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); | |||
| MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32); | |||
| auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | |||
| &try_conv_reformat_nchw42nchw32, | |||
| &try_conv_reformat_nchw42nhwc, | |||
| &try_conv_reformat_nchw322nchw4, | |||
| &rewriter](OperatorNodeBase* opr) { | |||
| if (!try_conv_dimshuffle_reshape_typecvt(opr) && | |||
| !try_conv_reformat_nchw42nchw32(opr) && | |||
| !try_conv_reformat_nchw42nhwc(opr) && | |||
| !try_conv_reformat_nchw322nchw4(opr)) { | |||
| rewriter.auto_replace_outputs(opr); | |||
| } | |||
| }; | |||
| opt.graph().iter(on_opr); | |||
| rewriter.apply_inplace(); | |||
| MIDOUT_E | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,451 @@ | |||
| /** | |||
| * \file src/gopt/impl/padding_channel.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/gopt/inference.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/misc.h" | |||
| #include "megbrain/opr/nn_int.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/serialization/opr_shallow_copy.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/tensor_format.h" | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megbrain/gopt/misc.h" | |||
| #include "megbrain/utils/hash_ct.h" | |||
| #include "midout.h" | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| MIDOUT_DECL(megbrain_padding_channel) | |||
| #define MIDOUT_B(tag) \ | |||
| MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) { | |||
| #define MIDOUT_E \ | |||
| } \ | |||
| MIDOUT_END(); | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| /* ==================== PaddingChannelPass ================= */ | |||
| const char* PaddingChannelPass::name() const { | |||
| return mgb_cstr_log("padding output channel to multiple of 4/32"); | |||
| } | |||
| void PaddingChannelPass::apply(OptState& opt) const { | |||
| MIDOUT_B("PaddingChannelPass::apply"); | |||
| // do not check shape | |||
| opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | |||
| VarReplaceCheckFlag::CHECK_SHAPE); | |||
| ThinHashSet<OperatorNodeBase*> padding_oprs; | |||
| ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*( | |||
| OperatorNodeBase*, const VarNodeArray&)>> | |||
| opr_replace_funcs; | |||
| auto rewriter = opt.graph().make_rewriter(); | |||
| auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||
| mgb_assert(inp->shape().ndim == 4); | |||
| mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||
| inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||
| inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||
| TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], | |||
| inp->shape()[3]}; | |||
| std::shared_ptr<HostTensorND> host_val = | |||
| std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||
| host_val->resize(shape); | |||
| auto ptr = host_val->raw_ptr(); | |||
| size_t size_bytes = | |||
| TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||
| std::memset(ptr, 0, size_bytes); | |||
| auto padding = | |||
| opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||
| auto out = opr::Concat::make({inp, padding}, 1); | |||
| return out.node(); | |||
| }; | |||
| auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||
| mgb_assert(inp->shape().ndim == 4); | |||
| mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||
| inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||
| inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||
| TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], | |||
| inp->shape()[3]}; | |||
| std::shared_ptr<HostTensorND> host_val = | |||
| std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||
| host_val->resize(shape); | |||
| auto ptr = host_val->raw_ptr(); | |||
| size_t size_bytes = | |||
| TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||
| std::memset(ptr, 0, size_bytes); | |||
| auto padding = | |||
| opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||
| auto out = opr::Concat::make({inp, padding}, 0); | |||
| return out.node(); | |||
| }; | |||
| auto extract_subtensor = [](VarNode* inp, | |||
| const TensorShape& orig_shape) -> VarNode* { | |||
| mgb_assert(inp->shape().ndim == 4); | |||
| mgb_assert(inp->shape()[0] == orig_shape[0]); | |||
| mgb_assert(inp->shape()[2] == orig_shape[2]); | |||
| mgb_assert(inp->shape()[3] == orig_shape[3]); | |||
| size_t orig_channels = orig_shape[1]; | |||
| auto x = SymbolVar(inp); | |||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
| using AIdx = opr::Subtensor::AxisIndexer; | |||
| auto sub = opr::Subtensor::make( | |||
| x, {AIdx::make_interval(0, None, None, cv(1)), | |||
| AIdx::make_interval(1, None, cv(orig_channels), None), | |||
| AIdx::make_interval(2, None, None, cv(1)), | |||
| AIdx::make_interval(3, None, None, cv(1))}); | |||
| return sub.node(); | |||
| }; | |||
| // padding policy for conv bias with data type qint8 | |||
| auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, | |||
| &pad_out_channels]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| mgb_assert(new_inp.size() == 3); | |||
| mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||
| auto inps = new_inp; | |||
| size_t out_channels = opr->input(1)->shape()[0]; | |||
| size_t in_channels = opr->input(1)->shape()[1]; | |||
| size_t new_in_channels = new_inp[0]->shape()[1]; | |||
| // pad input channels | |||
| if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||
| size_t pad_channels = new_in_channels - in_channels; | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
| } else { | |||
| size_t pad_channels = 0; | |||
| mgb_assert(new_in_channels == in_channels); | |||
| if (in_channels <= 16) { | |||
| if (in_channels % 4) | |||
| pad_channels = 4 - (in_channels % 4); // pad to use dp4a | |||
| } else { | |||
| if (in_channels % 32) | |||
| pad_channels = | |||
| 32 - (in_channels % 32); // pad to use tensorcore | |||
| } | |||
| if (pad_channels > 0) { | |||
| inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
| } | |||
| } | |||
| out_channels = inps[1]->shape()[0]; | |||
| in_channels = inps[1]->shape()[1]; | |||
| size_t pad_channels = 0; | |||
| if (out_channels <= 16) { | |||
| if (out_channels % 4) | |||
| pad_channels = 4 - (out_channels % 4); | |||
| } else { | |||
| if (out_channels % 32) | |||
| pad_channels = 32 - (out_channels % 32); | |||
| } | |||
| if (pad_channels > 0) { | |||
| inps[1] = pad_out_channels(inps[1], pad_channels); | |||
| inps[2] = pad_in_channels(inps[2], pad_channels); | |||
| padding_oprs.insert(opr); | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
| }; | |||
| // padding policy for conv bias with data type qint4 and quint4 | |||
| auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, | |||
| &pad_out_channels]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| mgb_assert(new_inp.size() == 3); | |||
| mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||
| auto inps = new_inp; | |||
| size_t out_channels = opr->input(1)->shape()[0]; | |||
| size_t in_channels = opr->input(1)->shape()[1]; | |||
| size_t new_in_channels = new_inp[0]->shape()[1]; | |||
| // pad input channels | |||
| if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||
| if (new_in_channels <= 32) { | |||
| if (new_in_channels % 8 == 0) { | |||
| size_t pad_channels = new_in_channels - in_channels; | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
| } else { | |||
| size_t pad_channels_0 = 8 - (new_in_channels % 8); | |||
| size_t pad_channels_1 = 8 - (in_channels % 8); | |||
| inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||
| } | |||
| } else { | |||
| if (new_in_channels % 64 == 0) { | |||
| size_t pad_channels = new_in_channels - in_channels; | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
| } else { | |||
| size_t pad_channels_0 = 64 - (new_in_channels % 64); | |||
| size_t pad_channels_1 = 64 - (in_channels % 64); | |||
| inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||
| } | |||
| } | |||
| } else { | |||
| size_t pad_channels = 0; | |||
| mgb_assert(new_in_channels == in_channels); | |||
| if (in_channels <= 32) { | |||
| if (in_channels % 8) | |||
| pad_channels = 8 - (in_channels % 8); | |||
| } else { | |||
| if (in_channels % 64) | |||
| pad_channels = 64 - (in_channels % 64); | |||
| } | |||
| if (pad_channels > 0) { | |||
| inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
| } | |||
| } | |||
| out_channels = inps[1]->shape()[0]; | |||
| in_channels = inps[1]->shape()[1]; | |||
| size_t pad_channels = 0; | |||
| if (out_channels <= 32) { | |||
| if (out_channels % 8) | |||
| pad_channels = 8 - (out_channels % 8); | |||
| } else { | |||
| if (out_channels % 64) | |||
| pad_channels = 64 - (out_channels % 64); | |||
| } | |||
| if (pad_channels > 0) { | |||
| inps[1] = pad_out_channels(inps[1], pad_channels); | |||
| inps[2] = pad_in_channels(inps[2], pad_channels); | |||
| padding_oprs.insert(opr); | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
| }; | |||
| opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = | |||
| [&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( | |||
| OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||
| if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { | |||
| return padding_policy_qint8(opr, new_inp); | |||
| } else if (opr->input(0)->dtype().enumv() == | |||
| DTypeEnum::QuantizedS4 || | |||
| opr->input(0)->dtype().enumv() == | |||
| DTypeEnum::Quantized4Asymm) { | |||
| return padding_policy_int4(opr, new_inp); | |||
| } else { | |||
| mgb_assert( | |||
| padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||
| "conv bias operator for data type(%s) cannot be " | |||
| "padded channel. " | |||
| "consumer(%s), producer(%s)", | |||
| opr->input(0)->dtype().name(), opr->cname(), | |||
| opr->input(0)->owner_opr()->cname()); | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| opr->config()); | |||
| } | |||
| }; | |||
| opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = | |||
| [&padding_oprs, &pad_in_channels, &pad_out_channels]( | |||
| OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||
| if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { | |||
| mgb_assert( | |||
| padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||
| "conv bwd data operator for data type(%s) cannot " | |||
| "be " | |||
| "padded channel. " | |||
| "consumer(%s), producer(%s)", | |||
| opr->input(0)->dtype().name(), opr->cname(), | |||
| opr->input(0)->owner_opr()->cname()); | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| opr->config()); | |||
| } | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| mgb_assert(new_inp.size() == 2, | |||
| "deconv (conv bwd data) operator for inference can " | |||
| "only have 2 input vars(got:%zu)", | |||
| new_inp.size()); | |||
| mgb_assert( | |||
| opr->input(0)->shape().eq_shape(new_inp[0]->shape())); | |||
| auto inps = new_inp; | |||
| size_t out_channels = opr->input(0)->shape()[0]; | |||
| size_t in_channels = opr->input(0)->shape()[1]; | |||
| size_t new_out_channels = new_inp[1]->shape()[1]; | |||
| // pad output channels | |||
| if (padding_oprs.count(opr->input(1)->owner_opr())) { | |||
| size_t pad_channels = new_out_channels - out_channels; | |||
| inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||
| } else { | |||
| size_t pad_channels = 0; | |||
| if (out_channels % 4) | |||
| pad_channels = 4 - (out_channels % 4); | |||
| if (pad_channels > 0) { | |||
| inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||
| inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
| } | |||
| } | |||
| out_channels = inps[0]->shape()[0]; | |||
| in_channels = inps[0]->shape()[1]; | |||
| // pad input channels | |||
| size_t pad_channels = 0; | |||
| if (in_channels % 4) | |||
| pad_channels = 4 - (in_channels % 4); | |||
| if (pad_channels > 0) { | |||
| inps[0] = pad_in_channels(inps[0], pad_channels); | |||
| padding_oprs.insert(opr); | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, inps, | |||
| opr->config()); | |||
| }; | |||
| auto replace_format_aware_opr = [&padding_oprs]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && | |||
| opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && | |||
| opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { | |||
| mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||
| "operator(type:%s,name:%s) for data type(%s) cannot be " | |||
| "padded channel. extra info:" | |||
| "consumer(%s), producer(%s)", | |||
| opr->dyn_typeinfo()->name, opr->cname(), | |||
| opr->input(0)->dtype().name(), opr->cname(), | |||
| opr->input(0)->owner_opr()->cname()); | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| opr->config()); | |||
| } | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||
| padding_oprs.insert(opr); | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||
| }; | |||
| opr_replace_funcs[opr::PoolingForward::typeinfo()] = | |||
| replace_format_aware_opr; | |||
| opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = | |||
| replace_format_aware_opr; | |||
| auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| bool have_padding_inp = false; | |||
| bool padding_all_inps = true; | |||
| bool same_padding = true; | |||
| size_t channels_after_padding = 0; | |||
| size_t i = 0; | |||
| for (auto&& cur_inp : opr->input()) { | |||
| bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||
| if (padding_cur_inp) { | |||
| if (!have_padding_inp) | |||
| have_padding_inp = true; | |||
| if (channels_after_padding == 0) { | |||
| channels_after_padding = new_inp[i]->shape()[1]; | |||
| } else { | |||
| same_padding = | |||
| channels_after_padding == new_inp[i]->shape()[1]; | |||
| } | |||
| } | |||
| if (padding_all_inps && (!padding_cur_inp || !same_padding)) | |||
| padding_all_inps = false; | |||
| ++i; | |||
| } | |||
| if (have_padding_inp && !padding_all_inps) { | |||
| auto inps = new_inp; | |||
| for (size_t i = 0; i < new_inp.size(); ++i) { | |||
| auto cur_inp = opr->input(i); | |||
| bool padding_cur_inp = | |||
| padding_oprs.count(cur_inp->owner_opr()) > 0; | |||
| if (padding_cur_inp) { | |||
| inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||
| } | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
| } | |||
| if (padding_all_inps) { | |||
| padding_oprs.insert(opr); | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||
| }; | |||
| opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = | |||
| replace_elemwise_like_opr; | |||
| opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; | |||
| opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; | |||
| auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( | |||
| OperatorNodeBase* opr, | |||
| const VarNodeArray& new_inp) { | |||
| mgb_assert(opr->input().size() == new_inp.size()); | |||
| auto inps = new_inp; | |||
| for (size_t i = 0; i < new_inp.size(); ++i) { | |||
| auto cur_inp = opr->input(i); | |||
| bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||
| if (padding_cur_inp) { | |||
| inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||
| } | |||
| } | |||
| return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
| }; | |||
| opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; | |||
| opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; | |||
| opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; | |||
| opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; | |||
| opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; | |||
| auto on_opr = [&opt, &rewriter, &opr_replace_funcs, | |||
| &extract_subtensor](OperatorNodeBase* opr) { | |||
| auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); | |||
| if (it != opr_replace_funcs.end()) { | |||
| VarNodeArray new_inp; | |||
| new_inp.reserve(opr->input().size()); | |||
| for (auto&& inp : opr->input()) { | |||
| new_inp.push_back(rewriter.get_var(inp)); | |||
| } | |||
| auto new_opr = (it->second)(opr, new_inp); | |||
| auto &&out0 = opr->output(), &&out1 = new_opr->output(); | |||
| mgb_assert(out0.size() == out1.size(), | |||
| "bad opr replace: src=%s{%s} dst=%s{%s}, " | |||
| "src.size=%zu " | |||
| "dst.size=%zu", | |||
| opr->cname(), opr->dyn_typeinfo()->name, | |||
| new_opr->cname(), new_opr->dyn_typeinfo()->name, | |||
| out0.size(), out1.size()); | |||
| for (size_t i = 0; i < out0.size(); ++i) { | |||
| if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
| mgb_assert(!out1[i]->contain_flag( | |||
| VarNode::Flag::VOLATILE_CONTENT)); | |||
| auto src = out0[i]; | |||
| auto dst = out1[i]; | |||
| if (opt.graph().endpoint_contain(src) && | |||
| !src->shape().eq_shape(dst->shape())) { | |||
| dst = extract_subtensor(dst, src->shape()); | |||
| } | |||
| rewriter.replace_var(src, dst, nullptr); | |||
| } | |||
| } | |||
| } else { | |||
| rewriter.auto_replace_outputs(opr); | |||
| } | |||
| }; | |||
| opt.graph().iter(on_opr); | |||
| rewriter.apply_inplace(); | |||
| MIDOUT_E | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -11,7 +11,6 @@ | |||
| */ | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| #include <numeric> | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| using namespace mgb; | |||
| @@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) { | |||
| return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; | |||
| case TensorFormats::KRSCk8: | |||
| return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; | |||
| case TensorFormats::KCRSc4: | |||
| return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||
| case TensorFormats::GKCRSc4: | |||
| return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||
| case TensorFormats::KCRS: | |||
| return {{"K"}, {"C"}, {"R"}, {"S"}}; | |||
| case TensorFormats::GKCRS: | |||
| @@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()( | |||
| lhs.attribute == rhs.attribute; | |||
| } | |||
| ReformatManager::ReformatKey& | |||
| ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { | |||
| static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = { | |||
| {TensorFormats::NCHW, TensorFormats::NCHWc64}, | |||
| {TensorFormats::NCHWc64, TensorFormats::NCHW}, | |||
| {TensorFormats::NCHW, TensorFormats::NHWC}, | |||
| {TensorFormats::NHWC, TensorFormats::NCHW}}; | |||
| if (set.count({input_format, output_format}) > 0 && | |||
| (dt.enumv() == DTypeEnum::QuantizedS4 || | |||
| dt.enumv() == DTypeEnum::Quantized4Asymm)) { | |||
| input_dtype = output_dtype = dt.enumv(); | |||
| } | |||
| return *this; | |||
| } | |||
| // =================== ReformatManager ====================*/ | |||
| #define FOREACH_FEATURE_TENSOR_FORMATS(cb) \ | |||
| cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \ | |||
| cb(NHCWc4) | |||
| #define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \ | |||
| cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \ | |||
| cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \ | |||
| cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8) | |||
| ReformatManager::ReformatManager() { | |||
| static constexpr TensorFormats feature_tensor_formats[] = { | |||
| #define cb(_fmt) TensorFormats::_fmt, | |||
| FOREACH_FEATURE_TENSOR_FORMATS(cb) | |||
| #undef cb | |||
| }; | |||
| static constexpr int nr_feature_tensor_formats = | |||
| sizeof(feature_tensor_formats) / sizeof(TensorFormats); | |||
| for (int i = 0; i < nr_feature_tensor_formats; ++i) { | |||
| for (int o = 0; o < nr_feature_tensor_formats; ++o) { | |||
| if (i == o) | |||
| continue; | |||
| NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape( | |||
| feature_tensor_formats[i]); | |||
| NamedTensorShape output_shape = | |||
| tensor_formats_to_named_tensor_shape( | |||
| feature_tensor_formats[o]); | |||
| auto impl = std::get<0>( | |||
| ReformatEmitter{input_shape, output_shape}.emit()); | |||
| m_cache.emplace(ReformatKey{feature_tensor_formats[i], | |||
| feature_tensor_formats[o]}, | |||
| impl); | |||
| } | |||
| } | |||
| static constexpr TensorFormats default_weight_tensor_formats = | |||
| TensorFormats::KCRS; | |||
| static constexpr TensorFormats default_group_conv_weight_tensor_formats = | |||
| TensorFormats::GKCRS; | |||
| static constexpr TensorFormats default_chan_conv_weight_tensor_formats = | |||
| TensorFormats::C11RS; | |||
| static constexpr TensorFormats weight_tensor_formats[] = { | |||
| #define cb(_fmt) TensorFormats::_fmt, | |||
| FOREACH_WEIGHT_TENSOR_FORMATS(cb) | |||
| #undef cb | |||
| }; | |||
| static constexpr int nr_weight_tensor_formats = | |||
| sizeof(weight_tensor_formats) / sizeof(TensorFormats); | |||
| using Name = megdnn::Dimension::Name; | |||
| for (int o = 0; o < nr_weight_tensor_formats; ++o) { | |||
| NamedTensorShape output_shape = | |||
| tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]); | |||
| TensorFormats input_format; | |||
| if (output_shape[0].name() == Name::G) { | |||
| input_format = default_group_conv_weight_tensor_formats; | |||
| } else if (output_shape[0].name() == Name::C) { | |||
| input_format = default_chan_conv_weight_tensor_formats; | |||
| } else { | |||
| mgb_assert(output_shape[0].name() == Name::K); | |||
| input_format = default_weight_tensor_formats; | |||
| } | |||
| NamedTensorShape input_shape = | |||
| tensor_formats_to_named_tensor_shape(input_format); | |||
| auto impl = | |||
| std::get<0>(ReformatEmitter{input_shape, output_shape}.emit()); | |||
| m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]}, | |||
| impl); | |||
| using Attribute = ReformatKey::Attribute; | |||
| { | |||
| auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4; | |||
| auto&& impl1 = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace(ReformatKey{i, o}, impl1); | |||
| auto&& impl2 = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace(ReformatKey{o, i}, impl2); | |||
| } | |||
| { | |||
| auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; | |||
| @@ -206,7 +179,7 @@ ReformatManager::ReformatManager() { | |||
| m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); | |||
| } | |||
| { | |||
| auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4; | |||
| auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4; | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| @@ -238,7 +211,7 @@ ReformatManager::ReformatManager() { | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) | |||
| megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace( | |||
| @@ -272,7 +245,7 @@ ReformatManager::ReformatManager() { | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) | |||
| megdnn::param::RelayoutFormat::Mode::NHWC_NCHW) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace( | |||
| @@ -371,14 +344,23 @@ ReformatManager::ReformatManager() { | |||
| impl); | |||
| } | |||
| } | |||
| #undef FOREACH_FEATURE_TENSOR_FORMATS | |||
| #undef FOREACH_WEIGHT_TENSOR_FORMATS | |||
| const ReformatManager::ReformatImpl& ReformatManager::get( | |||
| ReformatManager::ReformatImpl ReformatManager::get( | |||
| const ReformatKey& key) const { | |||
| using Attribute = ReformatKey::Attribute; | |||
| MGB_TRY { | |||
| auto&& impl = m_cache.at(key); | |||
| return impl; | |||
| auto find = m_cache.find(key); | |||
| if (find != m_cache.end()) { | |||
| auto rst = find->second; | |||
| return rst; | |||
| } | |||
| mgb_assert(key.attribute == Attribute::DEFAULT); | |||
| auto&& i = key.input_format; | |||
| auto&& o = key.output_format; | |||
| auto ishp = tensor_formats_to_named_tensor_shape(i); | |||
| auto oshp = tensor_formats_to_named_tensor_shape(o); | |||
| auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit()); | |||
| return builder; | |||
| } | |||
| MGB_CATCH(std::exception & exc, { | |||
| mgb_log_error( | |||
| @@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get( | |||
| } | |||
| const ReformatManager& ReformatManager::instance() { | |||
| static ReformatManager* inst = nullptr; | |||
| if (inst == nullptr) { | |||
| inst = new ReformatManager(); | |||
| } | |||
| return *inst; | |||
| static ReformatManager inst; | |||
| return inst; | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -227,6 +227,7 @@ namespace gopt { | |||
| VarReplaceCheckFlag m_var_replace_check_flag = | |||
| VarReplaceCheckFlag::CHECK_ALL; | |||
| class RelayoutPlaceholder; | |||
| friend class ShuffleShuffleRemovePass; | |||
| public: | |||
| TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { | |||
| @@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t { | |||
| KRSCk8 = 21, ///< [K/8, R, S, C, K%8] | |||
| // NCHW4 | |||
| KCRSc4 = 22, ///< [K, C/4, R, S, C%4] | |||
| GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4] | |||
| // default weight format | |||
| KCRS = 22, ///< [K, C, R, S] | |||
| GKCRS = 23, ///< [G, K, C, R, S] | |||
| C11RS = 24, ///< [C, 1, 1, R, S] | |||
| KCRS = 24, ///< [K, C, R, S] | |||
| GKCRS = 25, ///< [G, K, C, R, S] | |||
| C11RS = 26, ///< [C, 1, 1, R, S] | |||
| }; | |||
| class ReformatManager : public NonCopyableObj { | |||
| @@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj { | |||
| public: | |||
| using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; | |||
| enum class Attribute : uint32_t { | |||
| DEFAULT = 0, | |||
| IMAGE2D = 1 << 0, | |||
| IC_SMALL = 1 << 1, | |||
| }; | |||
| struct ReformatKey { | |||
| enum class Attribute : uint32_t { | |||
| DEFAULT = 0, | |||
| IMAGE2D = 1 << 0, | |||
| IC_SMALL = 1 << 1, | |||
| }; | |||
| TensorFormats input_format, output_format; | |||
| DTypeEnum input_dtype, output_dtype; | |||
| Attribute attribute; | |||
| std::string to_string() const; | |||
| ReformatKey() | |||
| : input_dtype{DTypeEnum::Float32}, | |||
| output_dtype{DTypeEnum::Float32}, | |||
| attribute{Attribute::DEFAULT} {} | |||
| ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | |||
| Attribute attribute_ = Attribute::DEFAULT, | |||
| DTypeEnum input_dtype_ = DTypeEnum::Float32, | |||
| @@ -86,11 +94,13 @@ public: | |||
| bool operator()(const ReformatKey& lhs, | |||
| const ReformatKey& rhs) const; | |||
| }; | |||
| ReformatKey& deduce_reformat_dtype_enum(const DType& dt); | |||
| }; | |||
| using ReformatCache = | |||
| std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | |||
| ReformatKey::Equal>; | |||
| const ReformatImpl& get(const ReformatKey& key) const; | |||
| ReformatImpl get(const ReformatKey& key) const; | |||
| ReformatImpl get(ReformatKey&& key) const { return get(key); } | |||
| static const ReformatManager& instance(); | |||
| private: | |||
| @@ -0,0 +1,171 @@ | |||
| /** | |||
| * \file src/gopt/test/reformat_manager.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./helper.h" | |||
| #include "megbrain/gopt/reformat_manager.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| using namespace mgb; | |||
| using namespace gopt; | |||
| TEST(TestReformatManager, Feature) { | |||
| constexpr size_t N = 16, C = 128, H = 7, W = 7; | |||
| HostTensorGenerator<> gen; | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64; | |||
| ReformatKey key{src_format, dst_format}; | |||
| auto reformat = ReformatManager::instance().get(key); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto r = [](VarNode* inp) { | |||
| auto x = SymbolVar(inp); | |||
| auto xshp = opr::GetVarShape::make(x); | |||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
| auto sub = [&xshp, &cv](int idx) { | |||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
| }; | |||
| auto tshp0 = opr::Concat::make( | |||
| {sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); | |||
| auto y0 = opr::Reshape::make(x, tshp0); | |||
| auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); | |||
| return y1; | |||
| }; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
| }; | |||
| auto x = mkvar("x", {N, H, W, C}); | |||
| auto y1 = SymbolVar(reformat({x.node()})); | |||
| auto y2 = r(x.node()); | |||
| size_t nr_shapeof = 0; | |||
| size_t nr_reshape = 0; | |||
| cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||
| if (o->same_type<opr::GetVarShape>()) | |||
| nr_shapeof++; | |||
| if (o->same_type<opr::Reshape>()) | |||
| nr_reshape++; | |||
| }} | |||
| .add(y1.node()->owner_opr()); | |||
| ASSERT_EQ(nr_shapeof, 1); | |||
| ASSERT_EQ(nr_reshape, 1); | |||
| HostTensorND t1, t2; | |||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
| func1->execute(); | |||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
| func2->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
| } | |||
| TEST(TestReformatManager, Weight) { | |||
| constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3; | |||
| HostTensorGenerator<> gen; | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| auto src_format = TensorFormats::GKCRS, | |||
| dst_format = TensorFormats::GKCRSk4c4; | |||
| ReformatKey key{src_format, dst_format}; | |||
| auto reformat = ReformatManager::instance().get(key); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto r = [](VarNode* inp) { | |||
| auto x = SymbolVar(inp); | |||
| auto xshp = opr::GetVarShape::make(x); | |||
| auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
| auto sub = [&xshp, &cv](int idx) { | |||
| return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
| }; | |||
| auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, | |||
| cv(4), sub(3), sub(4)}, | |||
| 0), | |||
| tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), | |||
| sub(4), cv(4), cv(4)}, | |||
| 0); | |||
| auto y0 = opr::Reshape::make(x, tshp0); | |||
| auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); | |||
| auto y2 = opr::Reshape::make(y1, tshp1); | |||
| return y2; | |||
| }; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
| }; | |||
| auto w = mkvar("w", {G, K / G, C / G, R, S}); | |||
| auto y1 = SymbolVar(reformat({w.node()})); | |||
| auto y2 = r(w.node()); | |||
| size_t nr_shapeof = 0; | |||
| size_t nr_reshape = 0; | |||
| cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||
| if (o->same_type<opr::GetVarShape>()) | |||
| nr_shapeof++; | |||
| if (o->same_type<opr::Reshape>()) | |||
| nr_reshape++; | |||
| }} | |||
| .add(y1.node()->owner_opr()); | |||
| ASSERT_EQ(nr_shapeof, 1); | |||
| ASSERT_EQ(nr_reshape, 1); | |||
| HostTensorND t1, t2; | |||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
| func1->execute(); | |||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
| func2->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
| } | |||
| TEST(TestReformatManager, InvalidKey) { | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| using Attribute = ReformatKey::Attribute; | |||
| auto src_format = TensorFormats::GKCRS, | |||
| dst_format = TensorFormats::GKCRSk4c4; | |||
| Attribute attribute = Attribute::IMAGE2D; | |||
| ReformatKey key{src_format, dst_format, attribute}; | |||
| ASSERT_THROW(ReformatManager::instance().get(key), AssertionError); | |||
| } | |||
| TEST(TestReformatManager, InputChannelSmall) { | |||
| constexpr size_t N = 16, C = 3, H = 224, W = 224; | |||
| auto cn = CompNode::load("cpux"); | |||
| HostTensorGenerator<> gen; | |||
| using ReformatKey = ReformatManager::ReformatKey; | |||
| using Attribute = ReformatKey::Attribute; | |||
| auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4; | |||
| ReformatKey key{src_format, dst_format, Attribute::IC_SMALL}; | |||
| auto reformat = ReformatManager::instance().get(key); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto r = [](VarNode* inp) { | |||
| auto x = SymbolVar(inp); | |||
| auto y = opr::RelayoutFormat::make( | |||
| x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); | |||
| return y; | |||
| }; | |||
| auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
| return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
| }; | |||
| auto x = mkvar("x", {N, C, H, W}); | |||
| auto y1 = SymbolVar(reformat({x.node()})); | |||
| auto y2 = r(x.node()); | |||
| HostTensorND t1, t2; | |||
| auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
| func1->execute(); | |||
| auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
| func2->execute(); | |||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||