Merge Target::ARM and Target::X86 into Target::CPU to make global layout transform easier to use
GitOrigin-RevId: cc9363fa38
tags/v1.7.0
| @@ -830,9 +830,9 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet | |||||
| src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); | src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); | ||||
| dst[4] = 32; | dst[4] = 32; | ||||
| } else if (param().format == Param::Format::NCHW88) { | } else if (param().format == Param::Format::NCHW88) { | ||||
| megdnn_assert(src.ndim == 5 || src.ndim == 4, | |||||
| "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", | |||||
| src.ndim); | |||||
| megdnn_assert( | |||||
| src.ndim == 5 || src.ndim == 4, | |||||
| "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim); | |||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = src[0]; | dst[0] = src[0]; | ||||
| auto oc = cflt.ocpg * cflt.group; | auto oc = cflt.ocpg * cflt.group; | ||||
| @@ -850,11 +850,12 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet | |||||
| "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); | "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); | ||||
| } | } | ||||
| } else if (param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_DOT) { | |||||
| megdnn_assert(src.ndim == 5 || src.ndim == 4, | |||||
| "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", | |||||
| src.ndim); | |||||
| } else if ( | |||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_DOT) { | |||||
| megdnn_assert( | |||||
| src.ndim == 5 || src.ndim == 4, | |||||
| "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim); | |||||
| dst.ndim = 5; | dst.ndim = 5; | ||||
| dst[0] = src[0]; | dst[0] = src[0]; | ||||
| auto oc = cflt.ocpg * cflt.group; | auto oc = cflt.ocpg * cflt.group; | ||||
| @@ -491,7 +491,6 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( | |||||
| auto& states = cuts.back().states; | auto& states = cuts.back().states; | ||||
| prune(states, edges[cur], ctx); | prune(states, edges[cur], ctx); | ||||
| force_prune(states); | force_prune(states); | ||||
| } | } | ||||
| cur++; | cur++; | ||||
| } | } | ||||
| @@ -32,8 +32,7 @@ const char* target_to_string(Target target) { | |||||
| return #_target | return #_target | ||||
| switch (target) { | switch (target) { | ||||
| cb(CUDA); | cb(CUDA); | ||||
| cb(X86); | |||||
| cb(ARM); | |||||
| cb(CPU); | |||||
| cb(UNSPEC); | cb(UNSPEC); | ||||
| default: | default: | ||||
| mgb_assert( | mgb_assert( | ||||
| @@ -89,7 +88,7 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||||
| std::unique_ptr<LayoutTransformContext> make_cpu_ctx( | |||||
| OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { | OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { | ||||
| OprList opr_list = { | OprList opr_list = { | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| @@ -104,34 +103,30 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||||
| }; | }; | ||||
| SmallVector<TensorFormats> available_tensor_formats = { | SmallVector<TensorFormats> available_tensor_formats = { | ||||
| TensorFormats::NCHW, TensorFormats::NCHWc4, | |||||
| DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | |||||
| Attribute attribute = {base_config_id, base_tensor_format, Target::ARM}; | |||||
| TensorFormats::NCHW, TensorFormats::NCHWc4, TensorFormats::NCHWc8}; | |||||
| Attribute attribute = {base_config_id, base_tensor_format, Target::CPU}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | auto ctx = std::make_unique<LayoutTransformContext>( | ||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | std::move(opr_list), std::move(available_tensor_formats), attribute); | ||||
| ctx->add_opr_config( | ctx->add_opr_config( | ||||
| opr::ConvBiasForward::typeinfo(), | opr::ConvBiasForward::typeinfo(), | ||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | ||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), | |||||
| OprFormatConfigID::NCHW44_DOT, OprFormatConfigID::NCHW44_DOT_HYBRID}) | |||||
| OprFormatConfigID::NCHW44_HYBRID, OprFormatConfigID::NCHW88, | |||||
| OprFormatConfigID::NCHW88_HYBRID, OprFormatConfigID::NCHW44_DOT, | |||||
| OprFormatConfigID::NCHW44_DOT_HYBRID}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ConvolutionForward::typeinfo(), | opr::ConvolutionForward::typeinfo(), | ||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | ||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), | |||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), | |||||
| OprFormatConfigID::NCHW44_DOT, | |||||
| OprFormatConfigID::NCHW44_HYBRID, OprFormatConfigID::NCHW88, | |||||
| OprFormatConfigID::NCHW88_HYBRID, OprFormatConfigID::NCHW44_DOT, | |||||
| OprFormatConfigID::NCHW44_DOT_HYBRID}) | OprFormatConfigID::NCHW44_DOT_HYBRID}) | ||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::PoolingForward::typeinfo(), | opr::PoolingForward::typeinfo(), | ||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | ||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}) | |||||
| OprFormatConfigID::NCHW88}) | |||||
| .add_opr_config( | .add_opr_config( | ||||
| opr::ResizeForward::typeinfo(), | opr::ResizeForward::typeinfo(), | ||||
| {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, | ||||
| DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}); | |||||
| OprFormatConfigID::NCHW88}); | |||||
| return ctx; | return ctx; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -162,8 +157,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | |||||
| switch (target) { | switch (target) { | ||||
| case Target::CUDA: | case Target::CUDA: | ||||
| return make_cuda_ctx(base_config_id, base_tensor_format); | return make_cuda_ctx(base_config_id, base_tensor_format); | ||||
| case Target::ARM: | |||||
| return make_arm_ctx(base_config_id, base_tensor_format); | |||||
| case Target::CPU: | |||||
| return make_cpu_ctx(base_config_id, base_tensor_format); | |||||
| default: | default: | ||||
| mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | ||||
| } | } | ||||
| @@ -82,6 +82,8 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW> { | |||||
| } | } | ||||
| }; | }; | ||||
| /* \remark: Here, maybe we needn't check data type of input and output tensors. Because | |||||
| * algo available checker will skip the configuration that has no underlying impls. */ | |||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { | struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| @@ -89,8 +91,9 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW44; | config.opr_format = OprFormat::NCHW44; | ||||
| config.config_id = OprFormatConfigID::NCHW44; | config.config_id = OprFormatConfigID::NCHW44; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| bool f32_config = opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| bool i8_config = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| bool available = f32_config || i8_config; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | config.output_dtypes = {opr->output(0)->dtype().enumv()}; | ||||
| @@ -102,7 +105,6 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { | |||||
| } | } | ||||
| }; | }; | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { | struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| @@ -110,8 +112,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW88; | config.opr_format = OprFormat::NCHW88; | ||||
| config.config_id = OprFormatConfigID::NCHW88; | config.config_id = OprFormatConfigID::NCHW88; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
| bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| config.output_dtypes = {opr->output(0)->dtype().enumv()}; | config.output_dtypes = {opr->output(0)->dtype().enumv()}; | ||||
| @@ -122,7 +123,6 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { | |||||
| return config; | return config; | ||||
| } | } | ||||
| }; | }; | ||||
| #endif | |||||
| template <> | template <> | ||||
| struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> { | struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> { | ||||
| @@ -131,8 +131,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW4; | config.opr_format = OprFormat::NCHW4; | ||||
| config.config_id = OprFormatConfigID::NCHW4; | config.config_id = OprFormatConfigID::NCHW4; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| bool available = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| @@ -152,8 +151,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::CHWN4> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::CHWN4; | config.opr_format = OprFormat::CHWN4; | ||||
| config.config_id = OprFormatConfigID::CHWN4; | config.config_id = OprFormatConfigID::CHWN4; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| bool available = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| @@ -173,8 +171,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW32> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW32; | config.opr_format = OprFormat::NCHW32; | ||||
| config.config_id = OprFormatConfigID::NCHW32; | config.config_id = OprFormatConfigID::NCHW32; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| bool available = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; | ||||
| @@ -194,9 +191,8 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWC> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NHWC; | config.opr_format = OprFormat::NHWC; | ||||
| config.config_id = OprFormatConfigID::NHWC; | config.config_id = OprFormatConfigID::NHWC; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | |||||
| bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv(); | available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv(); | ||||
| @@ -216,9 +212,8 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> { | |||||
| config.typeinfo = opr->dyn_typeinfo(); | config.typeinfo = opr->dyn_typeinfo(); | ||||
| config.opr_format = OprFormat::NCHW64; | config.opr_format = OprFormat::NCHW64; | ||||
| config.config_id = OprFormatConfigID::NCHW64; | config.config_id = OprFormatConfigID::NCHW64; | ||||
| bool available = true; | |||||
| available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | |||||
| bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
| opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; | |||||
| config.input_dtypes = {opr->input(0)->dtype().enumv()}; | config.input_dtypes = {opr->input(0)->dtype().enumv()}; | ||||
| config.input_tensor_types = {TensorType::FEATURE}; | config.input_tensor_types = {TensorType::FEATURE}; | ||||
| available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv(); | available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv(); | ||||
| @@ -552,14 +547,24 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44> { | |||||
| config.opr_format = OprFormat::NCHW44; | config.opr_format = OprFormat::NCHW44; | ||||
| config.config_id = OprFormatConfigID::NCHW44; | config.config_id = OprFormatConfigID::NCHW44; | ||||
| bool available = true; | bool available = true; | ||||
| auto check_dtype = [](DType dt, bool is_bias) { | |||||
| bool f32_config = dt.enumv() == DTypeEnum::Float32; | |||||
| auto i8_dtype = DTypeEnum::QuantizedS8; | |||||
| if (is_bias) | |||||
| i8_dtype = DTypeEnum::QuantizedS32; | |||||
| bool i8_config = dt.enumv() == i8_dtype; | |||||
| return f32_config || i8_config; | |||||
| }; | |||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| bool is_bias = | |||||
| ConvParamTrait<Opr>::has_bias && i == ConvParamTrait<Opr>::bias_idx; | |||||
| available &= check_dtype(opr->input(i)->dtype(), is_bias); | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
| config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
| } | } | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| available &= check_dtype(opr->output(0)->dtype(), false); | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| // setup tensor formats | // setup tensor formats | ||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | ||||
| @@ -594,14 +599,24 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> { | |||||
| config.opr_format = OprFormat::NCHW44; | config.opr_format = OprFormat::NCHW44; | ||||
| config.config_id = OprFormatConfigID::NCHW44_HYBRID; | config.config_id = OprFormatConfigID::NCHW44_HYBRID; | ||||
| bool available = true; | bool available = true; | ||||
| auto check_dtype = [](DType dt, bool is_bias) { | |||||
| bool f32_config = dt.enumv() == DTypeEnum::Float32; | |||||
| auto i8_dtype = DTypeEnum::QuantizedS8; | |||||
| if (is_bias) | |||||
| i8_dtype = DTypeEnum::QuantizedS32; | |||||
| bool i8_config = dt.enumv() == i8_dtype; | |||||
| return f32_config || i8_config; | |||||
| }; | |||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| bool is_bias = | |||||
| ConvParamTrait<Opr>::has_bias && i == ConvParamTrait<Opr>::bias_idx; | |||||
| available &= check_dtype(opr->input(i)->dtype(), is_bias); | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
| config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
| } | } | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| available &= check_dtype(opr->output(0)->dtype(), false); | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
| config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
| @@ -614,7 +629,6 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> { | |||||
| } | } | ||||
| }; | }; | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> { | struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> { | ||||
| static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
| @@ -626,12 +640,12 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> { | |||||
| bool available = true; | bool available = true; | ||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
| config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
| } | } | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| // setup tensor formats | // setup tensor formats | ||||
| if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | ||||
| @@ -668,12 +682,12 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> { | |||||
| bool available = true; | bool available = true; | ||||
| // setup dtypes | // setup dtypes | ||||
| for (size_t i = 0; i < opr->input().size(); ++i) { | for (size_t i = 0; i < opr->input().size(); ++i) { | ||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||||
| available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | ||||
| TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | ||||
| config.input_tensor_types.emplace_back(tensor_type); | config.input_tensor_types.emplace_back(tensor_type); | ||||
| } | } | ||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
| available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
| config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | ||||
| available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | available &= conv.param().sparse == Opr::Param::Sparse::DENSE; | ||||
| // setup tensor formats | // setup tensor formats | ||||
| @@ -686,7 +700,6 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> { | |||||
| return config; | return config; | ||||
| } | } | ||||
| }; | }; | ||||
| #endif | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT> { | struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT> { | ||||
| @@ -914,10 +927,8 @@ StaticData::StaticData() { | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID); | ||||
| #endif | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); | ||||
| @@ -925,10 +936,8 @@ StaticData::StaticData() { | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID); | ||||
| #endif | |||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); | ||||
| OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); | ||||
| @@ -949,15 +958,11 @@ StaticData::StaticData() { | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | ||||
| #endif | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | ||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | ||||
| #if !MEGDNN_DISABLE_FLOAT16 | |||||
| OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | ||||
| #endif | |||||
| #undef OPR_TENSOR_FORMATS_CONFIG_REG | #undef OPR_TENSOR_FORMATS_CONFIG_REG | ||||
| #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | ||||
| @@ -357,9 +357,8 @@ struct GraphTuningOptions { | |||||
| enum class Target : uint32_t { | enum class Target : uint32_t { | ||||
| UNSPEC = 0, ///< unspecific device target | UNSPEC = 0, ///< unspecific device target | ||||
| CUDA = 1, ///< CUDA device, usually refer to GPU devices of Nvidia | CUDA = 1, ///< CUDA device, usually refer to GPU devices of Nvidia | ||||
| X86 = 2, ///< x86 cpu | |||||
| ARM = 3, ///< arm cpu | |||||
| OPENCL = 4, ///< opencl, usually run on mobile devices | |||||
| CPU = 2, ///< cpu | |||||
| OPENCL = 3, ///< opencl, usually run on mobile devices | |||||
| }; | }; | ||||
| Target target; | Target target; | ||||
| bool layout_transform = false; ///< whether to enable graph level | bool layout_transform = false; ///< whether to enable graph level | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "megbrain/plugin/profiler.h" | #include "megbrain/plugin/profiler.h" | ||||
| #include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
| #define MGB_WITH_CACHED_TEST 0 | |||||
| #define MGB_WITH_CACHED_TEST 1 | |||||
| #if MGB_WITH_CACHED_TEST | #if MGB_WITH_CACHED_TEST | ||||
| #include "./cache_data.h" | #include "./cache_data.h" | ||||
| @@ -923,9 +923,196 @@ TEST(TestLayoutTransform, MobileNetV2) { | |||||
| HostTensorND t2; | HostTensorND t2; | ||||
| auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); | auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); | ||||
| func2->execute(); | func2->execute(); | ||||
| gprof.to_json_full(func2.get())->writeto_fpath(output_file("mobilenet_v2_f32.json")); | |||||
| gprof.to_json_full(func2.get()) | |||||
| ->writeto_fpath(output_file("mobilenet_v2_f32.json")); | |||||
| /// check correct | /// check correct | ||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | MGB_ASSERT_TENSOR_EQ(t1, t2); | ||||
| } | } | ||||
| TEST(TestLayoutTransform, MobileNetV2_NCHW88) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| Network network(cn); | |||||
| auto output = make_mobilenet_v2(network, 1); | |||||
| HostTensorND t1; | |||||
| auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | |||||
| func1->execute(); | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | |||||
| using Target = LayoutTransformContext::Target; | |||||
| using Attribute = LayoutTransformContext::Attribute; | |||||
| OprList opr_list = { | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::ElemwiseMultiType::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| opr::Concat::typeinfo(), | |||||
| opr::PoolingForward::typeinfo(), | |||||
| opr::WarpPerspectiveForward::typeinfo(), | |||||
| opr::Resize::typeinfo(), | |||||
| }; | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | |||||
| TensorFormats::NCHW, | |||||
| TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc8, | |||||
| }; | |||||
| Attribute attribute = { | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | |||||
| ctx->add_opr_config( | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW88, | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW88_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW88, | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW88_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::PoolingForward::typeinfo(), { | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW88, | |||||
| }); | |||||
| #if MGB_WITH_CACHED_TEST | |||||
| auto profiler = std::make_unique<ProfilerMock>( | |||||
| static_cast<const uint8_t*>(TestLayoutTransform_MobileNetV2_NCHW88.data()), | |||||
| TestLayoutTransform_MobileNetV2_NCHW88.size()); | |||||
| #else | |||||
| auto profiler = ProfilerBase::make_cached_profiler( | |||||
| "TestLayoutTransform.MobileNetV2_NCHW88.cache"); | |||||
| #endif | |||||
| std::unique_ptr<SolverBase> solver{ | |||||
| new DynamicProgrammingSolver(std::move(profiler))}; | |||||
| auto new_output = | |||||
| gopt::GraphOptimizer{} | |||||
| .add_pass<FuseConvBiasNonlinPass>() | |||||
| .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | |||||
| .add_pass<ShuffleShuffleRemovePass>() | |||||
| .add_pass<ParamFusePass>() | |||||
| .add_pass<ParamMergePass>() | |||||
| .apply({{output}}) | |||||
| .endpoint_vars(); | |||||
| auto new_out_var = new_output[0]; | |||||
| /// check global layout transform pass | |||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | |||||
| ASSERT_EQ(nr_dimshuffle, 1u); | |||||
| /// check first conv format | |||||
| const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var); | |||||
| const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | |||||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW88); | |||||
| GraphProfiler gprof{network.graph.get()}; | |||||
| HostTensorND t2; | |||||
| auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); | |||||
| func2->execute(); | |||||
| gprof.to_json_full(func2.get()) | |||||
| ->writeto_fpath(output_file("mobilenet_v2_nchw88.json")); | |||||
| /// check correct | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| TEST(TestLayoutTransform, MobileNetV2_NCHW44_DOT) { | |||||
| auto cn = CompNode::load("cpu0"); | |||||
| Network network(cn); | |||||
| auto output = make_mobilenet_v2(network, 1, dtype::QuantizedS8{1.f}); | |||||
| HostTensorND t1; | |||||
| auto func1 = network.graph->compile({make_callback_copy(output, t1)}); | |||||
| func1->execute(); | |||||
| using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID; | |||||
| using OprList = LayoutTransformContext::OprList; | |||||
| using Target = LayoutTransformContext::Target; | |||||
| using Attribute = LayoutTransformContext::Attribute; | |||||
| OprList opr_list = { | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| opr::ElemwiseMultiType::typeinfo(), | |||||
| opr::Elemwise::typeinfo(), | |||||
| opr::TypeCvt::typeinfo(), | |||||
| opr::Concat::typeinfo(), | |||||
| opr::PoolingForward::typeinfo(), | |||||
| opr::WarpPerspectiveForward::typeinfo(), | |||||
| opr::Resize::typeinfo(), | |||||
| }; | |||||
| SmallVector<TensorFormats> available_tensor_formats = { | |||||
| TensorFormats::NCHW, | |||||
| TensorFormats::NCHWc4, | |||||
| TensorFormats::NCHWc8, | |||||
| }; | |||||
| Attribute attribute = { | |||||
| OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC}; | |||||
| auto ctx = std::make_unique<LayoutTransformContext>( | |||||
| std::move(opr_list), std::move(available_tensor_formats), attribute); | |||||
| ctx->add_opr_config( | |||||
| opr::ConvBiasForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| OprFormatConfigID::NCHW44_DOT, | |||||
| OprFormatConfigID::NCHW44_DOT_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::ConvolutionForward::typeinfo(), | |||||
| { | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44, | |||||
| OprFormatConfigID::NCHW44_HYBRID, | |||||
| OprFormatConfigID::NCHW44_DOT, | |||||
| OprFormatConfigID::NCHW44_DOT_HYBRID, | |||||
| }) | |||||
| .add_opr_config( | |||||
| opr::PoolingForward::typeinfo(), { | |||||
| OprFormatConfigID::NCHW, | |||||
| OprFormatConfigID::NCHW44, | |||||
| }); | |||||
| #if MGB_WITH_CACHED_TEST | |||||
| auto profiler = std::make_unique<ProfilerMock>( | |||||
| static_cast<const uint8_t*>( | |||||
| TestLayoutTransform_MobileNetV2_NCHW44_DOT.data()), | |||||
| TestLayoutTransform_MobileNetV2_NCHW44_DOT.size()); | |||||
| #else | |||||
| auto profiler = ProfilerBase::make_cached_profiler( | |||||
| "TestLayoutTransform.MobileNetV2_NCHW44_DOT.cache"); | |||||
| #endif | |||||
| std::unique_ptr<SolverBase> solver{ | |||||
| new DynamicProgrammingSolver(std::move(profiler))}; | |||||
| auto new_output = | |||||
| gopt::GraphOptimizer{} | |||||
| .add_pass<FuseConvBiasNonlinPass>() | |||||
| .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver)) | |||||
| .add_pass<ShuffleShuffleRemovePass>() | |||||
| .add_pass<ParamFusePass>() | |||||
| .add_pass<ParamMergePass>() | |||||
| .apply({{output}}) | |||||
| .endpoint_vars(); | |||||
| auto new_out_var = new_output[0]; | |||||
| /// check global layout transform pass | |||||
| auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var); | |||||
| ASSERT_EQ(nr_dimshuffle, 1u); | |||||
| /// check first conv format | |||||
| const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var); | |||||
| const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>(); | |||||
| ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44_DOT); | |||||
| GraphProfiler gprof{network.graph.get()}; | |||||
| HostTensorND t2; | |||||
| auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); | |||||
| func2->execute(); | |||||
| gprof.to_json_full(func2.get()) | |||||
| ->writeto_fpath(output_file("mobilenet_v2_nchw44_dot.json")); | |||||
| /// check correct | |||||
| MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -57,7 +57,10 @@ SymbolVar Network::add_group_conv( | |||||
| {groups, output_channels / groups, input_channels / groups, kern_size[0], | {groups, output_channels / groups, input_channels / groups, kern_size[0], | ||||
| kern_size[1]}); | kern_size[1]}); | ||||
| auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1}); | auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1}); | ||||
| mgb_assert(out_dtype.category() == DTypeCategory::FLOAT); | |||||
| if (out_dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| weight = add_type_cvt(weight, out_dtype); | |||||
| bias = add_type_cvt(bias, dtype::QuantizedS32{1.f}); | |||||
| } | |||||
| opr::ConvBias::Param param; | opr::ConvBias::Param param; | ||||
| param.sparse = opr::ConvBias::Param::Sparse::GROUP; | param.sparse = opr::ConvBias::Param::Sparse::GROUP; | ||||
| param.stride_h = stride[0], param.stride_w = stride[1]; | param.stride_h = stride[0], param.stride_w = stride[1]; | ||||
| @@ -68,8 +71,15 @@ SymbolVar Network::add_group_conv( | |||||
| param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; | param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; | ||||
| } | } | ||||
| auto conv = opr::ConvBias::make( | |||||
| f, weight, bias, param, {}, OperatorNodeConfig{out_dtype}); | |||||
| weight_idx++; | |||||
| bias_idx++; | |||||
| SymbolVar conv; | |||||
| if (out_dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| conv = opr::ConvBias::make( | |||||
| f, weight, bias, param, {}, OperatorNodeConfig{out_dtype}); | |||||
| } else { | |||||
| conv = opr::ConvBias::make(f, weight, bias, param, {}); | |||||
| } | |||||
| weight_idx++; | weight_idx++; | ||||
| bias_idx++; | bias_idx++; | ||||
| return conv; | return conv; | ||||
| @@ -269,17 +279,17 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) { | |||||
| SymbolVar mgb::bottleneck( | SymbolVar mgb::bottleneck( | ||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, | Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, | ||||
| size_t stride) { | |||||
| size_t stride, DType out_dtype) { | |||||
| size_t in_channels = f.node()->shape()[1]; | size_t in_channels = f.node()->shape()[1]; | ||||
| SymbolVar x = f; | SymbolVar x = f; | ||||
| if (t != 1) { | if (t != 1) { | ||||
| x = network.add_conv( | x = network.add_conv( | ||||
| f, input_channels * t, {1, 1}, dtype::Float32(), true, {1, 1}, {0, 0}); | |||||
| f, input_channels * t, {1, 1}, out_dtype, true, {1, 1}, {0, 0}); | |||||
| } | } | ||||
| x = network.add_group_conv( | x = network.add_group_conv( | ||||
| x, input_channels * t, input_channels * t, {3, 3}, dtype::Float32(), true, | |||||
| x, input_channels * t, input_channels * t, {3, 3}, out_dtype, true, | |||||
| {stride, stride}, {1, 1}); | {stride, stride}, {1, 1}); | ||||
| x = network.add_conv(x, channels, {1, 1}, dtype::Float32(), false, {1, 1}, {0, 0}); | |||||
| x = network.add_conv(x, channels, {1, 1}, out_dtype, false, {1, 1}, {0, 0}); | |||||
| if (stride == 1 && in_channels == channels) | if (stride == 1 && in_channels == channels) | ||||
| x = f + x; | x = f + x; | ||||
| return x; | return x; | ||||
| @@ -287,11 +297,11 @@ SymbolVar mgb::bottleneck( | |||||
| SymbolVar mgb::bottleneck_group( | SymbolVar mgb::bottleneck_group( | ||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, | Network& network, SymbolVar f, size_t input_channels, size_t channels, | ||||
| size_t stages, size_t s, size_t t) { | |||||
| size_t stages, size_t s, size_t t, DType out_dtype) { | |||||
| SymbolVar x = f; | SymbolVar x = f; | ||||
| for (size_t i = 0; i < stages; ++i) { | for (size_t i = 0; i < stages; ++i) { | ||||
| size_t stride = i == 0 ? s : 1; | size_t stride = i == 0 ? s : 1; | ||||
| x = bottleneck(network, x, input_channels, channels, t, stride); | |||||
| x = bottleneck(network, x, input_channels, channels, t, stride, out_dtype); | |||||
| input_channels = channels; | input_channels = channels; | ||||
| } | } | ||||
| return x; | return x; | ||||
| @@ -307,22 +317,34 @@ size_t make_divisible(size_t v, size_t divisor) { | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch) { | |||||
| SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch, DType out_dtype) { | |||||
| auto data = network.add_var("data", {batch, 3, 224, 224}); | auto data = network.add_var("data", {batch, 3, 224, 224}); | ||||
| if (out_dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| data = network.add_type_cvt(data, dtype::QuantizedS8{1.f}); | |||||
| } | |||||
| constexpr size_t round_nearest = 8; | constexpr size_t round_nearest = 8; | ||||
| auto x = network.add_conv( | auto x = network.add_conv( | ||||
| data, make_divisible(32, round_nearest), {3, 3}, dtype::Float32(), true, | |||||
| {2, 2}, {1, 1}); | |||||
| x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1); | |||||
| x = bottleneck_group(network, x, 16, make_divisible(24, round_nearest), 2, 2, 6); | |||||
| x = bottleneck_group(network, x, 24, make_divisible(32, round_nearest), 3, 2, 6); | |||||
| x = bottleneck_group(network, x, 32, make_divisible(64, round_nearest), 4, 2, 6); | |||||
| x = bottleneck_group(network, x, 64, make_divisible(96, round_nearest), 3, 1, 6); | |||||
| x = bottleneck_group(network, x, 96, make_divisible(160, round_nearest), 3, 2, 6); | |||||
| x = bottleneck_group(network, x, 160, make_divisible(320, round_nearest), 1, 1, 6); | |||||
| data, make_divisible(32, round_nearest), {3, 3}, out_dtype, true, {2, 2}, | |||||
| {1, 1}); | |||||
| x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1, out_dtype); | |||||
| x = bottleneck_group( | |||||
| network, x, 16, make_divisible(24, round_nearest), 2, 2, 6, out_dtype); | |||||
| x = bottleneck_group( | |||||
| network, x, 24, make_divisible(32, round_nearest), 3, 2, 6, out_dtype); | |||||
| x = bottleneck_group( | |||||
| network, x, 32, make_divisible(64, round_nearest), 4, 2, 6, out_dtype); | |||||
| x = bottleneck_group( | |||||
| network, x, 64, make_divisible(96, round_nearest), 3, 1, 6, out_dtype); | |||||
| x = bottleneck_group( | |||||
| network, x, 96, make_divisible(160, round_nearest), 3, 2, 6, out_dtype); | |||||
| x = bottleneck_group( | |||||
| network, x, 160, make_divisible(320, round_nearest), 1, 1, 6, out_dtype); | |||||
| x = network.add_conv( | x = network.add_conv( | ||||
| x, make_divisible(1280, round_nearest), {1, 1}, dtype::Float32(), true, | |||||
| {1, 1}, {0, 0}); | |||||
| x, make_divisible(1280, round_nearest), {1, 1}, out_dtype, true, {1, 1}, | |||||
| {0, 0}); | |||||
| if (out_dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| x = network.add_type_cvt(x, dtype::Float32()); | |||||
| } | |||||
| return x; | return x; | ||||
| } | } | ||||
| @@ -79,13 +79,14 @@ SymbolVarArray make_det( | |||||
| SymbolVar bottleneck( | SymbolVar bottleneck( | ||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, | Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, | ||||
| size_t stride); | |||||
| size_t stride, DType out_dtype = dtype::Float32()); | |||||
| SymbolVar bottleneck_group( | SymbolVar bottleneck_group( | ||||
| Network& network, SymbolVar f, size_t input_channels, size_t channels, | Network& network, SymbolVar f, size_t input_channels, size_t channels, | ||||
| size_t stages, size_t s, size_t t); | |||||
| size_t stages, size_t s, size_t t, DType out_dtype = dtype::Float32()); | |||||
| SymbolVar make_mobilenet_v2(Network& network, size_t batch = 1); | |||||
| SymbolVar make_mobilenet_v2( | |||||
| Network& network, size_t batch = 1, DType out_dtype = dtype::Float32()); | |||||
| } // namespace mgb | } // namespace mgb | ||||