GitOrigin-RevId: 132605c7d9
tags/v1.7.2.m1
| @@ -246,6 +246,8 @@ NamedTensorShape NamedTensorShape::make_named_tensor_shape(Format format) { | |||
| return {{"N//8"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}, {"N%8"}}; | |||
| case Format::NCHW44_DOT: | |||
| return {{"N//4"}, {"C//4"}, {"H"}, {"W"}, {"N%4"}, {"C%4"}}; | |||
| case Format::NHWCD4: | |||
| return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}}; | |||
| default: | |||
| megdnn_throw(ssprintf("Format unimplement(%d)", static_cast<int>(format)) | |||
| .c_str()); | |||
| @@ -229,6 +229,30 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWCD4> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NHWCD4; | |||
| config.config_id = OprFormatConfigID::NHWCD4; | |||
| bool available = | |||
| opr->input(0)->dtype().enumv() == DTypeEnum::Float32 || | |||
| DNN_FLOAT16_SELECT( | |||
| (opr->input(0)->dtype().enumv() == DTypeEnum::Float16), true) || | |||
| opr->input(0)->dtype().enumv() == DTypeEnum::Int8 || | |||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||
| config.input_tensor_types = {TensorType::FEATURE}; | |||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||
| config.input_tensor_formats = {TensorFormats::NHCWc4}; | |||
| config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||
| if (available) | |||
| return config; | |||
| return None; | |||
| } | |||
| }; | |||
| template <typename Opr, OprFormatConfigID config_id> | |||
| struct ConvTensorFormatsDispatcherImpl; | |||
| @@ -814,6 +838,55 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWCD4> { | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NHWCD4; | |||
| config.config_id = OprFormatConfigID::NHWCD4; | |||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
| config.input_tensor_types.emplace_back(tensor_type); | |||
| } | |||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
| if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NHCWc4, TensorFormats::KRSCk4c4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NHCWc4, TensorFormats::KRSCk4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } | |||
| } else { | |||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
| if (is_channel_wise_conv<Opr>(opr)) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NHCWc4, TensorFormats::C1RSc4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } else { | |||
| if (opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| opr->input(1)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NHCWc4, TensorFormats::GKRSCk4c4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::NHCWc4, TensorFormats::GKRSCk4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } | |||
| } | |||
| } | |||
| config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||
| return config; | |||
| } | |||
| }; | |||
| template <> | |||
| struct ConvTensorFormatsDispatcherImpl< | |||
| opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> { | |||
| @@ -919,6 +992,57 @@ struct ConvTensorFormatsDispatcherImpl< | |||
| } | |||
| }; | |||
| template <> | |||
| struct ConvTensorFormatsDispatcherImpl< | |||
| opr::ConvolutionBackwardData, OprFormatConfigID::NHWCD4> { | |||
| using Opr = opr::ConvolutionBackwardData; | |||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
| const auto& conv = opr->cast_final_safe<Opr>(); | |||
| OprTensorFormatsConfiguration config; | |||
| config.typeinfo = opr->dyn_typeinfo(); | |||
| config.opr_format = OprFormat::NHWCD4; | |||
| config.config_id = OprFormatConfigID::NHWCD4; | |||
| for (size_t i = 0; i < opr->input().size(); ++i) { | |||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
| TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
| config.input_tensor_types.emplace_back(tensor_type); | |||
| } | |||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
| if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::KRSCk4c4, TensorFormats::NHCWc4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::KRSCk4, TensorFormats::NHCWc4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } | |||
| } else { | |||
| mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
| if (is_channel_wise_conv<Opr>(opr)) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::C1RSc4, TensorFormats::NHCWc4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } else { | |||
| if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
| opr->input(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm) { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::GKRSCk4c4, TensorFormats::NHCWc4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } else { | |||
| config.input_tensor_formats = { | |||
| TensorFormats::GKRSCk4, TensorFormats::NHCWc4, | |||
| TensorFormats::NHCWc4, TensorFormats::NHCWc4}; | |||
| } | |||
| } | |||
| } | |||
| config.output_tensor_formats = {TensorFormats::NHCWc4}; | |||
| return config; | |||
| } | |||
| }; | |||
| struct StaticData { | |||
| struct KeyHash { | |||
| size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const { | |||
| @@ -969,6 +1093,7 @@ StaticData::StaticData() { | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NHWCD4); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC); | |||
| @@ -979,15 +1104,18 @@ StaticData::StaticData() { | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWCD4); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW4); | |||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWCD4); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWC); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW4); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NCHW64); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(WarpPerspectiveForward, NHWCD4); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWC); | |||
| @@ -997,10 +1125,12 @@ StaticData::StaticData() { | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NHWCD4); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | |||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NHWCD4); | |||
| #undef OPR_TENSOR_FORMATS_CONFIG_REG | |||
| #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | |||
| @@ -22,6 +22,7 @@ | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/plugin/base.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/tensor_format.h" | |||
| using namespace mgb; | |||
| using namespace cg; | |||
| @@ -281,9 +282,6 @@ float ProfilerImpl::profile_operator( | |||
| std::min(config.input_tensor_formats.size(), opr->input().size()); | |||
| for (; i < nr_input_tensor; ++i) { | |||
| auto&& var = opr->input(i); | |||
| auto&& cn = var->comp_node(); | |||
| auto&& dtype = var->dtype(); | |||
| auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||
| TensorShape aligned_shape; | |||
| if (config.input_tensor_types[i] == TensorType::WEIGHT) { | |||
| mgb_assert(base_config.input_tensor_types[i] == TensorType::WEIGHT); | |||
| @@ -299,9 +297,12 @@ float ProfilerImpl::profile_operator( | |||
| var, base_config.input_tensor_formats[i], | |||
| config.input_tensor_formats[i], extra_attribute); | |||
| } | |||
| dval->resize(aligned_shape); | |||
| std::shared_ptr<DeviceTensorND> dval = create_device_tensor_helper( | |||
| config, i, var, aligned_shape, extra_attribute); | |||
| if (config.input_tensor_types[i] == TensorType::WEIGHT) { | |||
| new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node(); | |||
| new_inps[i] = | |||
| opr::SharedDeviceTensorWithFormat::make_const(*graph, dval).node(); | |||
| } else { | |||
| new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node(); | |||
| } | |||
| @@ -368,10 +369,27 @@ float ProfilerImpl::profile_var_node( | |||
| const VarNode* var, TensorFormats base_format, const ReformatKey& key) const { | |||
| auto&& cn = var->comp_node(); | |||
| auto&& dtype = var->dtype(); | |||
| auto dval = std::make_shared<DeviceTensorND>(cn, dtype); | |||
| auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape( | |||
| var, base_format, key.input_format, key.attribute); | |||
| dval->resize(aligned_tensor_shape); | |||
| std::shared_ptr<DeviceTensorND> dval; | |||
| if (key.input_format == TensorFormats::NHCWc4 && | |||
| key.attribute & ReformatAttribute::IMAGE2D) { | |||
| size_t align_axis = 2; | |||
| auto named_tensor = tensor_formats_to_named_tensor_shape(key.input_format); | |||
| for (size_t n = 0; n < named_tensor.ndim; n++) { | |||
| if (named_tensor[n].name() == megdnn::Dimension::Name::C) { | |||
| align_axis = n; | |||
| break; | |||
| } | |||
| } | |||
| dval = std::make_shared<DeviceTensorND>( | |||
| cn, aligned_tensor_shape, dtype, | |||
| megdnn::Image2DPack4TensorFormat::make( | |||
| align_axis, opr::intl::get_megdnn_handle(cn))); | |||
| } else | |||
| dval = std::make_shared<DeviceTensorND>(cn, aligned_tensor_shape, dtype); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| graph->options().var_sanity_check_first_run = false; | |||
| @@ -516,6 +534,8 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||
| return OprFormatConfigID::NHWC; | |||
| case TensorFormats::CHWNc4: | |||
| return OprFormatConfigID::CHWN4; | |||
| case TensorFormats::NHCWc4: | |||
| return OprFormatConfigID::NHWCD4; | |||
| default: | |||
| mgb_throw( | |||
| MegBrainError, "tensor format(%u) is not supported", | |||
| @@ -523,6 +543,39 @@ ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id( | |||
| } | |||
| } | |||
| std::shared_ptr<DeviceTensorND> ProfilerImpl::create_device_tensor_helper( | |||
| const OprTensorFormatsConfiguration& config, const size_t inp_idx, | |||
| const VarNode* var, const TensorShape aligned_shape, | |||
| ReformatAttribute extra_attribute) const { | |||
| auto&& cn = var->comp_node(); | |||
| auto&& dtype = var->dtype(); | |||
| std::shared_ptr<DeviceTensorND> dval; | |||
| if (config.config_id == OprFormatConfigID::NHWCD4 && | |||
| extra_attribute & ReformatAttribute::IMAGE2D) { | |||
| size_t align_axis = 2; | |||
| auto named_tensor = tensor_formats_to_named_tensor_shape( | |||
| config.input_tensor_formats[inp_idx]); | |||
| for (size_t n = 0; n < named_tensor.ndim; n++) { | |||
| if (named_tensor[n].name() == megdnn::Dimension::Name::C) { | |||
| align_axis = n; | |||
| break; | |||
| } | |||
| } | |||
| // channel wise weight | |||
| bool is_channel_wise = | |||
| config.input_tensor_formats[inp_idx] == TensorFormats::C1RSc4; | |||
| if (is_channel_wise) | |||
| align_axis = 1; | |||
| dval = std::make_shared<DeviceTensorND>( | |||
| cn, aligned_shape, dtype, | |||
| megdnn::Image2DPack4TensorFormat::make( | |||
| align_axis, opr::intl::get_megdnn_handle(cn))); | |||
| } else { | |||
| dval = std::make_shared<DeviceTensorND>(cn, aligned_shape, dtype); | |||
| } | |||
| return dval; | |||
| } | |||
| /* ================== ProfilerBase =================*/ | |||
| std::string ProfilerBase::OperatorNodeRecord::to_string() const { | |||
| auto str = ssprintf( | |||
| @@ -249,7 +249,7 @@ ReformatManager::ReformatManager() { | |||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
| } | |||
| { | |||
| auto i = TensorFormats::KCRS, o = TensorFormats::GKRSCk4; | |||
| auto i = TensorFormats::GKCRS, o = TensorFormats::GKRSCk4; | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| @@ -259,7 +259,7 @@ ReformatManager::ReformatManager() { | |||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
| } | |||
| { | |||
| auto i = TensorFormats::KCRS, o = TensorFormats::C1RSc4; | |||
| auto i = TensorFormats::C11RS, o = TensorFormats::C1RSc4; | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], | |||
| @@ -268,6 +268,21 @@ ReformatManager::ReformatManager() { | |||
| }; | |||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
| } | |||
| { | |||
| auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | |||
| auto&& impl1 = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace(ReformatKey{i, o}, impl1); | |||
| auto&& impl2 = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4_NCHW) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace(ReformatKey{o, i}, impl2); | |||
| } | |||
| { | |||
| auto i = TensorFormats::NCHW, o = TensorFormats::NHCWc4; | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| @@ -281,7 +296,7 @@ ReformatManager::ReformatManager() { | |||
| auto i = TensorFormats::NHCWc4, o = TensorFormats::NCHW; | |||
| auto&& impl = [](const VarNodeArray& vars) { | |||
| return opr::RelayoutFormat::make( | |||
| vars[0], megdnn::param::RelayoutFormat::Mode::NCHW_NHWCD4I) | |||
| vars[0], megdnn::param::RelayoutFormat::Mode::NHWCD4I_NCHW) | |||
| .node(); | |||
| }; | |||
| m_cache.emplace(ReformatKey{i, o, Attribute::IMAGE2D}, impl); | |||
| @@ -346,6 +361,15 @@ ReformatManager::ReformatImpl ReformatManager::get(const ReformatKey& key) const | |||
| return rst; | |||
| } | |||
| } | |||
| if (key.attribute == Attribute::IMAGE2D) { | |||
| auto key_ = key; | |||
| key_.input_dtype = DTypeEnum::Float32; | |||
| key_.output_dtype = DTypeEnum::Float32; | |||
| auto find = m_cache.find(key_); | |||
| if (find != m_cache.end()) { | |||
| return find->second; | |||
| } | |||
| } | |||
| mgb_assert( | |||
| !(key.attribute & Attribute::IMAGE2D) && | |||
| !(key.attribute & Attribute::IC_SMALL)); | |||
| @@ -682,7 +706,8 @@ TensorShape ReformatManager::make_aligned_weight_shape( | |||
| auto target_shape = tensor_formats_to_named_tensor_shape(target_formats); | |||
| for (size_t i = 0; i < target_shape.ndim; ++i) { | |||
| auto name = target_shape[i].name(); | |||
| if ((name == Dimension::Name::K || name == Dimension::Name::N) && | |||
| if ((name == Dimension::Name::K || name == Dimension::Name::N || | |||
| (extra_formats == TensorFormats::NHCWc4 && name == Dimension::Name::C)) && | |||
| target_shape[i].extent() == UNDETERMINED_EXTENT) { | |||
| size_t out_channels = tshp[i] * target_shape[i].stride(); | |||
| tshp[i] = divup(out_channels, out_channel_alignment) * | |||
| @@ -32,6 +32,7 @@ static inline const char* opr_format_to_string( | |||
| cb(NCHW44); | |||
| cb(NCHW88); | |||
| cb(NCHW44_DOT); | |||
| cb(NHWCD4); | |||
| default: | |||
| mgb_assert( | |||
| false, "Invalid opr format(got:%u)", | |||
| @@ -63,6 +64,7 @@ static inline const char* config_id_to_string( | |||
| cb(NCHW88_HYBRID); | |||
| cb(NCHW44_DOT); | |||
| cb(NCHW44_DOT_HYBRID); | |||
| cb(NHWCD4); | |||
| default: | |||
| mgb_assert( | |||
| false, "Invalid config id(got:%u)", | |||
| @@ -95,6 +97,8 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||
| return TensorFormats::NCHWc8; | |||
| case OprFormat::NCHW44_DOT: | |||
| return TensorFormats::NCHWc4; | |||
| case OprFormat::NHWCD4: | |||
| return TensorFormats::NHCWc4; | |||
| default: | |||
| mgb_throw( | |||
| AssertionError, "format(%s) is not supported", | |||
| @@ -202,6 +202,11 @@ protected: | |||
| const ReformatKey& key) const; | |||
| OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const; | |||
| std::shared_ptr<DeviceTensorND> create_device_tensor_helper( | |||
| const OprTensorFormatsConfiguration& config, const size_t inp_idx, | |||
| const VarNode* var, const TensorShape aligned_shape, | |||
| ReformatAttribute extra_attribute) const; | |||
| OprFootprint m_opr_footprint; | |||
| float m_opr_threshold; /// a threshold, when the computation of the newly | |||
| /// created operator that is built in some opr | |||
| @@ -336,6 +336,10 @@ cg::OperatorNodeBase::NodeProp* VolatileSharedDeviceTensor::do_make_node_prop() | |||
| return ret; | |||
| } | |||
| void VolatileSharedDeviceTensor::init_output_format() { | |||
| output(0)->format(get_dev_tensor().format()); | |||
| } | |||
| SymbolVar VolatileSharedDeviceTensor::make( | |||
| ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config) { | |||
| @@ -337,6 +337,8 @@ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| public: | |||
| using Super::Super; | |||
| void init_output_format() override; | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
| ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data, | |||
| const OperatorNodeConfig& config = {}); | |||