| @@ -35,9 +35,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
| ). | ). | ||||
| add_enum(Doc('Format', 'convolution data/filter/output format; see ' | add_enum(Doc('Format', 'convolution data/filter/output format; see ' | ||||
| ':class:`RelayoutFormat` for more details'), | ':class:`RelayoutFormat` for more details'), | ||||
| 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', | |||||
| 'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', 'NCHW44', | |||||
| Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | ||||
| Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | ||||
| Doc('NCHW44_WINOGRAD', 'NCHW44 layout with weights tranformed by winograd'), | |||||
| Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | ||||
| 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | 'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | ||||
| ) | ) | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/common/conv_bias.h" | #include "src/common/conv_bias.h" | ||||
| @@ -33,7 +34,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
| const TensorLayout& bias, const TensorLayout& z, | const TensorLayout& bias, const TensorLayout& z, | ||||
| const TensorLayout& dst, size_t workspace_in_bytes) { | const TensorLayout& dst, size_t workspace_in_bytes) { | ||||
| if ((param().format == param::ConvBias::Format::NCHW_WINOGRAD || | if ((param().format == param::ConvBias::Format::NCHW_WINOGRAD || | ||||
| param().format == param::ConvBias::Format::NCHW88_WINOGRAD) && | |||||
| param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
| param().format == param::ConvBias::Format::NCHW44_WINOGRAD) && | |||||
| src.dtype.category() == DTypeCategory::QUANTIZED) { | src.dtype.category() == DTypeCategory::QUANTIZED) { | ||||
| megdnn_assert(filter.dtype.enumv() == DTypeEnum::QuantizedS16); | megdnn_assert(filter.dtype.enumv() == DTypeEnum::QuantizedS16); | ||||
| megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | megdnn_assert(src.dtype.enumv() == DTypeEnum::QuantizedS8 || | ||||
| @@ -45,7 +47,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
| float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | ||||
| float scale_filter = 0.f; | float scale_filter = 0.f; | ||||
| if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | ||||
| param().format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||||
| param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
| param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
| scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | ||||
| } else { | } else { | ||||
| scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | ||||
| @@ -58,7 +61,8 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
| float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; | float scale_src = src.dtype.param<dtype::Quantized8Asymm>().scale; | ||||
| float scale_filter = 0.f; | float scale_filter = 0.f; | ||||
| if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | if (param().format == param::ConvBias::Format::NCHW_WINOGRAD || | ||||
| param().format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||||
| param().format == param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
| param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
| scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | scale_filter = filter.dtype.param<dtype::QuantizedS16>().scale; | ||||
| } else { | } else { | ||||
| scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale; | scale_filter = filter.dtype.param<dtype::Quantized8Asymm>().scale; | ||||
| @@ -98,7 +102,9 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
| megdnn_assert(bias.shape[2] == 1); | megdnn_assert(bias.shape[2] == 1); | ||||
| megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s", | megdnn_assert(bias.shape[3] == dst.shape[3], "bias:%s, dst:%s", | ||||
| bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
| } else if (param().format == param::ConvBias::Format::NCHW4) { | |||||
| } else if (param().format == param::ConvBias::Format::NCHW4 || | |||||
| param().format == param::ConvBias::Format::NCHW44 || | |||||
| param().format == param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
| megdnn_assert(bias.shape[0] == 1); | megdnn_assert(bias.shape[0] == 1); | ||||
| megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | ||||
| bias.to_string().c_str(), dst.to_string().c_str()); | bias.to_string().c_str(), dst.to_string().c_str()); | ||||
| @@ -141,7 +147,10 @@ ConvBiasForward::CanonizedFilterMeta ConvBiasForward::check_exec( | |||||
| if (z.ndim != 0) { | if (z.ndim != 0) { | ||||
| megdnn_assert(param().format != param::ConvBias::Format::NCHW_WINOGRAD); | megdnn_assert(param().format != param::ConvBias::Format::NCHW_WINOGRAD); | ||||
| megdnn_assert(param().format != param::ConvBias::Format::NCHW88_WINOGRAD); | |||||
| megdnn_assert(param().format != | |||||
| param::ConvBias::Format::NCHW88_WINOGRAD); | |||||
| megdnn_assert(param().format != | |||||
| param::ConvBias::Format::NCHW44_WINOGRAD); | |||||
| megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); | megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); | ||||
| megdnn_assert(z.eq_shape(dst)); | megdnn_assert(z.eq_shape(dst)); | ||||
| } | } | ||||
| @@ -163,10 +172,7 @@ std::string ConvBias::algo_name(const std::string& base, const T& p) { | |||||
| } | } | ||||
| #define FOREACH_CONV_BIAS_PARAM(cb) \ | #define FOREACH_CONV_BIAS_PARAM(cb) \ | ||||
| cb(WinogradParam) \ | |||||
| cb(DirectParam) \ | |||||
| cb(MatmulParam) \ | |||||
| cb(DefaultParam) | |||||
| cb(WinogradParam) cb(DirectParam) cb(MatmulParam) cb(DefaultParam) | |||||
| #define cb(pt) \ | #define cb(pt) \ | ||||
| template <> \ | template <> \ | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
| @@ -55,7 +56,13 @@ spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW88_WINOGRAD>( | |||||
| //! f = m + r - 1 -> r = f + 1 - m | //! f = m + r - 1 -> r = f + 1 - m | ||||
| return filter - param.output_block_size + 1; | return filter - param.output_block_size + 1; | ||||
| } | } | ||||
| template <> | |||||
| uint32_t | |||||
| spatial_getter<param::ConvBias, param::ConvBias::Format::NCHW44_WINOGRAD>( | |||||
| uint32_t filter, const param::ConvBias& param) { | |||||
| //! f = m + r - 1 -> r = f + 1 - m | |||||
| return filter - param.output_block_size + 1; | |||||
| } | |||||
| template <typename Parameter, typename Param> | template <typename Parameter, typename Param> | ||||
| void make_canonized_filter_meta_nchw_nhwc( | void make_canonized_filter_meta_nchw_nhwc( | ||||
| @@ -273,7 +280,7 @@ void make_canonized_filter_meta_nchwxx( | |||||
| /** | /** | ||||
| * input: N IC/pack_size, H, W, pack_size | * input: N IC/pack_size, H, W, pack_size | ||||
| * | * | ||||
| * NCHW88 mode | |||||
| * NCHW88 and NCHW44 mode | |||||
| * filter: | * filter: | ||||
| * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)} | * {OC/pack_size, IC/pack_size, FH, FW, pack_size(IC), pack_size(OC)} | ||||
| * [dense] | * [dense] | ||||
| @@ -281,7 +288,7 @@ void make_canonized_filter_meta_nchwxx( | |||||
| * FH, FW, pack_size(IC), pack_size(OC)} [group] | * FH, FW, pack_size(IC), pack_size(OC)} [group] | ||||
| * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan] | * {GROUP/pack_size, 1, 1, FH, FW, pack_size} [chan] | ||||
| * | * | ||||
| ** NCHW88_WINOGRAD mode | |||||
| ** NCHW88_WINOGRAD and NCHW44_WINOGRAD mode | |||||
| * filter: | * filter: | ||||
| * {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC), | * {alpha, alpha, OC/pack_size, IC/pack_size, pack_size(IC), | ||||
| *pack_size(OC)} [dense] | *pack_size(OC)} [dense] | ||||
| @@ -291,6 +298,7 @@ void make_canonized_filter_meta_nchwxx( | |||||
| */ | */ | ||||
| megdnn_assert(param.format == Param::Format::NCHW88 || | megdnn_assert(param.format == Param::Format::NCHW88 || | ||||
| param.format == Param::Format::NCHW44 || | |||||
| param.format == Param::Format::NCHW88_WINOGRAD); | param.format == Param::Format::NCHW88_WINOGRAD); | ||||
| size_t img_ndim = 2; | size_t img_ndim = 2; | ||||
| size_t flt_start = 0; | size_t flt_start = 0; | ||||
| @@ -305,7 +313,8 @@ void make_canonized_filter_meta_nchwxx( | |||||
| filter[filter.ndim - 1]); | filter[filter.ndim - 1]); | ||||
| ret.group = 1; | ret.group = 1; | ||||
| flt_start = 0; | flt_start = 0; | ||||
| if (param.format == Param::Format::NCHW88_WINOGRAD) { | |||||
| if (param.format == Param::Format::NCHW88_WINOGRAD || | |||||
| param.format == Param::Format::NCHW44_WINOGRAD) { | |||||
| flt_start = 2; | flt_start = 2; | ||||
| } | } | ||||
| ret.ocpg = filter[flt_start] * pack_size; | ret.ocpg = filter[flt_start] * pack_size; | ||||
| @@ -314,6 +323,8 @@ void make_canonized_filter_meta_nchwxx( | |||||
| // ohwi8o | // ohwi8o | ||||
| megdnn_assert(param.format != Param::Format::NCHW88_WINOGRAD, | megdnn_assert(param.format != Param::Format::NCHW88_WINOGRAD, | ||||
| "Hybrid nchw88 mode in not support winograd"); | "Hybrid nchw88 mode in not support winograd"); | ||||
| megdnn_assert(param.format != Param::Format::NCHW44_WINOGRAD, | |||||
| "Hybrid nchw44 mode in not support winograd"); | |||||
| flt_start = 0; | flt_start = 0; | ||||
| flt_spatial_start = 1; | flt_spatial_start = 1; | ||||
| ret.group = 1; | ret.group = 1; | ||||
| @@ -321,20 +332,22 @@ void make_canonized_filter_meta_nchwxx( | |||||
| ret.icpg = filter[flt_start + 3]; | ret.icpg = filter[flt_start + 3]; | ||||
| } else { | } else { | ||||
| megdnn_assert(0, "not support nchw88 filter dim = %zu", | |||||
| megdnn_assert(0, "not support nchwxx filter dim = %zu", | |||||
| filter.ndim); | filter.ndim); | ||||
| } | } | ||||
| } else { | } else { | ||||
| megdnn_assert(param.sparse == Param::Sparse::GROUP, | megdnn_assert(param.sparse == Param::Sparse::GROUP, | ||||
| "invalid convolution sparse type"); | "invalid convolution sparse type"); | ||||
| flt_start = 1; | flt_start = 1; | ||||
| if (param.format == Param::Format::NCHW88_WINOGRAD) { | |||||
| if (param.format == Param::Format::NCHW88_WINOGRAD || | |||||
| param.format == Param::Format::NCHW44_WINOGRAD) { | |||||
| flt_start = 3; | flt_start = 3; | ||||
| } | } | ||||
| auto filter_oc = filter[flt_start]; | auto filter_oc = filter[flt_start]; | ||||
| auto filter_ic = filter[flt_start + 1]; | auto filter_ic = filter[flt_start + 1]; | ||||
| if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4) && | if (filter_oc == 1 && filter_ic == 1 && filter.ndim == (img_ndim + 4) && | ||||
| param.format != Param::Format::NCHW88_WINOGRAD) { | |||||
| param.format != Param::Format::NCHW88_WINOGRAD && | |||||
| param.format != Param::Format::NCHW44_WINOGRAD) { | |||||
| // Depthwise case goihw8g | // Depthwise case goihw8g | ||||
| megdnn_assert(filter.ndim == img_ndim + 4, | megdnn_assert(filter.ndim == img_ndim + 4, | ||||
| "bad filter ndim for group convolution: " | "bad filter ndim for group convolution: " | ||||
| @@ -343,7 +356,7 @@ void make_canonized_filter_meta_nchwxx( | |||||
| megdnn_assert(filter[filter.ndim - 1] == pack_size, | megdnn_assert(filter[filter.ndim - 1] == pack_size, | ||||
| "last dim of filter must be %zu, but %zu", pack_size, | "last dim of filter must be %zu, but %zu", pack_size, | ||||
| filter[filter.ndim - 1]); | filter[filter.ndim - 1]); | ||||
| ret.group = filter[0] * 8; | |||||
| ret.group = filter[0] * pack_size; | |||||
| ret.ocpg = filter_oc; | ret.ocpg = filter_oc; | ||||
| ret.icpg = filter_ic; | ret.icpg = filter_ic; | ||||
| @@ -381,6 +394,10 @@ void make_canonized_filter_meta_nchwxx( | |||||
| ret.spatial[i] = | ret.spatial[i] = | ||||
| spatial_getter<Param, Param::Format::NCHW88_WINOGRAD>( | spatial_getter<Param, Param::Format::NCHW88_WINOGRAD>( | ||||
| filter[i + flt_start - 2], param); | filter[i + flt_start - 2], param); | ||||
| } else if (param.format == Param::Format::NCHW44_WINOGRAD) { | |||||
| ret.spatial[i] = | |||||
| spatial_getter<Param, Param::Format::NCHW44_WINOGRAD>( | |||||
| filter[i + flt_start - 2], param); | |||||
| } else { | } else { | ||||
| ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; | ret.spatial[i] = filter[i + flt_start + flt_spatial_start]; | ||||
| } | } | ||||
| @@ -535,6 +552,10 @@ ConvolutionBase<Parameter>::make_canonized_filter_meta( | |||||
| param().format == Param::Format::NCHW88_WINOGRAD) { | param().format == Param::Format::NCHW88_WINOGRAD) { | ||||
| make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwxx<8, Parameter>(src_ndim, filter, | ||||
| param(), ret); | param(), ret); | ||||
| } else if (param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_WINOGRAD) { | |||||
| make_canonized_filter_meta_nchwxx<4, Parameter>(src_ndim, filter, | |||||
| param(), ret); | |||||
| } else if (param().format == Param::Format::NCHW32) { | } else if (param().format == Param::Format::NCHW32) { | ||||
| make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, | make_canonized_filter_meta_nchwx<32, Parameter>(src_ndim, filter, | ||||
| param(), ret); | param(), ret); | ||||
| @@ -629,18 +650,22 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
| } else { | } else { | ||||
| megdnn_assert(param().format == Param::Format::NHWCD4 || | megdnn_assert(param().format == Param::Format::NHWCD4 || | ||||
| param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
| param().format == Param::Format::NCHW32 || | param().format == Param::Format::NCHW32 || | ||||
| param().format == Param::Format::NCHW88 || | param().format == Param::Format::NCHW88 || | ||||
| param().format == Param::Format::NCHW88_WINOGRAD || | param().format == Param::Format::NCHW88_WINOGRAD || | ||||
| param().format == Param::Format::CHWN4); | param().format == Param::Format::CHWN4); | ||||
| img_dim = src.ndim - 3; | img_dim = src.ndim - 3; | ||||
| if (param().format == Param::Format::NCHW88 && filter.ndim == 5) { | |||||
| if ((param().format == Param::Format::NCHW88 || | |||||
| param().format == Param::Format::NCHW44) && | |||||
| filter.ndim == 5) { | |||||
| img_dim = src.ndim - 2; | img_dim = src.ndim - 2; | ||||
| } | } | ||||
| megdnn_assert(filter.ndim == img_dim + 3 || | megdnn_assert(filter.ndim == img_dim + 3 || | ||||
| (filter.ndim == img_dim + 2 && | (filter.ndim == img_dim + 2 && | ||||
| param().format == Param::Format::NCHW88) || | |||||
| (param().format == Param::Format::NCHW88 || | |||||
| param().format == Param::Format::NCHW44)) || | |||||
| filter.ndim == img_dim + 4 || | filter.ndim == img_dim + 4 || | ||||
| filter.ndim == img_dim + 5, | filter.ndim == img_dim + 5, | ||||
| "%s", errmsg().c_str()); | "%s", errmsg().c_str()); | ||||
| @@ -691,6 +716,21 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
| ", and last shape two is 8 but got src %s, filter %s", | ", and last shape two is 8 but got src %s, filter %s", | ||||
| src.to_string().c_str(), filter.to_string().c_str()); | src.to_string().c_str(), filter.to_string().c_str()); | ||||
| } | } | ||||
| if (param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_WINOGRAD) { | |||||
| megdnn_assert((src.ndim == 4 && filter.ndim == 5 && | |||||
| filter[filter.ndim - 1] == 4) || | |||||
| (src.ndim == 5 && | |||||
| ((filter.ndim == 6 && | |||||
| filter[filter.ndim - 1] == 4) || | |||||
| (filter.ndim == 7 && | |||||
| filter[filter.ndim - 1] == 4 && | |||||
| filter[filter.ndim - 2] == 4)) && | |||||
| src[src.ndim - 1] == 4), | |||||
| "NCHW44 require src ndim is 5 and filter's ndim is 6 " | |||||
| ", and last shape two is 4 but got src %s, filter %s", | |||||
| src.to_string().c_str(), filter.to_string().c_str()); | |||||
| } | |||||
| if (param().format == Param::Format::CHWN4) { | if (param().format == Param::Format::CHWN4) { | ||||
| megdnn_assert( | megdnn_assert( | ||||
| src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && | src.ndim == 5 && (filter.ndim == 5 || filter.ndim == 6) && | ||||
| @@ -808,6 +848,27 @@ ConvolutionBase<Parameter>::deduce_layout_fwd(const TensorLayout& src, | |||||
| cflt.group); | cflt.group); | ||||
| } | } | ||||
| } else if (param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW44_WINOGRAD) { | |||||
| megdnn_assert(src.ndim == 5 || (src.ndim == 4 && src[1] <= 8), | |||||
| "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", | |||||
| src.ndim); | |||||
| dst.ndim = 5; | |||||
| dst[0] = src[0]; | |||||
| auto oc = cflt.ocpg * cflt.group; | |||||
| megdnn_assert(oc % 4 == 0); | |||||
| dst[1] = oc / 4; | |||||
| dst[2] = infer_conv_shape(src[2], cflt.dilated_spatial[0], | |||||
| cflt.stride[0], cflt.padding[0]); | |||||
| dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], | |||||
| cflt.stride[1], cflt.padding[1]); | |||||
| dst[4] = 4; | |||||
| if (cflt.group == 1) { | |||||
| megdnn_assert(cflt.icpg * cflt.group == src[1] * 4 || | |||||
| (cflt.icpg * cflt.group == src[1]), | |||||
| "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, | |||||
| cflt.group); | |||||
| } | |||||
| } else if (param().format == Param::Format::CHWN4) { | } else if (param().format == Param::Format::CHWN4) { | ||||
| megdnn_assert(src.ndim == 5, | megdnn_assert(src.ndim == 5, | ||||
| "invalid src ndim for CHWN4, expected=5, got=%zu", | "invalid src ndim for CHWN4, expected=5, got=%zu", | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| @@ -47,6 +48,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, | |||||
| spatial_pos = 1; | spatial_pos = 1; | ||||
| c_pos = 3; | c_pos = 3; | ||||
| } else if (param().format == Param::Format::NCHW4 || | } else if (param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW88 || | param().format == Param::Format::NCHW88 || | ||||
| param().format == Param::Format::NCHW32) { | param().format == Param::Format::NCHW32) { | ||||
| megdnn_assert(src.ndim == 5_z, "%s", errmsg_c); | megdnn_assert(src.ndim == 5_z, "%s", errmsg_c); | ||||
| @@ -73,6 +75,7 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, | |||||
| iw = src[spatial_pos + 2]; | iw = src[spatial_pos + 2]; | ||||
| } | } | ||||
| if (param().format == Param::Format::NCHW4 || | if (param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::CHWN4) { | param().format == Param::Format::CHWN4) { | ||||
| c *= 4; | c *= 4; | ||||
| } | } | ||||
| @@ -96,7 +99,8 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, | |||||
| megdnn_assert(param().format == Param::Format::NHWC, | megdnn_assert(param().format == Param::Format::NHWC, | ||||
| "invalid pooling format"); | "invalid pooling format"); | ||||
| dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format); | dst = TensorLayout({n, oh, ow, c}, src.dtype, src.format); | ||||
| } else if (param().format == Param::Format::NCHW4) { | |||||
| } else if (param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44) { | |||||
| dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format}; | dst = TensorLayout{{n, c / 4, oh, ow, 4}, src.dtype, src.format}; | ||||
| } else if (param().format == Param::Format::NCHW88) { | } else if (param().format == Param::Format::NCHW88) { | ||||
| dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format}; | dst = TensorLayout{{n, c / 8, oh, ow, 8}, src.dtype, src.format}; | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/fallback/convolution/opr_impl.h" | #include "src/fallback/convolution/opr_impl.h" | ||||
| #include "src/common/algo_chooser.h" | #include "src/common/algo_chooser.h" | ||||
| @@ -157,9 +158,11 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||||
| if (param().format == Param::Format::NCHW88 || | if (param().format == Param::Format::NCHW88 || | ||||
| param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
| param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW || | param().format == Param::Format::NCHW || | ||||
| param().format == Param::Format::NCHW_WINOGRAD || | param().format == Param::Format::NCHW_WINOGRAD || | ||||
| param().format == Param::Format::NCHW88_WINOGRAD) { | |||||
| param().format == Param::Format::NCHW88_WINOGRAD || | |||||
| param().format == Param::Format::NCHW44_WINOGRAD) { | |||||
| spatial_pos = 2; | spatial_pos = 2; | ||||
| } else if (param().format == Param::Format::NHWC) { | } else if (param().format == Param::Format::NHWC) { | ||||
| spatial_pos = 1; | spatial_pos = 1; | ||||
| @@ -188,7 +191,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||||
| param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; | ||||
| if (param().format == Param::Format::NCHW_WINOGRAD || | if (param().format == Param::Format::NCHW_WINOGRAD || | ||||
| param().format == Param::Format::NCHW88_WINOGRAD) { | |||||
| param().format == Param::Format::NCHW88_WINOGRAD || | |||||
| param().format == Param::Format::NCHW44_WINOGRAD) { | |||||
| size_t flt_start = 0; | size_t flt_start = 0; | ||||
| if (param().sparse == Param::Sparse::GROUP) { | if (param().sparse == Param::Sparse::GROUP) { | ||||
| flt_start = 1; | flt_start = 1; | ||||
| @@ -325,7 +329,7 @@ const char* ConvBiasImpl::get_algorithm_set_name() const { | |||||
| return "F0"; | return "F0"; | ||||
| } | } | ||||
| namespace megdnn{ | |||||
| namespace megdnn { | |||||
| namespace fallback { | namespace fallback { | ||||
| template <typename T> | template <typename T> | ||||
| @@ -342,7 +346,6 @@ const T* ConvBiasImpl::NCBKernParam::src(size_t batch_id, size_t group_pack_id, | |||||
| batch_offset + group_offset + channel_offset); | batch_offset + group_offset + channel_offset); | ||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, | const T* ConvBiasImpl::NCBKernParam::filter(size_t group_pack_id, | ||||
| size_t pack_group_size) const { | size_t pack_group_size) const { | ||||
| @@ -453,5 +456,4 @@ INST(void) | |||||
| } // namespace fallback | } // namespace fallback | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -87,7 +88,9 @@ class ConvBias { | |||||
| if (param.filter_meta.format != | if (param.filter_meta.format != | ||||
| param::ConvBias::Format::NCHW_WINOGRAD && | param::ConvBias::Format::NCHW_WINOGRAD && | ||||
| param.filter_meta.format != | param.filter_meta.format != | ||||
| param::ConvBias::Format::NCHW88_WINOGRAD) { | |||||
| param::ConvBias::Format::NCHW88_WINOGRAD && | |||||
| param.filter_meta.format != | |||||
| param::ConvBias::Format::NCHW44_WINOGRAD) { | |||||
| filter_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA * OC * | filter_transform_buf_size = Strategy::ALPHA * Strategy::ALPHA * OC * | ||||
| IC * sizeof(input_filter_compute_type); | IC * sizeof(input_filter_compute_type); | ||||
| } | } | ||||
| @@ -95,7 +98,8 @@ class ConvBias { | |||||
| get_wbundle_compute(param, matmul_algo).total_size_in_bytes() * | get_wbundle_compute(param, matmul_algo).total_size_in_bytes() * | ||||
| nr_threads; | nr_threads; | ||||
| if (param.filter_meta.format == param::ConvBias::Format::NCHW || | if (param.filter_meta.format == param::ConvBias::Format::NCHW || | ||||
| param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||||
| param.filter_meta.format == param::ConvBias::Format::NCHW88 || | |||||
| param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
| return WorkspaceBundle( | return WorkspaceBundle( | ||||
| nullptr, | nullptr, | ||||
| {winograd_comput_size, filter_transform_buf_size * GROUP}); | {winograd_comput_size, filter_transform_buf_size * GROUP}); | ||||
| @@ -103,7 +107,9 @@ class ConvBias { | |||||
| megdnn_assert(param.filter_meta.format == | megdnn_assert(param.filter_meta.format == | ||||
| param::ConvBias::Format::NCHW_WINOGRAD || | param::ConvBias::Format::NCHW_WINOGRAD || | ||||
| param.filter_meta.format == | param.filter_meta.format == | ||||
| param::ConvBias::Format::NCHW88_WINOGRAD); | |||||
| param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
| param.filter_meta.format == | |||||
| param::ConvBias::Format::NCHW44_WINOGRAD); | |||||
| return WorkspaceBundle(nullptr, {winograd_comput_size}); | return WorkspaceBundle(nullptr, {winograd_comput_size}); | ||||
| } | } | ||||
| } | } | ||||
| @@ -210,11 +216,17 @@ public: | |||||
| reinterpret_cast<input_filter_compute_type*>( | reinterpret_cast<input_filter_compute_type*>( | ||||
| reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + | reinterpret_cast<uintptr_t>(bundle_compute.get(2)) + | ||||
| compute_workspace_size_per_thread * thread_id); | compute_workspace_size_per_thread * thread_id); | ||||
| const stype* filter_ptr = kern_param.filter<stype>(group_id); | const stype* filter_ptr = kern_param.filter<stype>(group_id); | ||||
| size_t oc_start = oc_id, oc_end = oc_id+1; | |||||
| size_t oc_start = oc_id, oc_end = oc_id + 1; | |||||
| if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) { | if (kern_param.filter_meta.format == param::ConvBias::Format::NCHW88) { | ||||
| oc_start = 8 * oc_id; | oc_start = 8 * oc_id; | ||||
| oc_end = oc_start + 8; | oc_end = oc_start + 8; | ||||
| } else if (kern_param.filter_meta.format == | |||||
| param::ConvBias::Format::NCHW44) { | |||||
| oc_start = 4 * oc_id; | |||||
| oc_end = oc_start + 4; | |||||
| } | } | ||||
| strategy.filter(filter_ptr, filter_transform_buf, transform_mid_buf, OC, | strategy.filter(filter_ptr, filter_transform_buf, transform_mid_buf, OC, | ||||
| IC, oc_start, oc_end); | IC, oc_start, oc_end); | ||||
| @@ -279,7 +291,8 @@ public: | |||||
| static_cast<const input_filter_compute_type*>( | static_cast<const input_filter_compute_type*>( | ||||
| ncb_param.filter<input_filter_compute_type>(group_id)); | ncb_param.filter<input_filter_compute_type>(group_id)); | ||||
| if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW || | if (ncb_param.filter_meta.format == param::ConvBias::Format::NCHW || | ||||
| ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||||
| ncb_param.filter_meta.format == param::ConvBias::Format::NCHW88 || | |||||
| ncb_param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
| filter_transform_buf = reinterpret_cast<input_filter_compute_type*>( | filter_transform_buf = reinterpret_cast<input_filter_compute_type*>( | ||||
| reinterpret_cast<uintptr_t>(bundle_top.get(1)) + | reinterpret_cast<uintptr_t>(bundle_top.get(1)) + | ||||
| group_id * filter_group_size); | group_id * filter_group_size); | ||||
| @@ -404,14 +417,18 @@ public: | |||||
| param.filter_meta.stride[1] == 1 && | param.filter_meta.stride[1] == 1 && | ||||
| (param.filter_meta.format == param::ConvBias::Format::NCHW || | (param.filter_meta.format == param::ConvBias::Format::NCHW || | ||||
| param.filter_meta.format == param::ConvBias::Format::NCHW88 || | param.filter_meta.format == param::ConvBias::Format::NCHW88 || | ||||
| param.filter_meta.format == param::ConvBias::Format::NCHW44 || | |||||
| param.filter_meta.format == | param.filter_meta.format == | ||||
| param::ConvBias::Format::NCHW_WINOGRAD || | param::ConvBias::Format::NCHW_WINOGRAD || | ||||
| param.filter_meta.format == | param.filter_meta.format == | ||||
| param::ConvBias::Format::NCHW88_WINOGRAD)); | |||||
| param::ConvBias::Format::NCHW88_WINOGRAD || | |||||
| param.filter_meta.format == | |||||
| param::ConvBias::Format::NCHW44_WINOGRAD)); | |||||
| SmallVector<NCBKern> kerns; | SmallVector<NCBKern> kerns; | ||||
| if (param.filter_meta.format == param::ConvBias::Format::NCHW || | if (param.filter_meta.format == param::ConvBias::Format::NCHW || | ||||
| param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||||
| param.filter_meta.format == param::ConvBias::Format::NCHW88 || | |||||
| param.filter_meta.format == param::ConvBias::Format::NCHW44) { | |||||
| //! probably a gcc bug, labmda require capturing 'this' to call | //! probably a gcc bug, labmda require capturing 'this' to call | ||||
| //! static member function | //! static member function | ||||
| auto filter_process_kern = [this, strategy, bundle_top, | auto filter_process_kern = [this, strategy, bundle_top, | ||||
| @@ -426,6 +443,10 @@ public: | |||||
| if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | ||||
| megdnn_assert(OC % 8 == 0); | megdnn_assert(OC % 8 == 0); | ||||
| oc_parallelism = OC / 8; | oc_parallelism = OC / 8; | ||||
| } else if (param.filter_meta.format == | |||||
| param::ConvBias::Format::NCHW44) { | |||||
| megdnn_assert(OC % 4 == 0); | |||||
| oc_parallelism = OC / 4; | |||||
| } | } | ||||
| kerns.push_back({filter_process_kern, {GROUP, 1, oc_parallelism}}); | kerns.push_back({filter_process_kern, {GROUP, 1, oc_parallelism}}); | ||||
| } | } | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/fallback/convolution/opr_impl.h" | #include "src/fallback/convolution/opr_impl.h" | ||||
| @@ -142,7 +143,8 @@ ConvolutionImpl::NCBKernSizeParam ConvolutionImpl::make_ncb_kern_size_param( | |||||
| size_t spatial_pos; | size_t spatial_pos; | ||||
| if (param().format == Param::Format::NCHW88 || | if (param().format == Param::Format::NCHW88 || | ||||
| param().format == Param::Format::NCHW8 || | param().format == Param::Format::NCHW8 || | ||||
| param().format == Param::Format::NCHW4) { | |||||
| param().format == Param::Format::NCHW4 || | |||||
| param().format == Param::Format::NCHW44) { | |||||
| spatial_pos = 2; | spatial_pos = 2; | ||||
| } else if (param().format == Param::Format::NCHW || | } else if (param().format == Param::Format::NCHW || | ||||
| param().format == Param::Format::NCHW_WINOGRAD) { | param().format == Param::Format::NCHW_WINOGRAD) { | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| @@ -145,6 +146,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| using Format = param::Convolution::Format; | using Format = param::Convolution::Format; | ||||
| if (filter_meta.format == Format::NCHW || | if (filter_meta.format == Format::NCHW || | ||||
| filter_meta.format == Format::NCHW88 || | filter_meta.format == Format::NCHW88 || | ||||
| filter_meta.format == Format::NCHW44 || | |||||
| filter_meta.format == Format::NCHW4 || | filter_meta.format == Format::NCHW4 || | ||||
| filter_meta.format == Format::NCHW8 || | filter_meta.format == Format::NCHW8 || | ||||
| filter_meta.format == Format::NCHW32) { | filter_meta.format == Format::NCHW32) { | ||||
| @@ -171,7 +173,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| OW = dst.layout.shape[spatial_start + 1]; | OW = dst.layout.shape[spatial_start + 1]; | ||||
| if (filter_meta.format == Format::NCHW4 || | if (filter_meta.format == Format::NCHW4 || | ||||
| filter_meta.format == Format::CHWN4) { | |||||
| filter_meta.format == Format::CHWN4 || | |||||
| filter_meta.format == Format::NCHW44) { | |||||
| OC *= 4; | OC *= 4; | ||||
| } else if (filter_meta.format == Format::NCHW8 || | } else if (filter_meta.format == Format::NCHW8 || | ||||
| filter_meta.format == Format::NCHW88) { | filter_meta.format == Format::NCHW88) { | ||||
| @@ -216,6 +219,26 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| FS_G = FS_OC * filter_meta.ocpg / 8; | FS_G = FS_OC * filter_meta.ocpg / 8; | ||||
| } | } | ||||
| } | } | ||||
| } else if (filter_meta.format == Format::NCHW44) { | |||||
| if (filter_meta.group > 1 && filter_meta.icpg == 1 && | |||||
| src.layout.ndim == 5 && filter_meta.ocpg == 1) { | |||||
| FS_SPATIAL = 4; | |||||
| FS_IC = FH * FW * FS_SPATIAL; | |||||
| FS_OC = FS_IC * filter_meta.icpg; | |||||
| FS_G = FS_OC * filter_meta.ocpg; | |||||
| } else { | |||||
| if (src.layout.ndim == 4 && dst.layout.ndim == 5) { | |||||
| FS_IC = 4; | |||||
| FS_SPATIAL = filter_meta.icpg * FS_IC; | |||||
| FS_OC = FH * FW * FS_SPATIAL; | |||||
| FS_G = FS_OC * filter_meta.ocpg / 4; | |||||
| } else { | |||||
| FS_SPATIAL = 4 * 4; | |||||
| FS_IC = FH * FW * FS_SPATIAL; | |||||
| FS_OC = FS_IC * filter_meta.icpg / 4; | |||||
| FS_G = FS_OC * filter_meta.ocpg / 4; | |||||
| } | |||||
| } | |||||
| } else { | } else { | ||||
| // g, oc, fh, fw, ic | // g, oc, fh, fw, ic | ||||
| megdnn_assert(filter_meta.format == Format::NHWC); | megdnn_assert(filter_meta.format == Format::NHWC); | ||||
| @@ -259,6 +282,16 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| h * layout.stride[2] + w * layout.stride[3] + | h * layout.stride[2] + w * layout.stride[3] + | ||||
| (c & 0b111) * layout.stride[4]; | (c & 0b111) * layout.stride[4]; | ||||
| } | } | ||||
| } else if (filter_meta.format == Format::NCHW44) { | |||||
| if (filter_meta.format == Format::NCHW44 && !is_output && | |||||
| src.layout.ndim == 4) { | |||||
| return n * layout.stride[0] + c * layout.stride[1] + | |||||
| h * layout.stride[2] + w * layout.stride[3]; | |||||
| } else { | |||||
| return n * layout.stride[0] + (c / 4) * layout.stride[1] + | |||||
| h * layout.stride[2] + w * layout.stride[3] + | |||||
| (c % 4) * layout.stride[4]; | |||||
| } | |||||
| } else if (filter_meta.format == Format::NCHW32) { | } else if (filter_meta.format == Format::NCHW32) { | ||||
| return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | return n * layout.stride[0] + (c >> 5) * layout.stride[1] + | ||||
| h * layout.stride[2] + w * layout.stride[3] + | h * layout.stride[2] + w * layout.stride[3] + | ||||
| @@ -315,6 +348,27 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||||
| megdnn_assert( | megdnn_assert( | ||||
| 0, "nchw88 naive not support this input and output\n"); | 0, "nchw88 naive not support this input and output\n"); | ||||
| } | } | ||||
| } else if (filter_meta.format == Format::NCHW44) { | |||||
| if (src.layout.ndim == 4) { | |||||
| // ic < 8, input is nchw | |||||
| return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + | |||||
| (fh * FW + fw) * FS_SPATIAL + (ic - ic0) * FS_IC + | |||||
| gc_out.cur_off % 4; | |||||
| } else if (filter_meta.group > 1 && filter_meta.icpg == 1 && | |||||
| filter_meta.ocpg == 1 && src.layout.ndim == 5) { | |||||
| // dw case | |||||
| return gc_out.cur_grp / 4 * FS_G + gc_out.cur_off * FS_OC + | |||||
| (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL + | |||||
| gc_out.cur_grp % 4; | |||||
| } else if (src.layout.ndim == 5) { | |||||
| // normal case | |||||
| return gc_out.cur_grp * FS_G + gc_out.cur_off / 4 * FS_OC + | |||||
| (ic - ic0) / 4 * FS_IC + (fh * FW + fw) * FS_SPATIAL + | |||||
| ((ic - ic0) % 4) * 4 + gc_out.cur_off % 4; | |||||
| } else { | |||||
| megdnn_assert( | |||||
| 0, "nchw44 naive not support this input and output\n"); | |||||
| } | |||||
| } else { | } else { | ||||
| return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | ||||
| (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL; | (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL; | ||||
| @@ -504,6 +558,7 @@ void forward(_megdnn_tensor_in src, const ftype* fptr, _megdnn_tensor_out dst, | |||||
| megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW || | megdnn_assert(filter_meta.format == param::Convolution::Format::NCHW || | ||||
| filter_meta.format == param::Convolution::Format::NHWC || | filter_meta.format == param::Convolution::Format::NHWC || | ||||
| filter_meta.format == param::Convolution::Format::NCHW88 || | filter_meta.format == param::Convolution::Format::NCHW88 || | ||||
| filter_meta.format == param::Convolution::Format::NCHW44 || | |||||
| filter_meta.format == param::Convolution::Format::NCHW4); | filter_meta.format == param::Convolution::Format::NCHW4); | ||||
| compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | compute2d<stype, ftype, dtype, comp_type, StrategyFwd>( | ||||
| src, const_cast<ftype*>(fptr), dst, filter_meta); | src, const_cast<ftype*>(fptr), dst, filter_meta); | ||||
| @@ -557,6 +612,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| switch (filter_meta.format) { | switch (filter_meta.format) { | ||||
| case param::Convolution::Format::NCHW: | case param::Convolution::Format::NCHW: | ||||
| case param::Convolution::Format::NCHW88: | case param::Convolution::Format::NCHW88: | ||||
| case param::Convolution::Format::NCHW44: | |||||
| case param::Convolution::Format::NHWC: | case param::Convolution::Format::NHWC: | ||||
| case param::Convolution::Format::NCHW4: | case param::Convolution::Format::NCHW4: | ||||
| case param::Convolution::Format::NCHW8: | case param::Convolution::Format::NCHW8: | ||||
| @@ -633,6 +689,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
| } \ | } \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| case Format::NCHW44: | |||||
| case Format::NCHW4: { | case Format::NCHW4: { | ||||
| BIAS_ADD_NCHWx(4); | BIAS_ADD_NCHWx(4); | ||||
| break; | break; | ||||
| @@ -6,7 +6,8 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/naive/pooling/opr_impl.h" | #include "src/naive/pooling/opr_impl.h" | ||||
| @@ -168,6 +169,13 @@ struct NCHW88IdxGetter { | |||||
| return id; | return id; | ||||
| } | } | ||||
| }; | }; | ||||
| struct NCHW44IdxGetter { | |||||
| static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t, | |||||
| size_t C, size_t H, size_t W) { | |||||
| size_t id = (((n * (C >> 2) + (c >> 2)) * H + h) * W + w) * 4 + (c % 4); | |||||
| return id; | |||||
| } | |||||
| }; | |||||
| struct CHWN4IdxGetter { | struct CHWN4IdxGetter { | ||||
| static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t N, | static size_t get_idx(size_t n, size_t c, size_t h, size_t w, size_t N, | ||||
| @@ -375,6 +383,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| if (param().format == Param::Format::NCHW || | if (param().format == Param::Format::NCHW || | ||||
| param().format == Param::Format::NCHW4 || | param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW88 || | param().format == Param::Format::NCHW88 || | ||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::NCHW32) { | param().format == Param::Format::NCHW32) { | ||||
| c_pos = 1; | c_pos = 1; | ||||
| spatial_pos = 2; | spatial_pos = 2; | ||||
| @@ -401,6 +410,7 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| OW = dst.layout.shape[spatial_pos + 2]; | OW = dst.layout.shape[spatial_pos + 2]; | ||||
| } | } | ||||
| if (param().format == Param::Format::NCHW4 || | if (param().format == Param::Format::NCHW4 || | ||||
| param().format == Param::Format::NCHW44 || | |||||
| param().format == Param::Format::CHWN4) { | param().format == Param::Format::CHWN4) { | ||||
| C *= 4; | C *= 4; | ||||
| } | } | ||||
| @@ -437,6 +447,9 @@ void PoolingForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| case Param::Format::NCHW88: \ | case Param::Format::NCHW88: \ | ||||
| DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \ | DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW88IdxGetter); \ | ||||
| break; \ | break; \ | ||||
| case Param::Format::NCHW44: \ | |||||
| DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW44IdxGetter); \ | |||||
| break; \ | |||||
| case Param::Format::NCHW32: \ | case Param::Format::NCHW32: \ | ||||
| DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \ | DISPATCH_WITH_POOLER_AND_IDX_GETTER(Pooler, NCHW32IdxGetter); \ | ||||
| break; \ | break; \ | ||||
| @@ -6,13 +6,14 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "test/naive/fixture.h" | |||||
| #include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
| #include "test/common/checker.h" | #include "test/common/checker.h" | ||||
| #include "test/common/workspace_wrapper.h" | #include "test/common/workspace_wrapper.h" | ||||
| #include "test/naive/fixture.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace test; | using namespace test; | ||||
| @@ -35,55 +36,39 @@ private: | |||||
| } // namespace | } // namespace | ||||
| TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32) { | TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32) { | ||||
| Checker<ConvBias> checker(handle(), /* check_dispatch */false); | |||||
| Checker<ConvBias> checker(handle(), /* check_dispatch */ false); | |||||
| ConvBias::Param param; | ConvBias::Param param; | ||||
| param.format = ConvBias::Param::Format::NCHW; | param.format = ConvBias::Param::Format::NCHW; | ||||
| checker.set_param(param).exect( | checker.set_param(param).exect( | ||||
| Testcase{ | |||||
| TensorValue({1, 1, 4, 4}, dtype::QuantizedS8(0.1f), | |||||
| {90-128, 136-128, 85-128, 204-128, | |||||
| 48-128, 9-128, 226-128, 25-128, | |||||
| 118-128, 109-128, 87-128, 132-128, | |||||
| 104-128, 163-128, 25-128, 90-128}), | |||||
| TensorValue({3, 1, 3, 3}, dtype::QuantizedS8(0.2f), | |||||
| {153-124, 170-124, 102-124, | |||||
| 103-124, 23-124, 213-124, | |||||
| 116-124, 195-124, 191-124, | |||||
| 44-124, 50-124, 247-124, | |||||
| 172-124, 42-124, 32-124, | |||||
| 233-124, 163-124, 247-124, | |||||
| 120-124, 241-124, 209-124, | |||||
| 83-124, 201-124, 115-124, | |||||
| 32-124, 140-124, 147-124}), | |||||
| TensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.02f), | |||||
| {0, 0, 0}), | |||||
| TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.3f), | |||||
| {1234, 0, | |||||
| 0, 0, | |||||
| 0, 0, | |||||
| 0, 0, | |||||
| 0, -234, | |||||
| 0, 0}), | |||||
| {}}, | |||||
| Testcase{ | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.1f * 0.2f), | |||||
| {37127, -22475, | |||||
| -15694, -1920, | |||||
| -12813, 4440, | |||||
| 18190, -13195, | |||||
| -9659, 12423, | |||||
| -5558, -4969})}); | |||||
| Testcase{TensorValue({1, 1, 4, 4}, dtype::QuantizedS8(0.1f), | |||||
| {90 - 128, 136 - 128, 85 - 128, 204 - 128, | |||||
| 48 - 128, 9 - 128, 226 - 128, 25 - 128, | |||||
| 118 - 128, 109 - 128, 87 - 128, 132 - 128, | |||||
| 104 - 128, 163 - 128, 25 - 128, 90 - 128}), | |||||
| TensorValue({3, 1, 3, 3}, dtype::QuantizedS8(0.2f), | |||||
| {153 - 124, 170 - 124, 102 - 124, 103 - 124, | |||||
| 23 - 124, 213 - 124, 116 - 124, 195 - 124, | |||||
| 191 - 124, 44 - 124, 50 - 124, 247 - 124, | |||||
| 172 - 124, 42 - 124, 32 - 124, 233 - 124, | |||||
| 163 - 124, 247 - 124, 120 - 124, 241 - 124, | |||||
| 209 - 124, 83 - 124, 201 - 124, 115 - 124, | |||||
| 32 - 124, 140 - 124, 147 - 124}), | |||||
| TensorValue({1, 3, 1, 1}, dtype::QuantizedS32(0.02f), | |||||
| {0, 0, 0}), | |||||
| TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.3f), | |||||
| {1234, 0, 0, 0, 0, 0, 0, 0, 0, -234, 0, 0}), | |||||
| {}}, | |||||
| Testcase{{}, | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| TensorValue({1, 3, 2, 2}, dtype::QuantizedS32(0.1f * 0.2f), | |||||
| {37127, -22475, -15694, -1920, | |||||
| -12813, 4440, 18190, -13195, | |||||
| -9659, 12423, -5558, -4969})}); | |||||
| } | } | ||||
| TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) { | TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) { | ||||
| @@ -175,10 +160,8 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED4x4x32) { | |||||
| {0, 0, 0, 0, 0, 0, 0, 0}), | {0, 0, 0, 0, 0, 0, 0, 0}), | ||||
| TensorValue( | TensorValue( | ||||
| {1, 1, 2, 2, 8}, dtype::QuantizedS32(0.3f), | {1, 1, 2, 2, 8}, dtype::QuantizedS32(0.3f), | ||||
| {0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, -87, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0}), | |||||
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, -87, 0, 0, 0, 0, 0, 0, 0, 0, 0}), | |||||
| {}}, | {}}, | ||||
| Testcase{ | Testcase{ | ||||
| {}, | {}, | ||||
| @@ -316,8 +299,221 @@ TEST_F(NAIVE, CONV_BIAS_QUANTIZED8x8x32_NCHW32) { | |||||
| TensorNDArray{src_ts_32.tensornd(), | TensorNDArray{src_ts_32.tensornd(), | ||||
| filter_ts_32.tensornd(), | filter_ts_32.tensornd(), | ||||
| bias_ts_32.tensornd(), | bias_ts_32.tensornd(), | ||||
| z_ts_32.tensornd(), {}}, | |||||
| z_ts_32.tensornd(), | |||||
| {}}, | |||||
| TensorNDArray{{}, {}, {}, {}, dst_ts_32.tensornd()}); | TensorNDArray{{}, {}, {}, {}, dst_ts_32.tensornd()}); | ||||
| } | } | ||||
| TEST_F(NAIVE, CONV_BIAS_NCHW44) { | |||||
| Checker<ConvBias> checker(handle(), /* check_dispatch */ false); | |||||
| ConvBias::Param param; | |||||
| param.format = ConvBias::Param::Format::NCHW44; | |||||
| size_t n = 1; | |||||
| size_t ic = 4; | |||||
| size_t oc = 8; | |||||
| size_t h = 2; | |||||
| size_t w = 2; | |||||
| size_t filter_size = 3; | |||||
| size_t pad = 1; | |||||
| auto src_tensor_shape = TensorShape{n, ic / 4, h, w, 4}; | |||||
| auto weight_tensor_shape = | |||||
| TensorShape{oc / 4, ic / 4, filter_size, filter_size, 4, 4}; | |||||
| auto bias_tensor_shape = TensorShape{1, oc / 4, 1, 1, 4}; | |||||
| param.pad_h = pad; | |||||
| param.pad_w = pad; | |||||
| UniformIntRNG rng{-127, 127}; | |||||
| checker.set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .set_dtype(4, dtype::Float32()) | |||||
| .set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &rng) | |||||
| .set_epsilon(1e-3) | |||||
| .set_param(param) | |||||
| .execs({src_tensor_shape, | |||||
| weight_tensor_shape, | |||||
| bias_tensor_shape, | |||||
| {}, | |||||
| {}}); | |||||
| checker.set_dtype(0, dtype::QuantizedS8(2.f)) | |||||
| .set_dtype(1, dtype::QuantizedS8(3.f)) | |||||
| .set_dtype(2, dtype::QuantizedS32(6.f)) | |||||
| .set_dtype(4, dtype::QuantizedS32(6.f)) | |||||
| .set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &rng) | |||||
| .set_epsilon(1e-3) | |||||
| .set_param(param) | |||||
| .execs({src_tensor_shape, | |||||
| weight_tensor_shape, | |||||
| bias_tensor_shape, | |||||
| {}, | |||||
| {}}); | |||||
| { | |||||
| // test normal conv | |||||
| ConvBias::Param param; | |||||
| param.format = ConvBias::Param::Format::NCHW44; | |||||
| param.sparse = ConvBias::Param::Sparse::DENSE; | |||||
| param.pad_h = 1; | |||||
| param.pad_w = 1; | |||||
| checker.set_param(param).exect( | |||||
| Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
| {7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7, | |||||
| 6, 4}), | |||||
| TensorValue( | |||||
| {1, 1, 3, 3, 4, 4}, dtype::Float32(), | |||||
| {3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0, | |||||
| 7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2, | |||||
| 2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7, | |||||
| 7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4, | |||||
| 1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8, | |||||
| 1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4, | |||||
| 1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3, | |||||
| 2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1, | |||||
| 1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4, | |||||
| 3, 3, 7, 2, 8, 1, 1, 1, 4}), | |||||
| TensorValue({1, 1, 1, 1, 4}, dtype::Float32(), | |||||
| {7, 2, 8, 1}), | |||||
| TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0}), | |||||
| {}}, | |||||
| Testcase{ | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
| {264, 338, 309, 195, 276, 332, 390, 199, | |||||
| 224, 268, 311, 218, 288, 311, 346, 277})}); | |||||
| } | |||||
| { | |||||
| // test dw conv | |||||
| ConvBias::Param param; | |||||
| param.format = ConvBias::Param::Format::NCHW44; | |||||
| param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
| param.pad_h = 1; | |||||
| param.pad_w = 1; | |||||
| checker.set_param(param).exect( | |||||
| Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
| {5, 8, 3, 2, 4, 6, 1, 5, 0, 8, 2, 6, 8, 6, | |||||
| 5, 7}), | |||||
| TensorValue({1, 1, 1, 3, 3, 4}, dtype::Float32(), | |||||
| {3, 0, 3, 1, 6, 5, 7, 3, 5, 0, 0, 7, | |||||
| 4, 6, 0, 1, 8, 2, 3, 7, 1, 0, 2, 4, | |||||
| 7, 5, 3, 0, 6, 2, 1, 5, 8, 6, 3, 1}), | |||||
| TensorValue({1, 1, 1, 1, 4}, dtype::Float32(), | |||||
| {4, 3, 5, 6}), | |||||
| TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0}), | |||||
| {}}, | |||||
| Testcase{{}, | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| TensorValue({1, 1, 2, 2, 4}, dtype::Float32(), | |||||
| {112, 71, 33, 77, 104, 115, 19, 78, 62, 59, | |||||
| 42, 117, 107, 93, 36, 78})}); | |||||
| } | |||||
| { | |||||
| // test group conv | |||||
| ConvBias::Param param; | |||||
| param.format = ConvBias::Param::Format::NCHW44; | |||||
| param.sparse = ConvBias::Param::Sparse::GROUP; | |||||
| param.pad_h = 1; | |||||
| param.pad_w = 1; | |||||
| checker.set_param(param).exect( | |||||
| Testcase{TensorValue({1, 2, 2, 2, 4}, dtype::Float32(), | |||||
| {6, 3, 2, 7, 7, 6, 4, 5, 8, 6, 3, | |||||
| 1, 1, 2, 8, 3, 1, 0, 6, 1, 3, 3, | |||||
| 6, 0, 0, 5, 6, 7, 2, 2, 4, 4}), | |||||
| TensorValue( | |||||
| {2, 1, 1, 3, 3, 4, 4}, dtype::Float32(), | |||||
| {3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0, | |||||
| 7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2, | |||||
| 2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7, | |||||
| 7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4, | |||||
| 1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8, | |||||
| 1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4, | |||||
| 1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3, | |||||
| 2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1, | |||||
| 1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4, | |||||
| 3, 3, 7, 2, 8, 1, 1, 1, 4, 7, 4, 5, 0, 6, 8, | |||||
| 7, 4, 8, 1, 3, 5, 3, 0, 0, 3, 7, 7, 7, 3, 8, | |||||
| 1, 2, 0, 1, 1, 2, 1, 3, 0, 0, 1, 1, 3, 0, 5, | |||||
| 6, 3, 0, 5, 4, 1, 4, 7, 0, 2, 1, 6, 7, 8, 0, | |||||
| 2, 1, 6, 7, 6, 3, 2, 7, 6, 5, 1, 1, 1, 2, 4, | |||||
| 6, 3, 3, 8, 0, 7, 1, 3, 7, 3, 2, 2, 4, 3, 5, | |||||
| 5, 6, 3, 3, 1, 2, 3, 0, 4, 0, 3, 3, 5, 5, 5, | |||||
| 2, 3, 1, 5, 4, 5, 8, 1, 7, 2, 1, 0, 1, 8, 2, | |||||
| 6, 7, 8, 4, 4, 7, 8, 4, 5, 8, 1, 1, 0, 7, 8, | |||||
| 4, 2, 2, 8, 6, 5, 2, 4, 8, 4, 0, 4, 0, 2, 1, | |||||
| 7, 1, 6}), | |||||
| TensorValue({1, 2, 1, 1, 4}, dtype::Float32(), | |||||
| {1, 8, 5, 6, 2, 8, 7, 7}), | |||||
| TensorValue({1, 2, 2, 2, 4}, dtype::Float32(), | |||||
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), | |||||
| {}}, | |||||
| Testcase{ | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| TensorValue({1, 2, 2, 2, 4}, dtype::Float32(), | |||||
| {260, 342, 244, 241, 293, 385, 362, 257, | |||||
| 278, 301, 303, 226, 273, 306, 318, 307, | |||||
| 180, 244, 169, 156, 210, 244, 206, 167, | |||||
| 126, 165, 156, 207, 191, 141, 209, 172})}); | |||||
| } | |||||
| { | |||||
| // test normal conv | |||||
| ConvBias::Param param; | |||||
| param.format = ConvBias::Param::Format::NCHW44; | |||||
| param.sparse = ConvBias::Param::Sparse::DENSE; | |||||
| param.pad_h = 1; | |||||
| param.pad_w = 1; | |||||
| checker.set_param(param).exect( | |||||
| Testcase{TensorValue({1, 1, 2, 2, 4}, dtype::Int8(), | |||||
| {7, 2, 2, 1, 7, 5, 6, 3, 1, 2, 8, 3, 7, 7, | |||||
| 6, 4}), | |||||
| TensorValue( | |||||
| {1, 1, 3, 3, 4, 4}, dtype::Int8(), | |||||
| {3, 5, 5, 2, 0, 1, 4, 8, 3, 5, 0, 7, 1, 7, 0, | |||||
| 7, 6, 4, 7, 7, 5, 2, 2, 4, 7, 6, 6, 3, 3, 2, | |||||
| 2, 8, 5, 0, 4, 4, 0, 5, 1, 0, 0, 4, 8, 4, 7, | |||||
| 7, 2, 0, 4, 8, 7, 3, 6, 2, 3, 0, 0, 6, 4, 4, | |||||
| 1, 4, 3, 8, 8, 8, 7, 2, 2, 5, 5, 1, 3, 2, 8, | |||||
| 1, 7, 0, 2, 7, 1, 6, 1, 5, 0, 6, 3, 0, 2, 4, | |||||
| 1, 1, 4, 2, 7, 5, 7, 8, 4, 5, 5, 7, 0, 3, 3, | |||||
| 2, 8, 6, 0, 1, 4, 6, 6, 6, 0, 1, 2, 4, 4, 1, | |||||
| 1, 7, 8, 2, 5, 2, 8, 3, 8, 3, 5, 0, 6, 3, 4, | |||||
| 3, 3, 7, 2, 8, 1, 1, 1, 4}), | |||||
| TensorValue({1, 1, 1, 1, 4}, dtype::Int32(), | |||||
| {7, 2, 8, 1}), | |||||
| TensorValue({1, 1, 2, 2, 4}, dtype::Int32(), | |||||
| {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||||
| 0, 0}), | |||||
| {}}, | |||||
| Testcase{ | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| {}, | |||||
| TensorValue({1, 1, 2, 2, 4}, dtype::Int32(), | |||||
| {264, 338, 309, 195, 276, 332, 390, 199, | |||||
| 224, 268, 311, 218, 288, 311, 346, 277})}); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||