GitOrigin-RevId: 132605c7d9
tags/v1.7.2.m1
| @@ -246,6 +246,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) { | |||||
| return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}}; | return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}}; | ||||
| case Format::NCHW44_DOT: | case Format::NCHW44_DOT: | ||||
| return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}}; | return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}}; | ||||
| case Format::NHWCD4: | |||||
| return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}}; | |||||
| default: | default: | ||||
| megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format)) | megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format)) | ||||
| .c_str()); | .c_str()); | ||||
| @@ -229,6 +229,30 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWCD4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NHWCD4; | |||||
| config.config_id = OprFormatConfigID::NHWCD4; | |||||
| bool available = | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::Float32 || | |||||
| DNN_FLOAT16_SELECT( | |||||
| (opr->input(0)->dtype().enumv() == DTypeEnum::Float16), true) || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::Int8 || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||||
| config.input_tensor_types = {TensorType::FEATURE}; | |||||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||||
| config.input_tensor_formats = {TensorFormats::NHCWc4}; | |||||
| config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||||
| if (available) | |||||
| return config; | |||||
| return None; | |||||
| } | |||||
| }; | |||||
| template <typename Opr, OprFormatConfigID config_id> | template <typename Opr, OprFormatConfigID config_id> | ||||
| struct ConvTensorFormatsDispatcherImpl; | struct ConvTensorFormatsDispatcherImpl; | ||||
| @@ -814,6 +838,55 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWCD4> { | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NHWCD4; | |||||
| config.config_id = OprFormatConfigID::NHWCD4; | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
| if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NHCWc4, TensorFormats::KRSCk4c4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NHCWc4, TensorFormats::KRSCk4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } | |||||
| } else { | |||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
| if (is_channel_wise_conv<Opr>(opr)) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NHCWc4, TensorFormats::C1RSc4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } else { | |||||
| if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NHCWc4, TensorFormats::GKRSCk4c4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::NHCWc4, TensorFormats::GKRSCk4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } | |||||
| } | |||||
| } | |||||
| config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| template <> | template <> | ||||
| struct ConvTensorFormatsDispatcherImpl< | struct ConvTensorFormatsDispatcherImpl< | ||||
| opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { | opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { | ||||
| @@ -919,6 +992,57 @@ struct ConvTensorFormatsDispatcherImpl< | |||||
| } | } | ||||
| }; | }; | ||||
| template <> | |||||
| struct ConvTensorFormatsDispatcherImpl< | |||||
| opr::ConvolutionBackwardData, OprFormatConfigID::NHWCD4> { | |||||
| using Opr = opr::ConvolutionBackwardData; | |||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||||
| OprTensorFormatsConfiguration config; | |||||
| config.typeinfo = opr->dyn_typeinfo(); | |||||
| config.opr_format = OprFormat::NHWCD4; | |||||
| config.config_id = OprFormatConfigID::NHWCD4; | |||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
| TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
| config.input_tensor_types.emplace_back(tensor_type); | |||||
| } | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
| if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::KRSCk4c4, TensorFormats::NHCWc4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::KRSCk4, TensorFormats::NHCWc4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } | |||||
| } else { | |||||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
| if (is_channel_wise_conv<Opr>(opr)) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::C1RSc4, TensorFormats::NHCWc4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } else { | |||||
| if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::GKRSCk4c4, TensorFormats::NHCWc4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } else { | |||||
| config.input_tensor_formats = { | |||||
| TensorFormats::GKRSCk4, TensorFormats::NHCWc4, | |||||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||||
| } | |||||
| } | |||||
| } | |||||
| config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||||
| return config; | |||||
| } | |||||
| }; | |||||
| struct StaticData { | struct StaticData { | ||||
| struct KeyHash { | struct KeyHash { | ||||
| size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const { | size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const { | ||||
| @@ -969,6 +1093,7 @@ StaticData::StaticData() { | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NHWCD4); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC); | ||||
| @@ -979,15 +1104,18 @@ StaticData::StaticData() { | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWCD4); | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWCD4); | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWCD4); | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC); | ||||
| @@ -997,10 +1125,12 @@ StaticData::StaticData() { | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWCD4); | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NHWCD4); | |||||
| #undef OPR_TENSOR_FORMATS_CONFIG_REG | #undef OPR_TENSOR_FORMATS_CONFIG_REG | ||||
| #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
| #include "megbrain/plugin/base.h" | #include "megbrain/plugin/base.h" | ||||
| #include "megbrain/serialization/sereg.h" | #include "megbrain/serialization/sereg.h" | ||||
| #include "megdnn/tensor_format.h" | |||||
| using namespace mgb; | using namespace mgb; | ||||
| using namespace cg; | using namespace cg; | ||||
| @@ -281,9 +282,6 @@ float ProfilerImpl::profile_operator( | |||||
| std::min(config.input_tensor_formats.size(), opr->input().size()); | std::min(config.input_tensor_formats.size(), opr->input().size()); | ||||
| for (; i < nr_input_tensor; ++i) { | for (; i < nr_input_tensor; ++i) { | ||||
| auto&& var = opr->input(i); | auto&& var = opr->input(i); | ||||
| auto&& cn = var->comp_node(); | |||||
| auto&& dtype = var->dtype(); | |||||
| auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||||
| TensorShape aligned_shape; | TensorShape aligned_shape; | ||||
| if (config.input_tensor_types[i] == TensorType::WEIGHT) { | if (config.input_tensor_types[i] == TensorType::WEIGHT) { | ||||
| mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | ||||
| @@ -299,9 +297,12 @@ float ProfilerImpl::profile_operator( | |||||
| var, base_config.input_tensor_formats[i], | var, base_config.input_tensor_formats[i], | ||||
| config.input_tensor_formats[i], extra_attribute); | config.input_tensor_formats[i], extra_attribute); | ||||
| } | } | ||||
| dval->resize(aligned_shape); | |||||
| std::shared_ptr<DeviceTensorND> dval = create_device_tensor_helper( | |||||
| config, i, var, aligned_shape, extra_attribute); | |||||
| if (config.input_tensor_types[i] == TensorType::WEIGHT) { | if (config.input_tensor_types[i] == TensorType::WEIGHT) { | ||||
| new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node(); | |||||
| new_inps[i] = | |||||
| opr::SharedDeviceTensorWithFormat::make_const(*graph, dval).node(); | |||||
| } else { | } else { | ||||
| new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); | new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); | ||||
| } | } | ||||
| @@ -368,10 +369,27 @@ float ProfilerImpl::profile_var_node( | |||||
| const VarNode* var, TensorFormats base_format, const ReformatKey& key) const { | const VarNode* var, TensorFormats base_format, const ReformatKey& key) const { | ||||
| auto&& cn = var->comp_node(); | auto&& cn = var->comp_node(); | ||||
| auto&& dtype = var->dtype(); | auto&& dtype = var->dtype(); | ||||
| auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||||
| auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | ||||
| var, base_format, key.input_format, key.attribute); | var, base_format, key.input_format, key.attribute); | ||||
| dval->resize(aligned_tensor_shape); | |||||
| std::shared_ptr<DeviceTensorND> dval; | |||||
| if (key.input_format == TensorFormats::NHCWc4 && | |||||
| key.attribute & ReformatAttribute::IMAGE2D) { | |||||
| size_t align_axis = 2; | |||||
| auto named_tensor = tensor_formats_to_named_tensor_shape(key.input_format); | |||||
| for (size_t n = 0; n < named_tensor.ndim; n++) { | |||||
| if (named_tensor[n].name() == megdnn::Dimension::Name::C) { | |||||
| align_axis = n; | |||||
| break; | |||||
| } | |||||
| } | |||||
| dval = std::make_shared<DeviceTensorND>( | |||||
| cn, aligned_tensor_shape, dtype, | |||||
| megdnn::Image2DPack4TensorFormat::make( | |||||
| align_axis, opr::intl::get_megdnn_handle(cn))); | |||||
| } else | |||||
| dval = std::make_shared<DeviceTensorND>(cn, aligned_tensor_shape, dtype); | |||||
| auto graph = ComputingGraph::make(); | auto graph = ComputingGraph::make(); | ||||
| graph->options().graph_opt_level = 0; | graph->options().graph_opt_level = 0; | ||||
| graph->options().var_sanity_check_first_run = false; | graph->options().var_sanity_check_first_run = false; | ||||
| @@ -516,6 +534,8 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||||
| return OprFormatConfigID::NHWC; | return OprFormatConfigID::NHWC; | ||||
| case TensorFormats::CHWNc4: | case TensorFormats::CHWNc4: | ||||
| return OprFormatConfigID::CHWN4; | return OprFormatConfigID::CHWN4; | ||||
| case TensorFormats::NHCWc4: | |||||
| return OprFormatConfigID::NHWCD4; | |||||
| default: | default: | ||||
| mgb_throw( | mgb_throw( | ||||
| MegBrainError, "tensor format(%u) is not supported", | MegBrainError, "tensor format(%u) is not supported", | ||||
| @@ -523,6 +543,39 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<DeviceTensorND> ProfilerImpl::create_device_tensor_helper( | |||||
| const OprTensorFormatsConfiguration& config, const size_t inp_idx, | |||||
| const VarNode* var, const TensorShape aligned_shape, | |||||
| ReformatAttribute extra_attribute) const { | |||||
| auto&& cn = var->comp_node(); | |||||
| auto&& dtype = var->dtype(); | |||||
| std::shared_ptr<DeviceTensorND> dval; | |||||
| if (config.config_id == OprFormatConfigID::NHWCD4 && | |||||
| extra_attribute & ReformatAttribute::IMAGE2D) { | |||||
| size_t align_axis = 2; | |||||
| auto named_tensor = tensor_formats_to_named_tensor_shape( | |||||
| config.input_tensor_formats[inp_idx]); | |||||
| for (size_t n = 0; n < named_tensor.ndim; n++) { | |||||
| if (named_tensor[n].name() == megdnn::Dimension::Name::C) { | |||||
| align_axis = n; | |||||
| break; | |||||
| } | |||||
| } | |||||
| // channel wise weight | |||||
| bool is_channel_wise = | |||||
| config.input_tensor_formats[inp_idx] == TensorFormats::C1RSc4; | |||||
| if (is_channel_wise) | |||||
| align_axis = 1; | |||||
| dval = std::make_shared<DeviceTensorND>( | |||||
| cn, aligned_shape, dtype, | |||||
| megdnn::Image2DPack4TensorFormat::make( | |||||
| align_axis, opr::intl::get_megdnn_handle(cn))); | |||||
| } else { | |||||
| dval = std::make_shared<DeviceTensorND>(cn, aligned_shape, dtype); | |||||
| } | |||||
| return dval; | |||||
| } | |||||
| /* ================== ProfilerBase =================*/ | /* ================== ProfilerBase =================*/ | ||||
| std::string ProfilerBase::OperatorNodeRecord::to_string() const { | std::string ProfilerBase::OperatorNodeRecord::to_string() const { | ||||
| auto str = ssprintf( | auto str = ssprintf( | ||||
| @@ -249,7 +249,7 @@ ReformatManager::ReformatManager() { | |||||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | ||||
| } | } | ||||
| { | { | ||||
| auto i = TensorFormats::KCRS, o = TensorFormats::GKRSCk4; | |||||
| auto i = TensorFormats::GKCRS, o = TensorFormats::GKRSCk4; | |||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
| vars[0], | vars[0], | ||||
| @@ -259,7 +259,7 @@ ReformatManager::ReformatManager() { | |||||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | ||||
| } | } | ||||
| { | { | ||||
| auto i = TensorFormats::KCRS, o = TensorFormats::C1RSc4; | |||||
| auto i = TensorFormats::C11RS, o = TensorFormats::C1RSc4; | |||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
| vars[0], | vars[0], | ||||
| @@ -268,6 +268,21 @@ ReformatManager::ReformatManager() { | |||||
| }; | }; | ||||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | ||||
| } | } | ||||
| { | |||||
| auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | |||||
| auto&& impl1 = [](const VarNodeArray& vars) { | |||||
| return opr::RelayoutFormat::make( | |||||
| vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4) | |||||
| .node(); | |||||
| }; | |||||
| m_cache.emplace(ReformatKey{i, o}, impl1); | |||||
| auto&& impl2 = [](const VarNodeArray& vars) { | |||||
| return opr::RelayoutFormat::make( | |||||
| vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4_NCHW) | |||||
| .node(); | |||||
| }; | |||||
| m_cache.emplace(ReformatKey{o, i}, impl2); | |||||
| } | |||||
| { | { | ||||
| auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | ||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| @@ -281,7 +296,7 @@ ReformatManager::ReformatManager() { | |||||
| auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW; | auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW; | ||||
| auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
| return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
| vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) | |||||
| vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW) | |||||
| .node(); | .node(); | ||||
| }; | }; | ||||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | ||||
| @@ -346,6 +361,15 @@ ReformatManager::ReformatImpl ReformatManager::get(const ReformatKey& key) const | |||||
| return rst; | return rst; | ||||
| } | } | ||||
| } | } | ||||
| if (key.attribute == Attribute::IMAGE2D) { | |||||
| auto key_ = key; | |||||
| key_.input_dtype = DTypeEnum::Float32; | |||||
| key_.output_dtype = DTypeEnum::Float32; | |||||
| auto find = m_cache.find(key_); | |||||
| if (find != m_cache.end()) { | |||||
| return find->second; | |||||
| } | |||||
| } | |||||
| mgb_assert( | mgb_assert( | ||||
| !(key.attribute & Attribute::IMAGE2D) && | !(key.attribute & Attribute::IMAGE2D) && | ||||
| !(key.attribute & Attribute::IC_SMALL)); | !(key.attribute & Attribute::IC_SMALL)); | ||||
| @@ -682,7 +706,8 @@ TensorShape ReformatManager::make_aligned_weight_shape( | |||||
| auto target_shape = tensor_formats_to_named_tensor_shape(target_formats); | auto target_shape = tensor_formats_to_named_tensor_shape(target_formats); | ||||
| for (size_t i = 0; i < target_shape.ndim; ++i) { | for (size_t i = 0; i < target_shape.ndim; ++i) { | ||||
| auto name = target_shape[i].name(); | auto name = target_shape[i].name(); | ||||
| if ((name == Dimension::Name::K || name == Dimension::Name::N) && | |||||
| if ((name == Dimension::Name::K || name == Dimension::Name::N || | |||||
| (extra_formats == TensorFormats::NHCWc4 && name == Dimension::Name::C)) && | |||||
| target_shape[i].extent() == UNDETERMINED_EXTENT) { | target_shape[i].extent() == UNDETERMINED_EXTENT) { | ||||
| size_t out_channels = tshp[i] * target_shape[i].stride(); | size_t out_channels = tshp[i] * target_shape[i].stride(); | ||||
| tshp[i] = divup(out_channels, out_channel_alignment) * | tshp[i] = divup(out_channels, out_channel_alignment) * | ||||
| @@ -32,6 +32,7 @@ static inline const char* opr_format_to_string( | |||||
| cb(NCHW44); | cb(NCHW44); | ||||
| cb(NCHW88); | cb(NCHW88); | ||||
| cb(NCHW44_DOT); | cb(NCHW44_DOT); | ||||
| cb(NHWCD4); | |||||
| default: | default: | ||||
| mgb_assert( | mgb_assert( | ||||
| false, "Invalid opr format(got:%u)", | false, "Invalid opr format(got:%u)", | ||||
| @@ -63,6 +64,7 @@ static inline const char* config_id_to_string( | |||||
| cb(NCHW88_HYBRID); | cb(NCHW88_HYBRID); | ||||
| cb(NCHW44_DOT); | cb(NCHW44_DOT); | ||||
| cb(NCHW44_DOT_HYBRID); | cb(NCHW44_DOT_HYBRID); | ||||
| cb(NHWCD4); | |||||
| default: | default: | ||||
| mgb_assert( | mgb_assert( | ||||
| false, "Invalid config id(got:%u)", | false, "Invalid config id(got:%u)", | ||||
| @@ -95,6 +97,8 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||||
| return TensorFormats::NCHWc8; | return TensorFormats::NCHWc8; | ||||
| case OprFormat::NCHW44_DOT: | case OprFormat::NCHW44_DOT: | ||||
| return TensorFormats::NCHWc4; | return TensorFormats::NCHWc4; | ||||
| case OprFormat::NHWCD4: | |||||
| return TensorFormats::NHCWc4; | |||||
| default: | default: | ||||
| mgb_throw( | mgb_throw( | ||||
| AssertionError, "format(%s) is not supported", | AssertionError, "format(%s) is not supported", | ||||
| @@ -202,6 +202,11 @@ protected: | |||||
| const ReformatKey& key) const; | const ReformatKey& key) const; | ||||
| OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; | OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; | ||||
| std::shared_ptr<DeviceTensorND> create_device_tensor_helper( | |||||
| const OprTensorFormatsConfiguration& config, const size_t inp_idx, | |||||
| const VarNode* var, const TensorShape aligned_shape, | |||||
| ReformatAttribute extra_attribute) const; | |||||
| OprFootprint m_opr_footprint; | OprFootprint m_opr_footprint; | ||||
| float m_opr_threshold; /// a threshold, when the computation of the newly | float m_opr_threshold; /// a threshold, when the computation of the newly | ||||
| /// created operator that is built in some opr | /// created operator that is built in some opr | ||||
| @@ -336,6 +336,10 @@ cg::OperatorNodeBase::NodeProp* VolatileSharedDeviceTensor::do_make_node_prop() | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| void VolatileSharedDeviceTensor::init_output_format() { | |||||
| output(0)->format(get_dev_tensor().format()); | |||||
| } | |||||
| SymbolVar VolatileSharedDeviceTensor::make( | SymbolVar VolatileSharedDeviceTensor::make( | ||||
| ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| @@ -337,6 +337,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||||
| public: | public: | ||||
| using Super::Super; | using Super::Super; | ||||
| void init_output_format() override; | |||||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | MGE_WIN_DECLSPEC_FUC static SymbolVar make( | ||||
| ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | ||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||