GitOrigin-RevId: 038e372cbe
tags/v1.6.0
| @@ -200,6 +200,15 @@ static inline bool is_nchw_nchw4_shuffle_vec( | |||
| param.pattern[4] == 2; | |||
| } | |||
| static inline bool is_shape_before_nhwc(const TensorShape& shape) { | |||
| return shape.ndim == 4 && shape[1] == 4; | |||
| } | |||
| static inline bool is_nchw_nhwc_shuffle(const opr::Dimshuffle::Param param) { | |||
| return param.ndim == 4 && param.pattern[0] == 0 && param.pattern[1] == 2 && | |||
| param.pattern[2] == 3 && param.pattern[3] == 1; | |||
| } | |||
| template <typename T> | |||
| static inline bool is_immutable_equal(OperatorNodeBase* opr, T val, | |||
| DTypeEnum dtype_enum) { | |||
| @@ -276,14 +285,20 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() { | |||
| auto inp0 = opr->input()[0]; | |||
| return is_shape_nchw(inp0->shape()); | |||
| }}; | |||
| SGM::Node shuffle_root{ | |||
| opr::Dimshuffle::typeinfo(), | |||
| {{nchwx_reshape}}, | |||
| {{nchwx_reshape}, {broadcast_concat}}, | |||
| [](OperatorNodeBase* opr) { | |||
| auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>(); | |||
| auto& input_vec = shuffle_opr.input(); | |||
| return is_shape_before_nchw4(input_vec[0]->shape()) && | |||
| is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); | |||
| bool nchw_nchw4_ok = | |||
| is_shape_before_nchw4(input_vec[0]->shape()) && | |||
| is_nchw_nchw4_shuffle_vec(shuffle_opr.param()); | |||
| bool nchw_nhwc_ok = | |||
| is_shape_before_nhwc(input_vec[0]->shape()) && | |||
| is_nchw_nhwc_shuffle(shuffle_opr.param()); | |||
| return nchw_nchw4_ok || nchw_nhwc_ok; | |||
| }}; | |||
| return shuffle_root; | |||
| }; | |||
| @@ -382,6 +397,19 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() { | |||
| auto out_node = opr::RelayoutFormat::make( | |||
| rewriter.get_var(src_node->output()[0]), param.mode, | |||
| config); | |||
| const auto& outshp = opr->output(0)->shape(); | |||
| if (outshp.ndim == 4) { | |||
| auto shpvar = opr::GetVarShape::make(out_node); | |||
| auto cv = [&out_node](int v) { | |||
| return out_node.make_scalar(v); | |||
| }; | |||
| auto sub = [&shpvar, &cv](int idx) { | |||
| return opr::IndexAt::make(shpvar, {{0, cv(idx)}}); | |||
| }; | |||
| auto nhwc_shp = | |||
| opr::Concat::make({sub(0), sub(2), sub(3), sub(4)}, 0); | |||
| out_node = opr::Reshape::make(out_node, nhwc_shp); | |||
| } | |||
| return out_node.node()->owner_opr(); | |||
| } else { | |||
| return serialization::copy_opr_shallow(*opr, new_inp, | |||
| @@ -740,4 +768,4 @@ void FuseWarpPerspectiveDimshufflePass::apply(OptState& opt) const { | |||
| }; | |||
| opt.graph().iter(on_opr); | |||
| rewriter.apply_inplace(); | |||
| } | |||
| } | |||
| @@ -92,19 +92,24 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
| bool is_parameter = | |||
| fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == | |||
| TensorType::WEIGHT; | |||
| ReformatManager::ReformatImpl reformat; | |||
| ReformatManager::ReformatKey key{from, to, reformat_attribute, | |||
| var->dtype().enumv(), | |||
| var->dtype().enumv()}; | |||
| if (is_parameter) { | |||
| auto aligned_desc = make_aligned_desc(base_fmt, out_fmt); | |||
| reformat = ReformatManager::instance() | |||
| .auto_aligned_reformat_weight( | |||
| var, key, aligned_desc); | |||
| } else { | |||
| reformat = ReformatManager::instance() | |||
| .auto_aligned_reformat_featrue( | |||
| var, base_fmt, key); | |||
| // need relayout | |||
| if (from != to && !new_var->shape().is_scalar()) { | |||
| ReformatManager::ReformatImpl reformat; | |||
| ReformatManager::ReformatKey key{ | |||
| from, to, reformat_attribute, var->dtype().enumv(), | |||
| var->dtype().enumv()}; | |||
| if (is_parameter) { | |||
| auto aligned_desc = ReformatManager::make_aligned_desc( | |||
| base_fmt, out_fmt); | |||
| reformat = ReformatManager::instance() | |||
| .auto_aligned_reformat_weight( | |||
| var, key, aligned_desc); | |||
| } else { | |||
| reformat = ReformatManager::instance() | |||
| .auto_aligned_reformat_featrue( | |||
| var, base_fmt, key); | |||
| } | |||
| new_var = reformat({new_var}); | |||
| } | |||
| if (from != to && !new_var->shape().is_scalar()) | |||
| new_var = reformat({new_var}); | |||
| @@ -165,6 +165,7 @@ public: | |||
| private: | |||
| static constexpr float PROFILE_TIME_OUT = 1e7; | |||
| using ReformatAttribute = ReformatKey::Attribute; | |||
| /*! | |||
| * \brief profile opr format agnostic operators (like elemwise, elemwise multi type, typecvt etc.) | |||
| * | |||
| @@ -175,40 +176,48 @@ private: | |||
| */ | |||
| OperatorNodeRecord profile_operator( | |||
| const OperatorNodeBase* opr, TensorFormats base_format, | |||
| const SmallVector<TensorFormats>& available_tensor_formats) const; | |||
| const SmallVector<TensorFormats>& available_tensor_formats, | |||
| ReformatAttribute extra_attribute = | |||
| ReformatAttribute::DEFAULT) const; | |||
| float profile_operator(const OperatorNodeBase* opr, | |||
| TensorFormats base_format, | |||
| TensorFormats tensor_format) const; | |||
| TensorFormats tensor_format, | |||
| ReformatAttribute extra_attribute = | |||
| ReformatAttribute::DEFAULT) const; | |||
| /*! | |||
| * \brief profile opr format aware operators (like conv, deconv, conv_bias, etc.) | |||
| * \brief profile opr format aware operators (like conv, deconv, conv_bias, | |||
| * etc.) | |||
| * | |||
| * \param opr pointer to the operator node to be profiled | |||
| * \param base_config the tensor formats configuration of base opr format | |||
| * \param config all the available configuration | |||
| * \param config all the available configuration | |||
| * \return the operator node record | |||
| */ | |||
| OperatorNodeRecord profile_operator( | |||
| const OperatorNodeBase* opr, | |||
| const OprTensorFormatsConfiguration& base_config, | |||
| const SmallVector<OprTensorFormatsConfiguration>& available_configs) | |||
| const; | |||
| const SmallVector<OprTensorFormatsConfiguration>& available_configs, | |||
| ReformatAttribute extra_attribute = | |||
| ReformatAttribute::DEFAULT) const; | |||
| float profile_operator(const OperatorNodeBase* opr, | |||
| const OprTensorFormatsConfiguration& base_config, | |||
| const OprTensorFormatsConfiguration& config) const; | |||
| const OprTensorFormatsConfiguration& config, | |||
| ReformatAttribute extra_attribute = | |||
| ReformatAttribute::DEFAULT) const; | |||
| /*! | |||
| * \brief profile layout transform of the var node | |||
| * | |||
| * \param var pointer to the var node to be profiled | |||
| * \param base_format the original tensor formats in which the var node is stored | |||
| * \param available_tensor_formats the available tensor formats | |||
| * \param base_format the original tensor formats in which the var node is | |||
| * stored \param available_tensor_formats the available tensor formats | |||
| * \param extra_attribute the extra attributes (options) of the problem | |||
| * \return the var node record | |||
| */ | |||
| VarNodeRecord profile_var_node( | |||
| const VarNode* var, TensorFormats base_format, | |||
| const SmallVector<TensorFormats>& available_tensor_formats, | |||
| ReformatKey::Attribute extra_attribute = | |||
| ReformatKey::Attribute::DEFAULT) const; | |||
| ReformatAttribute extra_attribute = | |||
| ReformatAttribute::DEFAULT) const; | |||
| float profile_var_node(const VarNode* var, TensorFormats base_format, | |||
| const ReformatKey& key) const; | |||
| int m_runs; /// sample times of the profiler | |||
| @@ -216,20 +225,23 @@ private: | |||
| ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||
| const OperatorNodeBase* opr, TensorFormats base_format, | |||
| const SmallVector<TensorFormats>& available_tensor_formats) const { | |||
| const SmallVector<TensorFormats>& available_tensor_formats, | |||
| ReformatAttribute extra_attribute) const { | |||
| OperatorNodeRecord record; | |||
| record.opr = opr; | |||
| auto& costs = record.costs; | |||
| for (auto&& f : available_tensor_formats) { | |||
| auto opr_format = tensor_formats_to_opr_format(f); | |||
| costs[opr_format] = profile_operator(opr, base_format, f); | |||
| costs[opr_format] = | |||
| profile_operator(opr, base_format, f, extra_attribute); | |||
| } | |||
| return record; | |||
| } | |||
| float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | |||
| TensorFormats base_format, | |||
| TensorFormats tensor_format) const { | |||
| TensorFormats tensor_format, | |||
| ReformatAttribute extra_attribute) const { | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| graph->options().var_sanity_check_first_run = false; | |||
| @@ -239,8 +251,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | |||
| auto&& cn = var->comp_node(); | |||
| auto&& dtype = var->dtype(); | |||
| auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||
| auto aligned_tensor_shape = | |||
| make_aligned_tensor_shape(var, base_format, tensor_format); | |||
| auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||
| var, base_format, tensor_format, extra_attribute); | |||
| dval->resize(aligned_tensor_shape); | |||
| auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | |||
| new_inps[i] = aligned_var.node(); | |||
| @@ -263,8 +275,8 @@ float ProfilerImpl::profile_operator(const OperatorNodeBase* opr, | |||
| ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||
| const OperatorNodeBase* opr, | |||
| const OprTensorFormatsConfiguration& base_config, | |||
| const SmallVector<OprTensorFormatsConfiguration>& available_configs) | |||
| const { | |||
| const SmallVector<OprTensorFormatsConfiguration>& available_configs, | |||
| ReformatAttribute extra_attribute) const { | |||
| OperatorNodeRecord record; | |||
| record.opr = opr; | |||
| auto& costs = record.costs; | |||
| @@ -273,7 +285,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||
| if (i.opr_format == OprFormat::NCHW && | |||
| opr->input(0)->dtype().enumv() != DTypeEnum::Float32) | |||
| continue; | |||
| costs[i.opr_format] = profile_operator(opr, base_config, i); | |||
| costs[i.opr_format] = | |||
| profile_operator(opr, base_config, i, extra_attribute); | |||
| } | |||
| return record; | |||
| } | |||
| @@ -281,7 +294,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator( | |||
| float ProfilerImpl::profile_operator( | |||
| const OperatorNodeBase* opr, | |||
| const OprTensorFormatsConfiguration& base_config, | |||
| const OprTensorFormatsConfiguration& config) const { | |||
| const OprTensorFormatsConfiguration& config, | |||
| ReformatAttribute extra_attribute) const { | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| graph->options().var_sanity_check_first_run = false; | |||
| @@ -297,18 +311,18 @@ float ProfilerImpl::profile_operator( | |||
| TensorShape aligned_shape; | |||
| if (config.input_tensor_types[i] == TensorType::WEIGHT) { | |||
| mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | |||
| aligned_shape = make_aligned_weight_shape( | |||
| aligned_shape = ReformatManager::make_aligned_weight_shape( | |||
| var, base_config.input_tensor_formats[i], | |||
| config.input_tensor_formats[i], | |||
| config.output_tensor_formats[0]); | |||
| config.output_tensor_formats[0], extra_attribute); | |||
| } else { | |||
| mgb_assert(base_config.input_tensor_types[i] == | |||
| config.input_tensor_types[i]); | |||
| mgb_assert(base_config.input_tensor_types[i] == | |||
| TensorType::FEATURE); | |||
| aligned_shape = make_aligned_tensor_shape( | |||
| aligned_shape = ReformatManager::make_aligned_tensor_shape( | |||
| var, base_config.input_tensor_formats[i], | |||
| config.input_tensor_formats[i]); | |||
| config.input_tensor_formats[i], extra_attribute); | |||
| } | |||
| dval->resize(aligned_shape); | |||
| auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval); | |||
| @@ -357,7 +371,7 @@ float ProfilerImpl::profile_operator( | |||
| ProfilerImpl::VarNodeRecord ProfilerImpl::profile_var_node( | |||
| const VarNode* var, TensorFormats base_format, | |||
| const SmallVector<TensorFormats>& available_tensor_formats, | |||
| ReformatKey::Attribute attribute) const { | |||
| ReformatAttribute attribute) const { | |||
| VarNodeRecord record; | |||
| record.var = var; | |||
| auto& costs = record.costs; | |||
| @@ -379,8 +393,8 @@ float ProfilerImpl::profile_var_node(const VarNode* var, | |||
| auto&& cn = var->comp_node(); | |||
| auto&& dtype = var->dtype(); | |||
| auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||
| auto aligned_tensor_shape = | |||
| make_aligned_tensor_shape(var, base_format, key.input_format); | |||
| auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||
| var, base_format, key.input_format, key.attribute); | |||
| dval->resize(aligned_tensor_shape); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| @@ -468,13 +482,14 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||
| auto base_format = problem.base_format(); | |||
| auto&& available_tensor_formats = problem.available_tensor_formats(); | |||
| auto&& reformat_attribute = problem.attribute().reformat_attribute; | |||
| ProfilingResult profiling_result; | |||
| auto& opr_record = profiling_result.opr_record; | |||
| auto& var_record = profiling_result.var_record; | |||
| for (auto&& var : vars) { | |||
| var_record[var] = | |||
| profile_var_node(var, base_format, available_tensor_formats); | |||
| var_record[var] = profile_var_node( | |||
| var, base_format, available_tensor_formats, reformat_attribute); | |||
| } | |||
| for (auto&& opr : oprs) { | |||
| auto&& opr_configs = problem.opr_configs(); | |||
| @@ -482,11 +497,12 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||
| if (find == opr_configs.end()) { | |||
| if (skip_oprs.count(opr) > 0) { | |||
| SmallVector<TensorFormats> tensor_formats = {base_format}; | |||
| opr_record[opr] = | |||
| profile_operator(opr, base_format, tensor_formats); | |||
| opr_record[opr] = profile_operator( | |||
| opr, base_format, tensor_formats, reformat_attribute); | |||
| } else { | |||
| opr_record[opr] = profile_operator(opr, base_format, | |||
| available_tensor_formats); | |||
| available_tensor_formats, | |||
| reformat_attribute); | |||
| } | |||
| } else { | |||
| auto&& dispatchers = find->second; | |||
| @@ -498,7 +514,8 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile( | |||
| } | |||
| } | |||
| auto base_config = problem.base_config(opr); | |||
| opr_record[opr] = profile_operator(opr, base_config, configs); | |||
| opr_record[opr] = profile_operator(opr, base_config, configs, | |||
| reformat_attribute); | |||
| } | |||
| } | |||
| for (auto&& rpair : opr_record) { | |||
| @@ -21,7 +21,7 @@ using NamedTensorShape = megdnn::NamedTensorShape; | |||
| using Dimension = megdnn::Dimension; | |||
| namespace { | |||
| int gcd(const int& p, const int& q) { | |||
| static inline int gcd(const int& p, const int& q) { | |||
| int x = p, y = q; | |||
| while (y != 0) { | |||
| if (x < y) { | |||
| @@ -33,6 +33,47 @@ int gcd(const int& p, const int& q) { | |||
| } | |||
| return x; | |||
| } | |||
| static inline size_t extra_alignment( | |||
| ReformatManager::ReformatKey::Attribute attr, | |||
| TensorFormats target_formats, DType dt, size_t channel_alignment) { | |||
| using Attribute = ReformatManager::ReformatKey::Attribute; | |||
| if (attr & Attribute::AUTO_PADDING_NHWC) { | |||
| constexpr size_t alignment_in_bits = 32; | |||
| size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | |||
| size_t extra_alignment = alignment_in_bits >= dtype_bits | |||
| ? alignment_in_bits / dtype_bits | |||
| : 1; | |||
| if (target_formats == TensorFormats::NHWC) | |||
| channel_alignment = extra_alignment * channel_alignment / | |||
| gcd(channel_alignment, extra_alignment); | |||
| return channel_alignment; | |||
| } | |||
| return channel_alignment; | |||
| } | |||
| static inline std::tuple<size_t, size_t> extra_alignment( | |||
| const ReformatManager::ReformatKey& key, DType dt, | |||
| size_t input_channel_alignment, size_t output_channel_alignment) { | |||
| using Attribute = ReformatManager::ReformatKey::Attribute; | |||
| if (key.attribute & Attribute::AUTO_PADDING_NHWC) { | |||
| constexpr size_t alignment_in_bits = 32; | |||
| size_t dtype_bits = dt.is_low_bit() ? dt.low_bit() : dt.size(1) * 8; | |||
| size_t extra_alignment = alignment_in_bits >= dtype_bits | |||
| ? alignment_in_bits / dtype_bits | |||
| : 1; | |||
| if (key.input_format == TensorFormats::NHWC) | |||
| input_channel_alignment = | |||
| input_channel_alignment * extra_alignment / | |||
| gcd(input_channel_alignment, extra_alignment); | |||
| if (key.output_format == TensorFormats::NHWC) | |||
| output_channel_alignment = | |||
| output_channel_alignment * extra_alignment / | |||
| gcd(output_channel_alignment, extra_alignment); | |||
| return {input_channel_alignment, output_channel_alignment}; | |||
| } | |||
| return {input_channel_alignment, output_channel_alignment}; | |||
| } | |||
| }; // namespace | |||
| // =================== ReformatManager::ReformatKey ====================*/ | |||
| @@ -293,7 +334,8 @@ ReformatManager::ReformatImpl ReformatManager::get( | |||
| auto rst = find->second; | |||
| return rst; | |||
| } | |||
| mgb_assert(key.attribute == Attribute::DEFAULT); | |||
| mgb_assert(!(key.attribute & Attribute::IMAGE2D) && | |||
| !(key.attribute & Attribute::IC_SMALL)); | |||
| auto&& i = key.input_format; | |||
| auto&& o = key.output_format; | |||
| auto ishp = tensor_formats_to_named_tensor_shape(i); | |||
| @@ -346,6 +388,8 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_featrue( | |||
| "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | |||
| input_alignment, output_alignment, | |||
| input_shape.to_string().c_str()); | |||
| std::tie(input_alignment, output_alignment) = extra_alignment( | |||
| key, orig_var->dtype(), input_alignment, output_alignment); | |||
| NamedTensorShape orig_shape = | |||
| tensor_formats_to_named_tensor_shape(orig_format); | |||
| size_t orig_channel = 0; | |||
| @@ -451,6 +495,12 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
| "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | |||
| in_channel_alignment, out_channel_alignment, | |||
| output_shape.to_string().c_str()); | |||
| in_channel_alignment = | |||
| ::extra_alignment(key.attribute, key.output_format, | |||
| orig_var->dtype(), in_channel_alignment); | |||
| out_channel_alignment = | |||
| ::extra_alignment(key.attribute, key.output_format, | |||
| orig_var->dtype(), out_channel_alignment); | |||
| size_t aligned_in_channel = | |||
| divup(in_channels, in_channel_alignment) * in_channel_alignment; | |||
| if (extra_alignment.name == out_channel_name) { | |||
| @@ -506,9 +556,9 @@ const ReformatManager& ReformatManager::instance() { | |||
| return inst; | |||
| } | |||
| TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||
| TensorFormats orig_formats, | |||
| TensorFormats target_formats) { | |||
| TensorShape ReformatManager::make_aligned_tensor_shape( | |||
| const VarNode* var, TensorFormats orig_formats, | |||
| TensorFormats target_formats, ReformatKey::Attribute extra_attribute) { | |||
| using Dimension = megdnn::Dimension; | |||
| static constexpr uint32_t UNDETERMINED_EXTENT = | |||
| Dimension::UNDETERMINED_EXTENT; | |||
| @@ -545,6 +595,15 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||
| tshp[i] = oshp[idx] * factor; | |||
| else | |||
| tshp[i] = divup(oshp[idx], factor); | |||
| if (name == Dimension::Name::C) { | |||
| size_t channel_alignment = target_shape[i].stride(); | |||
| size_t channels = tshp[i] * channel_alignment; | |||
| size_t new_channel_alignment = | |||
| extra_alignment(extra_attribute, target_formats, | |||
| var->dtype(), channel_alignment); | |||
| tshp[i] = divup(channels, new_channel_alignment) * | |||
| new_channel_alignment / channel_alignment; | |||
| } | |||
| } else { | |||
| tshp[i] = target_shape[i].extent(); | |||
| } | |||
| @@ -552,11 +611,12 @@ TensorShape mgb::gopt::make_aligned_tensor_shape(const VarNode* var, | |||
| return tshp; | |||
| } | |||
| TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||
| TensorFormats orig_formats, | |||
| TensorFormats target_formats, | |||
| TensorFormats extra_formats) { | |||
| auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats); | |||
| TensorShape ReformatManager::make_aligned_weight_shape( | |||
| const VarNode* var, TensorFormats orig_formats, | |||
| TensorFormats target_formats, TensorFormats extra_formats, | |||
| ReformatKey::Attribute extra_attribute) { | |||
| auto tshp = make_aligned_tensor_shape(var, orig_formats, target_formats, | |||
| extra_attribute); | |||
| auto extra_shape = tensor_formats_to_named_tensor_shape(extra_formats); | |||
| using Dimension = megdnn::Dimension; | |||
| static constexpr uint32_t UNDETERMINED_EXTENT = | |||
| @@ -567,6 +627,9 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||
| if (name == Dimension::Name::C && | |||
| extra_shape[i].extent() == UNDETERMINED_EXTENT) { | |||
| out_channel_alignment = extra_shape[i].stride(); | |||
| out_channel_alignment = | |||
| extra_alignment(extra_attribute, target_formats, | |||
| var->dtype(), out_channel_alignment); | |||
| } | |||
| } | |||
| @@ -583,9 +646,8 @@ TensorShape mgb::gopt::make_aligned_weight_shape(const VarNode* var, | |||
| return tshp; | |||
| } | |||
| ReformatManager::AlignmentDesc mgb::gopt::make_aligned_desc( | |||
| ReformatManager::AlignmentDesc ReformatManager::make_aligned_desc( | |||
| TensorFormats weight_format, TensorFormats out_feature_format) { | |||
| using AlignmentDesc = ReformatManager::AlignmentDesc; | |||
| using Name = Dimension::Name; | |||
| auto weight_shape = tensor_formats_to_named_tensor_shape(weight_format); | |||
| auto out_shape = tensor_formats_to_named_tensor_shape(out_feature_format); | |||
| @@ -143,6 +143,7 @@ public: | |||
| TensorFormats base_format() const { | |||
| return m_ctx.attribute().base_tensor_formats; | |||
| } | |||
| Attribute attribute() const { return m_ctx.attribute(); } | |||
| /*! | |||
| * \brief return the tensor formats configuration of an operator in the | |||
| * default op format | |||
| @@ -74,6 +74,7 @@ public: | |||
| DEFAULT = 0, | |||
| IMAGE2D = 1 << 0, | |||
| IC_SMALL = 1 << 1, | |||
| AUTO_PADDING_NHWC = 1 << 2, | |||
| }; | |||
| TensorFormats input_format, output_format; | |||
| DTypeEnum input_dtype, output_dtype; | |||
| @@ -124,23 +125,40 @@ public: | |||
| ReformatImpl auto_aligned_reformat_weight( | |||
| const VarNode* orig_var, const ReformatKey& key, | |||
| const AlignmentDesc& extra_alignment = {}) const; | |||
| static TensorShape make_aligned_tensor_shape( | |||
| const VarNode* var, TensorFormats orig_formats, | |||
| TensorFormats target_formats, | |||
| ReformatKey::Attribute extra_attribute = | |||
| ReformatKey::Attribute::DEFAULT); | |||
| static TensorShape make_aligned_weight_shape( | |||
| const VarNode* var, TensorFormats orig_formats, | |||
| TensorFormats target_formats, TensorFormats extra_formats, | |||
| ReformatKey::Attribute extra_attribute = | |||
| ReformatKey::Attribute::DEFAULT); | |||
| static AlignmentDesc make_aligned_desc(TensorFormats weight_format, | |||
| TensorFormats out_feature_format); | |||
| static const ReformatManager& instance(); | |||
| private: | |||
| ReformatCache m_cache; | |||
| }; | |||
| TensorShape make_aligned_tensor_shape(const VarNode* var, | |||
| TensorFormats orig_formats, | |||
| TensorFormats target_formats); | |||
| TensorShape make_aligned_weight_shape(const VarNode* var, | |||
| TensorFormats orig_formats, | |||
| TensorFormats target_formats, | |||
| TensorFormats extra_formats); | |||
| MGB_DEF_ENUM_CLASS_BIT_OPR(ReformatManager::ReformatKey::Attribute); | |||
| // | |||
| //TensorShape make_aligned_tensor_shape( | |||
| // const VarNode* var, TensorFormats orig_formats, | |||
| // TensorFormats target_formats, | |||
| // ReformatManager::ReformatKey::Attribute extra_attribute = | |||
| // ReformatManager::ReformatKey::Attribute::DEFAULT); | |||
| // | |||
| //TensorShape make_aligned_weight_shape( | |||
| // const VarNode* var, TensorFormats orig_formats, | |||
| // TensorFormats target_formats, TensorFormats extra_formats, | |||
| // ReformatManager::ReformatKey::Attribute extra_attribute = | |||
| // ReformatManager::ReformatKey::Attribute::DEFAULT); | |||
| ReformatManager::AlignmentDesc make_aligned_desc( | |||
| TensorFormats weight_format, TensorFormats out_feature_format); | |||
| } // namespace gopt | |||
| } // namespace mgb | |||