GitOrigin-RevId: 1c65ba87d7
tags/v1.5.0
| @@ -330,6 +330,7 @@ bool check_bias_share_in_channel(const TensorLayout& bias, | |||
| } else if (format == param::ConvBias::Format::NCHW4 || | |||
| format == param::ConvBias::Format::NCHW8 || | |||
| format == param::ConvBias::Format::NCHW32 || | |||
| format == param::ConvBias::Format::NCHW64 || | |||
| format == param::ConvBias::Format::NCHW4_NCHW32 || | |||
| format == param::ConvBias::Format::NCHW32_NCHW4) { | |||
| share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | |||
| @@ -559,7 +559,7 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||
| supported_dst_dtype.push_back( | |||
| dtype::QuantizedS8(src.param<dtype::QuantizedS32>().scale / | |||
| filter.param<dtype::QuantizedS8>().scale)); | |||
| } else { | |||
| }else { | |||
| megdnn_throw(ssprintf("unsupported input / filter DType: %s x %s", | |||
| src.name(), filter.name())); | |||
| } | |||
| @@ -488,10 +488,6 @@ void LowbitsAlignedTensorFormatBase::assert_valid( | |||
| "bad stride:%s, %zu", layout.to_string().c_str(), | |||
| layout.stride[i]); | |||
| } | |||
| /// FIXME | |||
| if (layout.ndim == 0) { | |||
| printf("%s\n", layout.to_string().c_str()); | |||
| } | |||
| megdnn_assert(layout.ndim == 0 || has_dim_unity_stride, | |||
| "innermost dim not contiguous"); | |||
| } | |||
| @@ -553,7 +549,7 @@ bool LowbitsAlignedTensorFormatBase::is_contiguous_spec( | |||
| if (layout.shape[i] != 1 && layout.stride[i] != expected) | |||
| return false; | |||
| auto multiplier = layout.shape[i]; | |||
| if (i == layout.ndim - 1) | |||
| if (i == static_cast<int>(layout.ndim) - 1) | |||
| multiplier = round_up(multiplier, m_align_size_in_elements); | |||
| expected *= multiplier; | |||
| } | |||
| @@ -67,6 +67,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
| } | |||
| if (param.format == param::ConvBias::Format::NCHW8 || | |||
| param.format == param::ConvBias::Format::NCHW64 || | |||
| param.format == param::ConvBias::Format::CHWN4) | |||
| return false; | |||
| if (param.format == param::ConvBias::Format::NCHW32) { | |||
| @@ -6,18 +6,19 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/naive/conv_bias/opr_impl.h" | |||
| #include "src/naive/convolution/helper.h" | |||
| #include <cstring> | |||
| #include "megdnn/dtype.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/naive/lowbit_utils.h" | |||
| #include "src/common/conv_bias.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "midout.h" | |||
| MIDOUT_DECL(megdnn_naive_conv_bias_fwd) | |||
| @@ -32,7 +33,7 @@ void handle_z_inp_and_activation_naive( | |||
| const TensorND& conv_bias_tensor, const TensorND& z_tensor, | |||
| const TensorND& dst_tensor, dt_byte* workspace_ptr) { | |||
| auto res = dst_tensor, z_float = z_tensor; | |||
| //!create naive inplace handle | |||
| //! create naive inplace handle | |||
| auto handle = inplace_cpu_handle(2); | |||
| if (z_tensor.layout.ndim > 0 && | |||
| z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) { | |||
| @@ -121,6 +122,7 @@ void forward_bias<dt_quint4, dt_quint4, dt_qint32, dt_qint32>( | |||
| auto ret = layout; | |||
| auto param = layout.dtype.param<dtype::Quantized4Asymm>(); | |||
| ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point); | |||
| ret.format = TensorFormat(ret.dtype); | |||
| return ret; | |||
| }; | |||
| TensorND new_src = {workspace_ptr, convert_layout(src.layout)}; | |||
| @@ -134,6 +136,29 @@ void forward_bias<dt_quint4, dt_quint4, dt_qint32, dt_qint32>( | |||
| forward_bias<dt_quint8, dt_quint8, dt_qint32, dt_qint32>( | |||
| new_src, new_flt, bias, dst, nullptr, new_filter_meta); | |||
| } | |||
| template <> | |||
| void forward_bias<dt_qint4, dt_qint4, dt_qint32, dt_qint32>( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias, | |||
| _megdnn_tensor_out dst, dt_byte* workspace_ptr, | |||
| const ConvBiasForward::CanonizedFilterMeta& filter_meta) { | |||
| auto convert_layout = [](const TensorLayout& layout) { | |||
| auto ret = layout; | |||
| auto param = layout.dtype.param<dtype::QuantizedS4>(); | |||
| ret.dtype = dtype::QuantizedS8(param.scale); | |||
| ret.format = TensorFormat(ret.dtype); | |||
| return ret; | |||
| }; | |||
| TensorND new_src = {workspace_ptr, convert_layout(src.layout)}; | |||
| TensorND new_flt = {workspace_ptr + new_src.layout.span().dist_byte(), | |||
| convert_layout(filter.layout)}; | |||
| int4_to_int8(src, new_src); | |||
| int4_to_int8(filter, new_flt); | |||
| auto new_filter_meta = filter_meta; | |||
| new_filter_meta.dtype = new_flt.layout.dtype; | |||
| forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>( | |||
| new_src, new_flt, bias, dst, nullptr, new_filter_meta); | |||
| } | |||
| } // namespace convolution | |||
| size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| @@ -150,15 +175,17 @@ size_t ConvBiasForwardImpl::get_workspace_in_bytes(const TensorLayout& src, | |||
| float_workspace_size = | |||
| 2 * TensorLayout{z, dtype::Float32()}.span().dist_byte(); | |||
| } | |||
| if ((src.dtype.enumv() == DTypeEnum::Quantized4Asymm || | |||
| src.dtype.enumv() == DTypeEnum::QuantizedS4) && | |||
| bias.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
| float_workspace_size += | |||
| (src.total_nr_elems() + flt.total_nr_elems()) * sizeof(uint8_t); | |||
| } | |||
| if (bias.dtype.enumv() != dst.dtype.enumv()) { | |||
| return float_workspace_size + | |||
| TensorLayout{dst, bias.dtype}.span().dist_byte(); | |||
| } else if (src.dtype.enumv() == DTypeEnum::Quantized4Asymm && | |||
| dst.dtype.enumv() == DTypeEnum::QuantizedS32) { | |||
| return float_workspace_size + | |||
| (src.span().dist_elem() + flt.span().dist_elem()) * | |||
| sizeof(uint8_t); | |||
| float_workspace_size += | |||
| TensorLayout{dst, bias.dtype}.span().dist_byte(); | |||
| } | |||
| return float_workspace_size; | |||
| } | |||
| @@ -169,7 +196,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| const PreprocessedFilter* preprocessed_filter, | |||
| _megdnn_workspace workspace) { | |||
| MIDOUT_BEGIN(megdnn_naive_conv_bias_fwd) { | |||
| dt_byte *workspace_ptr = workspace.raw_ptr; | |||
| dt_byte* workspace_ptr = workspace.raw_ptr; | |||
| // ============================w * f + b================================ | |||
| auto filter_meta = | |||
| @@ -198,7 +225,8 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| DTypeTrait<dtype::in_dt>::ctype, \ | |||
| DTypeTrait<dtype::out_dt>::ctype, \ | |||
| DTypeTrait<dtype::out_dt>::ctype>)) | |||
| if (0) {} | |||
| if (0) { | |||
| } | |||
| DISPATCH(Float32, Float32) | |||
| DISPATCH(Int8, Int16) | |||
| DISPATCH(Int8, Int32) | |||
| @@ -209,6 +237,7 @@ void ConvBiasForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| DISPATCH_RAW(QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32, | |||
| (convolution::forward_bias<dt_int8, dt_int8, dt_int32, | |||
| dt_int32>)) | |||
| DISPATCH(QuantizedS4, QuantizedS32) | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| DISPATCH(Float16, Float16) | |||
| DISPATCH_RAW(Float16, Float16, Float16, FLOAT32, | |||
| @@ -254,8 +283,7 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| return algo; | |||
| } | |||
| ConvBiasForward::Algorithm* | |||
| ConvBiasForwardImpl::get_algorithm_from_desc( | |||
| ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_from_desc( | |||
| const AlgorithmDesc& desc) { | |||
| Algorithm* ret = | |||
| static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo(); | |||
| @@ -162,7 +162,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
| filter_meta.format == Format::NCHW4_NCHW32 || | |||
| filter_meta.format == Format::NCHW8 || | |||
| filter_meta.format == Format::NCHW32 || | |||
| filter_meta.format == Format::NCHW32_NCHW4) { | |||
| filter_meta.format == Format::NCHW32_NCHW4 || | |||
| filter_meta.format == Format::NCHW64) { | |||
| spatial_start = 2; | |||
| channel_pos = 1; | |||
| batch_pos = 0; | |||
| @@ -197,6 +198,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
| } else if (filter_meta.format == Format::NCHW32 || | |||
| filter_meta.format == Format::NCHW4_NCHW32) { | |||
| OC *= 32; | |||
| } else if (filter_meta.format == Format::NCHW64) { | |||
| OC *= 64; | |||
| } | |||
| size_t FS_G, FS_OC, FS_IC, FS_SPATIAL; | |||
| @@ -206,7 +209,8 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
| filter_meta.format == Format::NCHW4_NCHW32 || | |||
| filter_meta.format == Format::NCHW8 || | |||
| filter_meta.format == Format::NCHW32 || | |||
| filter_meta.format == Format::NCHW32_NCHW4) { | |||
| filter_meta.format == Format::NCHW32_NCHW4 || | |||
| filter_meta.format == Format::NCHW64) { | |||
| // g, oc, ic, fh, fw | |||
| FS_SPATIAL = 1; | |||
| FS_IC = FH * FW; | |||
| @@ -349,6 +353,10 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
| h * layout.stride[2] + w * layout.stride[3] + | |||
| (c & 0b11) * layout.stride[4]; | |||
| } | |||
| } else if (filter_meta.format == Format::NCHW64) { | |||
| return n * layout.stride[0] + (c >> 6) * layout.stride[1] + | |||
| h * layout.stride[2] + w * layout.stride[3] + | |||
| (c & 0x3F) * layout.stride[4]; | |||
| } else { | |||
| megdnn_assert(filter_meta.format == Format::NCHW4, | |||
| "invalid conv format"); | |||
| @@ -432,6 +440,10 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, | |||
| megdnn_throw( | |||
| "nchw44_dot naive not support this input and output\n"); | |||
| } | |||
| } else if (filter_meta.format == Format::NCHW64) { | |||
| return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | |||
| (ic - ic0) / 64 * FS_IC * 64 + | |||
| (fh * FW + fw) * FS_SPATIAL * 64 + ((ic - ic0) & 0x3F); | |||
| } else { | |||
| return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + | |||
| (ic - ic0) * FS_IC + (fh * FW + fw) * FS_SPATIAL; | |||
| @@ -690,6 +702,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| case param::Convolution::Format::NCHW32: | |||
| case param::Convolution::Format::NCHW32_NCHW4: | |||
| case param::Convolution::Format::CHWN4: | |||
| case param::Convolution::Format::NCHW64: | |||
| compute2d<stype, ftype, dtype, comp_type, StrategyFwd, FilterMeta, | |||
| FilterVisitor>(src, filter.compatible_ptr<ftype>(), dst, | |||
| filter_meta); | |||
| @@ -782,6 +795,10 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
| BIAS_ADD_NCHWx(8); | |||
| break; | |||
| }; | |||
| case Format::NCHW64: { | |||
| BIAS_ADD_NCHWx(64); | |||
| break; | |||
| }; | |||
| #define BIAS_ADD_CHWNx(_pack_size) \ | |||
| do { \ | |||
| megdnn_assert(dst.layout.is_contiguous()); \ | |||
| @@ -179,7 +179,7 @@ void ElemwiseMultiTypeImpl::dispatch_add_qint_op( | |||
| auto size = param.size; | |||
| auto param0 = param[0].layout.dtype | |||
| .param<typename DTypeTrait<src_ctype>::dtype>(); | |||
| auto dst = dst_tensor.ptr<dst_ctype>(); | |||
| auto dst = tensor_iter_valonly<dst_ctype>(dst_tensor).begin(); | |||
| auto dst_param = dst_tensor.layout.dtype | |||
| .param<typename DTypeTrait<dst_ctype>::dtype>(); | |||
| @@ -205,7 +205,7 @@ void ElemwiseMultiTypeImpl::dispatch_add_qint_op( | |||
| .param<typename DTypeTrait<src_ctype>::dtype>(); | |||
| auto param1 = param[1].layout.dtype | |||
| .param<typename DTypeTrait<src_ctype>::dtype>(); | |||
| auto dst = dst_tensor.ptr<dst_ctype>(); | |||
| auto dst = tensor_iter_valonly<dst_ctype>(dst_tensor).begin(); | |||
| auto dst_param = dst_tensor.layout.dtype | |||
| .param<typename DTypeTrait<dst_ctype>::dtype>(); | |||
| @@ -238,7 +238,7 @@ void ElemwiseMultiTypeImpl::dispatch_add_qint_op( | |||
| .param<typename DTypeTrait<src_ctype>::dtype>(); | |||
| auto param2 = param[2].layout.dtype | |||
| .param<typename DTypeTrait<src_ctype>::dtype>(); | |||
| auto dst = dst_tensor.ptr<dst_ctype>(); | |||
| auto dst = tensor_iter_valonly<dst_ctype>(dst_tensor).begin(); | |||
| auto dst_param = dst_tensor.layout.dtype | |||
| .param<typename DTypeTrait<dst_ctype>::dtype>(); | |||
| @@ -272,10 +272,13 @@ void ElemwiseMultiTypeImpl::dispatch_add_qint_op_dst(const ElemParam& param, | |||
| typename DTypeTrait<_dt>::ctype>(param, dst); \ | |||
| break; | |||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||
| MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
| #undef cb | |||
| default: | |||
| megdnn_assert_internal(0); | |||
| megdnn_assert(0, "not support %s %s\n", | |||
| param[0].layout.dtype.name(), | |||
| dst.layout.dtype.name()); | |||
| } | |||
| } | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/naive/lowbit_utils.h" | |||
| @@ -40,6 +41,7 @@ void megdnn::naive::int4_to_int8(const TensorND& in, const TensorND& out) { | |||
| auto out_ptr = | |||
| static_cast<int8_t*>(out.raw_ptr) + out.layout.span().low_byte; | |||
| megdnn_assert(in.layout.span().dist_elem() % 2 == 0); | |||
| for (size_t i = 0; i < in.layout.span().dist_elem(); i += 2) { | |||
| int8_t cur = in_ptr[i / 2]; | |||
| out_ptr[i] = cur << 4; | |||
| @@ -29,7 +29,7 @@ namespace { | |||
| double error_sum = 0; | |||
| double error_sum_biased = 0; | |||
| for (size_t i = 0; i < nr_elem; ++ i) { | |||
| ctype iv0 = ctype(*it0), iv1 = ctype(*it1); | |||
| ctype iv0 = *it0, iv1 = *it1; | |||
| float err = diff(iv0, iv1); | |||
| error_sum += std::abs(err); | |||
| error_sum_biased += err; | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| @@ -419,13 +420,28 @@ TensorND TensorValueLowbit4(const TensorShape& shape, T dtype, | |||
| tensor.raw_ptr = | |||
| static_cast<dt_byte*>(malloc(tensor.layout.span().dist_byte())); | |||
| megdnn_assert(values.size() == tensor.layout.total_nr_elems()); | |||
| auto ptr = static_cast<U*>(tensor.raw_ptr); | |||
| for (size_t i = 0; i < values.size(); i += 2) { | |||
| auto ptr = tensor.ptr<typename DTypeTrait<T>::ctype>(); | |||
| size_t i; | |||
| for (i = 0; i + 1 < values.size(); i += 2) { | |||
| U val0 = values[i], val1 = values[i + 1]; | |||
| megdnn_assert(val0 >= DTypeTrait<T>::min()); | |||
| megdnn_assert(val1 <= DTypeTrait<T>::max()); | |||
| ptr[i / 2] = (val0 & 0xF) | (val1 << 4); | |||
| ptr[i / 2] = typename DTypeTrait<T>::ctype((val0 & 0xF) | (val1 << 4)); | |||
| } | |||
| if (i < values.size()) { | |||
| U val0 = values[i]; | |||
| megdnn_assert(val0 >= DTypeTrait<T>::min() && | |||
| val0 <= DTypeTrait<T>::max()); | |||
| if (i + 1 < values.size()) { | |||
| U val1 = values[i + 1]; | |||
| megdnn_assert(val1 >= DTypeTrait<T>::min() && | |||
| val1 <= DTypeTrait<T>::max()); | |||
| ptr[i / 2] = typename DTypeTrait<T>::ctype((val0 & 0xF) | (val1 << 4)); | |||
| } else { | |||
| ptr[i / 2] = typename DTypeTrait<T>::ctype(val0 & 0xF); | |||
| } | |||
| } | |||
| return tensor; | |||
| } | |||
| @@ -466,7 +482,6 @@ struct ExecutionPolicyAlgoName { | |||
| template <class Opr, typename OprAlgoProxy = OprAlgoProxy<Opr>> | |||
| class AlgoChecker { | |||
| public: | |||
| AlgoChecker(ExecutionPolicyAlgoName name, bool* require_algo = nullptr) | |||
| : m_policy_name{name}, m_require_algo{require_algo} {} | |||
| @@ -554,7 +569,8 @@ void construct_sub_execution_policy_heuristic(ExecutionPolicy& policy, | |||
| opr->param() = Algorithm::deserialize_read_pod<typename Opr::Param>(param); | |||
| if (!policy.algo.valid()) { | |||
| policy.algo = AlgoProxy<Opr, OprTrait<Opr>::arity>:: | |||
| get_algorithm_info_heuristic(opr.get(), layouts).desc; | |||
| get_algorithm_info_heuristic(opr.get(), layouts) | |||
| .desc; | |||
| } | |||
| Algorithm* algo = opr->get_algorithm_from_desc(policy.algo); | |||
| @@ -563,8 +579,7 @@ void construct_sub_execution_policy_heuristic(ExecutionPolicy& policy, | |||
| FOREACH_OPR_TYPE_DISPATCH(sub_items, { | |||
| policy.sub_policy.push_back(ExecutionPolicy{}); | |||
| construct_sub_execution_policy_heuristic<_Opr>( | |||
| policy.sub_policy.back(), _item.layouts, _item.param, | |||
| handle); | |||
| policy.sub_policy.back(), _item.layouts, _item.param, handle); | |||
| }); | |||
| } | |||
| @@ -752,6 +752,15 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||
| checker.set_epsilon(1 + 1e-3) | |||
| .set_max_avg_error(1e-1) | |||
| .set_max_avg_biased_error(1e-3); | |||
| } else if (src_dtype.enumv() == DTypeEnum::QuantizedS4) { | |||
| rng = std::make_unique<UniformIntRNG>(-3, 3); | |||
| const_rng = std::make_unique<UniformIntRNG>(1, 1); | |||
| zero_rng = std::make_unique<UniformIntRNG>(0, 0); | |||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::QuantizedS32); | |||
| bias_rng = std::make_unique<UniformIntRNG>(-50, 50); | |||
| checker.set_epsilon(1 + 1e-3) | |||
| .set_max_avg_error(1e-1) | |||
| .set_max_avg_biased_error(1e-3); | |||
| } else if (src_dtype.enumv() == DTypeEnum::Float16) { | |||
| rng = std::make_unique<NormalRNG>(2.f); | |||
| megdnn_assert(bias_dtype.enumv() == DTypeEnum::Float16); | |||
| @@ -783,6 +792,12 @@ void check_conv_bias(DType src_dtype, DType filter_dtype, DType bias_dtype, | |||
| fh = arg.filter[2]; | |||
| fw = arg.filter[3]; | |||
| z[1] = arg.filter[0] / 32; | |||
| } else if (format == Format::NCHW64) { | |||
| hi = arg.src[2]; | |||
| wi = arg.src[3]; | |||
| fh = arg.filter[2]; | |||
| fw = arg.filter[3]; | |||
| z[1] = arg.filter[0] / 64; | |||
| } else { | |||
| megdnn_assert(format == Format::CHWN4); | |||
| hi = arg.src[1]; | |||
| @@ -20,408 +20,13 @@ | |||
| #include "test/cuda/utils.h" | |||
| #include "test/common/tensor.h" | |||
| #include "test/common/workspace_wrapper.h" | |||
| #include "test/cuda/conv_test_utils.h" | |||
| #define V1(x) #x | |||
| #define V(x) V1(x) | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace { | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| struct BenchArgs { | |||
| size_t n, ci, hi, wi, co, f, s; | |||
| }; | |||
| std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 128, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 512, 28, 28, 128, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 28, 28, 128, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 28, 28, 512, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 512, 28, 28, 1024, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 512, 28, 28, 256, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 1024, 14, 14, 256, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 14, 14, 256, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 14, 14, 1024, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 14, 14, 1024, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 1024, 14, 14, 2048, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 1024, 14, 14, 512, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 2048, 7, 7, 512, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 512, 7, 7, 512, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 512, 7, 7, 2048, 1, 1}); | |||
| return args; | |||
| } | |||
| std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 4, 736, 1280, 8, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 184, 320, 16, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 16, 184, 320, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 184, 320, 16, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 184, 320, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 32, 184, 320, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 184, 320, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 92, 160, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 128, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 256, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 46, 80, 128, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 46, 80, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 23, 40, 256, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 23, 40, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 23, 40, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 23, 40, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 8, 12, 20, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 12, 20, 8, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 8, 6, 10, 8, 3, 1}); | |||
| return args; | |||
| } | |||
| std::vector<BenchArgs> get_det_first_bench_args(size_t batch = 16) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 4, 736, 1280, 16, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 16, 384, 640, 16, 3, 1}); | |||
| return args; | |||
| } | |||
| void benchmark_target_algo( | |||
| Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype, | |||
| DType filter_dtype, DType bias_dtype, DType dst_dtype, | |||
| const char* algo = nullptr, | |||
| param::ConvBias::Format format = param::ConvBias::Format::NCHW4) { | |||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||
| CUBenchmarker<ConvBiasForward> benchmarker(handle); | |||
| CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); | |||
| size_t RUNS = 1000; | |||
| benchmarker.set_display(false).set_times(RUNS); | |||
| benchmarker_cudnn.set_display(false).set_times(RUNS); | |||
| #define CUDNN_VERSION_STRING \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
| "ALGO_IMPLICIT_PRECOMP_" | |||
| "GEMM" CUDNN_VERSION_STRING)); | |||
| benchmarker.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| benchmarker_cudnn.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| using Param = ConvBias::Param; | |||
| using Format = Param::Format; | |||
| // helper function to change format | |||
| auto get_tensor_shape = [](TensorShape shape, | |||
| Format format) -> TensorShape { | |||
| TensorShape ret; | |||
| if (format == Format::NCHW4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::CHWN4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({1, 3, 4, 0, 2})); | |||
| } | |||
| return ret; | |||
| }; | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = format; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_execution_policy.algo.reset(); | |||
| } | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, | |||
| z{arg.n, arg.co, ho, wo}, dst = z; | |||
| float time_in_ms = 0.f; | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| Format format_cudnn = Format::NCHW4; | |||
| param.format = format_cudnn; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / | |||
| (1e12); | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| printf("bench with z tensor\n"); | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| get_tensor_shape(z, format_cudnn), | |||
| {}}) / | |||
| RUNS; | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| void benchmark_target_algo_with_cudnn_tsc( | |||
| Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype, | |||
| DType filter_dtype, DType bias_dtype, DType dst_dtype, | |||
| const char* algo = nullptr, | |||
| param::ConvBias::Format format = param::ConvBias::Format::NCHW4) { | |||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||
| CUBenchmarker<ConvBiasForward> benchmarker(handle); | |||
| CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); | |||
| size_t RUNS = 1000; | |||
| benchmarker.set_display(false).set_times(RUNS); | |||
| benchmarker_cudnn.set_display(false).set_times(RUNS); | |||
| std::unique_ptr<OprProxy<ConvBiasForward>> proxy{ | |||
| new OprProxy<ConvBiasForward>{true}}; | |||
| if (!algo) { | |||
| benchmarker.set_proxy(proxy); | |||
| } | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
| "ALGO_IMPLICIT_PRECOMP_" | |||
| "GEMM" CUDNN_VERSION_STRING)); | |||
| #undef CUDNN_VERSION_STRING | |||
| benchmarker.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| benchmarker_cudnn.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| using Param = ConvBias::Param; | |||
| using Format = Param::Format; | |||
| // helper function to change format | |||
| auto get_tensor_shape = [](TensorShape shape, | |||
| Format format) -> TensorShape { | |||
| TensorShape ret; | |||
| if (format == Format::NCHW4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::NCHW32) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 32, 32, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::CHWN4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({1, 3, 4, 0, 2})); | |||
| } | |||
| return ret; | |||
| }; | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = format; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_execution_policy.algo.reset(); | |||
| } | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, | |||
| z{arg.n, arg.co, ho, wo}, dst = z; | |||
| // skip testcase which cannot enable nchw32 tensorcore | |||
| if (format == Format::NCHW32 && (arg.co % 32 != 0 || arg.ci % 32 != 0)) | |||
| continue; | |||
| // skip testcase which cannot enable nchw4/chwn4 tensorcore | |||
| if ((format == Format::CHWN4 || format == Format::NCHW4) && | |||
| (arg.ci % 16 != 0)) | |||
| continue; | |||
| Format format_cudnn = arg.ci % 32 == 0 && arg.co % 32 == 0 | |||
| ? Format::NCHW32 | |||
| : Format::NCHW4; | |||
| param.format = format_cudnn; | |||
| benchmarker_cudnn.set_param(param); | |||
| float time_in_ms = 0.f; | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| float time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / | |||
| (1e12); | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| printf("bench with z tensor\n"); | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| get_tensor_shape(z, format_cudnn), | |||
| {}}) / | |||
| RUNS; | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| #endif | |||
| } // namespace | |||
| namespace conv{ | |||
| TEST_F(CUDA, CONV_BIAS_INT8_NCHW4_1x1) { | |||
| require_compute_capability(6, 1); | |||
| @@ -1410,10 +1015,10 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) { | |||
| } | |||
| #endif | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| #undef V1 | |||
| #undef V | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,458 @@ | |||
| /** | |||
| * \file dnn/test/cuda/conv_test_utils.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/cudnn_with_check.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/conv_bias.h" | |||
| #include "test/common/tensor.h" | |||
| #include "test/common/workspace_wrapper.h" | |||
| #include "test/cuda/benchmark.h" | |||
| #include "test/cuda/conv_test_utils.h" | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/cuda/utils.h" | |||
| #define V1(x) #x | |||
| #define V(x) V1(x) | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace conv { | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| std::vector<BenchArgs> get_resnet50_bench_args(size_t batch) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 256, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 4, 256, 256, 32, 7, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 56, 56, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 512, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 56, 56, 128, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 512, 28, 28, 128, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 28, 28, 128, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 28, 28, 512, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 512, 28, 28, 1024, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 512, 28, 28, 256, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 1024, 14, 14, 256, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 14, 14, 256, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 14, 14, 1024, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 14, 14, 1024, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 1024, 14, 14, 2048, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 1024, 14, 14, 512, 1, 2}); | |||
| args.emplace_back(BenchArgs{batch, 2048, 7, 7, 512, 1, 1}); | |||
| args.emplace_back(BenchArgs{batch, 512, 7, 7, 512, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 512, 7, 7, 2048, 1, 1}); | |||
| return args; | |||
| } | |||
| std::vector<BenchArgs> get_detection_bench_args(size_t batch) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 4, 736, 1280, 8, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 184, 320, 16, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 16, 184, 320, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 184, 320, 16, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 184, 320, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 32, 184, 320, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 184, 320, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 92, 160, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 128, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 256, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 92, 160, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 32, 46, 80, 128, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 46, 80, 32, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 64, 23, 40, 256, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 256, 23, 40, 64, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 128, 46, 80, 64, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 256, 23, 40, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 23, 40, 32, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 8, 12, 20, 8, 3, 1}); | |||
| args.emplace_back(BenchArgs{batch, 8, 12, 20, 8, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 8, 6, 10, 8, 3, 1}); | |||
| return args; | |||
| } | |||
| std::vector<BenchArgs> get_det_first_bench_args(size_t batch) { | |||
| std::vector<BenchArgs> args; | |||
| args.emplace_back(BenchArgs{batch, 4, 736, 1280, 16, 3, 2}); | |||
| args.emplace_back(BenchArgs{batch, 16, 384, 640, 16, 3, 1}); | |||
| return args; | |||
| } | |||
| void benchmark_target_algo(Handle* handle, const std::vector<BenchArgs>& args, | |||
| DType src_dtype, DType filter_dtype, | |||
| DType bias_dtype, DType dst_dtype, const char* algo, | |||
| param::ConvBias::Format format) { | |||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||
| CUBenchmarker<ConvBiasForward> benchmarker(handle); | |||
| CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); | |||
| size_t RUNS = 1000; | |||
| benchmarker.set_display(false).set_times(RUNS); | |||
| benchmarker_cudnn.set_display(false).set_times(RUNS); | |||
| #define CUDNN_VERSION_STRING \ | |||
| "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_FWD_" | |||
| "ALGO_IMPLICIT_PRECOMP_" | |||
| "GEMM" CUDNN_VERSION_STRING)); | |||
| benchmarker.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| benchmarker_cudnn.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| using Param = ConvBias::Param; | |||
| using Format = Param::Format; | |||
| // helper function to change format | |||
| auto get_tensor_shape = [](TensorShape shape, | |||
| Format format) -> TensorShape { | |||
| TensorShape ret; | |||
| if (format == Format::NCHW4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::CHWN4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({1, 3, 4, 0, 2})); | |||
| } | |||
| return ret; | |||
| }; | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = format; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_execution_policy.algo.reset(); | |||
| } | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, | |||
| z{arg.n, arg.co, ho, wo}, dst = z; | |||
| float time_in_ms = 0.f; | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| Format format_cudnn = Format::NCHW4; | |||
| param.format = format_cudnn; | |||
| benchmarker_cudnn.set_param(param); | |||
| auto time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / | |||
| (1e12); | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| printf("bench with z tensor\n"); | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| time_in_ms_cudnn = | |||
| benchmarker_cudnn.execs({get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| get_tensor_shape(z, format_cudnn), | |||
| {}}) / | |||
| RUNS; | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| void benchmark_target_algo_with_cudnn_tsc( | |||
| Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype, | |||
| DType filter_dtype, DType bias_dtype, DType dst_dtype, const char* algo, | |||
| param::ConvBias::Format format, bool with_cudnn, | |||
| const char* change_cudnn_algo, | |||
| param::ConvBias::Format change_cudnn_format, | |||
| DType change_cudnn_src_dtype, DType change_cudnn_filter_dtype, | |||
| DType change_cudnn_bias_dtype, DType change_cudnn_dst_dtype) { | |||
| megdnn_assert(src_dtype.enumv() == filter_dtype.enumv()); | |||
| CUBenchmarker<ConvBiasForward> benchmarker(handle); | |||
| CUBenchmarker<ConvBiasForward> benchmarker_cudnn(handle); | |||
| size_t RUNS = 1000; | |||
| benchmarker.set_display(false).set_times(RUNS); | |||
| benchmarker.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| benchmarker_cudnn.set_display(false).set_times(RUNS); | |||
| std::unique_ptr<OprProxy<ConvBiasForward>> proxy{ | |||
| new OprProxy<ConvBiasForward>{true}}; | |||
| if (!algo) { | |||
| benchmarker.set_proxy(proxy); | |||
| } | |||
| if (change_cudnn_algo) { | |||
| benchmarker_cudnn.set_dtype(0, change_cudnn_src_dtype) | |||
| .set_dtype(1, change_cudnn_filter_dtype) | |||
| .set_dtype(2, change_cudnn_bias_dtype) | |||
| .set_dtype(3, change_cudnn_dst_dtype) | |||
| .set_dtype(4, change_cudnn_dst_dtype); | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| change_cudnn_algo)); | |||
| } else { | |||
| benchmarker_cudnn.set_dtype(0, src_dtype) | |||
| .set_dtype(1, filter_dtype) | |||
| .set_dtype(2, bias_dtype) | |||
| .set_dtype(3, dst_dtype) | |||
| .set_dtype(4, dst_dtype); | |||
| benchmarker_cudnn.set_before_exec_callback( | |||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>( | |||
| "DEFAULT:CUDNN:ConvBiasActivation:CUDNN_CONVOLUTION_" | |||
| "FWD_" | |||
| "ALGO_IMPLICIT_PRECOMP_GEMM" CUDNN_VERSION_STRING)); | |||
| } | |||
| #undef CUDNN_VERSION_STRING | |||
| using Param = ConvBias::Param; | |||
| using Format = Param::Format; | |||
| // helper function to change format | |||
| auto get_tensor_shape = [](TensorShape shape, | |||
| Format format) -> TensorShape { | |||
| TensorShape ret; | |||
| if (format == Format::NCHW4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::NCHW32) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 32, 32, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::NCHW64) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::QuantizedS4(1.f)} | |||
| .reshape({shape[0], shape[1] / 64, 64, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({0, 1, 3, 4, 2})); | |||
| } else if (format == Format::CHWN4) { | |||
| ret = static_cast<TensorShape>( | |||
| TensorLayout{shape, dtype::Int8()} | |||
| .reshape({shape[0], shape[1] / 4, 4, shape[2], | |||
| shape[3]}) | |||
| .dimshuffle({1, 3, 4, 0, 2})); | |||
| } | |||
| return ret; | |||
| }; | |||
| for (auto&& arg : args) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = arg.f / 2; | |||
| param.stride_h = param.stride_w = arg.s; | |||
| param.format = format; | |||
| size_t ho = infer_conv_shape(arg.hi, arg.f, arg.s, arg.f / 2); | |||
| size_t wo = infer_conv_shape(arg.wi, arg.f, arg.s, arg.f / 2); | |||
| benchmarker.set_param(param); | |||
| if (!algo) { | |||
| benchmarker.proxy()->target_execution_policy.algo.reset(); | |||
| } | |||
| TensorShape src{arg.n, arg.ci, arg.hi, arg.wi}, | |||
| filter{arg.co, arg.ci, arg.f, arg.f}, bias{1, arg.co, 1, 1}, | |||
| z{arg.n, arg.co, ho, wo}, dst = z; | |||
| // skip testcase which cannot enable nchw32 tensorcore | |||
| if (format == Format::NCHW32 && (arg.co % 32 != 0 || arg.ci % 32 != 0)) | |||
| continue; | |||
| // skip testcase which cannot enable nchw32 tensorcore | |||
| if (format == Format::NCHW64 && (arg.co % 64 != 0 || arg.ci % 64 != 0)) | |||
| continue; | |||
| // skip testcase which cannot enable nchw4/chwn4 tensorcore | |||
| if ((format == Format::CHWN4 || format == Format::NCHW4) && | |||
| (arg.ci % 16 != 0)) | |||
| continue; | |||
| Format format_cudnn = arg.ci % 32 == 0 && arg.co % 32 == 0 | |||
| ? Format::NCHW32 | |||
| : Format::NCHW4; | |||
| if (change_cudnn_algo) { | |||
| format_cudnn = change_cudnn_format; | |||
| } | |||
| param.format = format_cudnn; | |||
| benchmarker_cudnn.set_param(param); | |||
| float time_in_ms = 0.f; | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| float time_in_ms_cudnn = 0; | |||
| if (with_cudnn) { | |||
| time_in_ms_cudnn = benchmarker_cudnn.execs( | |||
| {get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| {}, | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| float flo = 2.0 * arg.n * arg.co * ho * wo * arg.ci * arg.f * arg.f / | |||
| (1e12); | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| printf("bench with z tensor\n"); | |||
| if (algo) { | |||
| time_in_ms = | |||
| algo_benchmark<ConvBiasForward, OprProxy<ConvBiasForward>, | |||
| CUTimer>(benchmarker, | |||
| {get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}, | |||
| algo) / | |||
| RUNS; | |||
| } else { | |||
| time_in_ms = benchmarker.execs({get_tensor_shape(src, format), | |||
| get_tensor_shape(filter, format), | |||
| get_tensor_shape(bias, format), | |||
| get_tensor_shape(z, format), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| time_in_ms_cudnn = 0; | |||
| if (with_cudnn) { | |||
| time_in_ms_cudnn = benchmarker_cudnn.execs( | |||
| {get_tensor_shape(src, format_cudnn), | |||
| get_tensor_shape(filter, format_cudnn), | |||
| get_tensor_shape(bias, format_cudnn), | |||
| get_tensor_shape(z, format_cudnn), | |||
| {}}) / | |||
| RUNS; | |||
| } | |||
| printf("src=%s, filter=%s, dst=%s, time(algo=%s)=%.2f %.2fTops, " | |||
| "time(cudnn)=%.2f %.2fTops, " | |||
| "perf(algo=%s)/perf(cudnn)=%.2f\n", | |||
| src.to_string().c_str(), filter.to_string().c_str(), | |||
| dst.to_string().c_str(), algo, time_in_ms, | |||
| (flo / (time_in_ms * 1e-3)), time_in_ms_cudnn, | |||
| (flo / (time_in_ms_cudnn * 1e-3)), algo, | |||
| time_in_ms_cudnn / time_in_ms); | |||
| } | |||
| } | |||
| #endif | |||
| } // namespace conv | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| #undef V1 | |||
| #undef V | |||
| @@ -0,0 +1,66 @@ | |||
| /** | |||
| * \file dnn/test/cuda/conv_test_utils.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/cuda/cudnn_with_check.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/conv_bias.h" | |||
| #include "test/common/tensor.h" | |||
| #include "test/common/workspace_wrapper.h" | |||
| #include "test/cuda/benchmark.h" | |||
| #include "test/cuda/fixture.h" | |||
| #include "test/cuda/utils.h" | |||
| #define V1(x) #x | |||
| #define V(x) V1(x) | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace conv { | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| struct BenchArgs { | |||
| size_t n, ci, hi, wi, co, f, s; | |||
| }; | |||
| std::vector<BenchArgs> get_resnet50_bench_args(size_t batch = 64); | |||
| std::vector<BenchArgs> get_detection_bench_args(size_t batch = 16); | |||
| std::vector<BenchArgs> get_det_first_bench_args(size_t batch = 16); | |||
| void benchmark_target_algo( | |||
| Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype, | |||
| DType filter_dtype, DType bias_dtype, DType dst_dtype, | |||
| const char* algo = nullptr, | |||
| param::ConvBias::Format format = param::ConvBias::Format::NCHW4); | |||
| void benchmark_target_algo_with_cudnn_tsc( | |||
| Handle* handle, const std::vector<BenchArgs>& args, DType src_dtype, | |||
| DType filter_dtype, DType bias_dtype, DType dst_dtype, | |||
| const char* algo = nullptr, | |||
| param::ConvBias::Format format = param::ConvBias::Format::NCHW4, | |||
| bool with_cudnn = true, const char* change_cudnn_algo = nullptr, | |||
| param::ConvBias::Format change_cudnn_format = | |||
| param::ConvBias::Format::NCHW4, | |||
| DType change_cudnn_src_dtype = dtype::Int8(), | |||
| DType change_cudnn_filter_dtype = dtype::Int8(), | |||
| DType change_cudnn_bias_dtype = dtype::Int8(), | |||
| DType change_cudnn_dst_dtype = dtype::Int8()); | |||
| #endif | |||
| } // namespace conv | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| #undef V1 | |||
| #undef V | |||
| @@ -147,8 +147,23 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_UNARY) { | |||
| checker.execs({{10, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 1}, {}}); | |||
| checker.set_dtype(1, dtype::QuantizedS4(0.35f)); | |||
| checker.execs({{3, 4, 5, 6}, {}}); | |||
| checker.execs({{10, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 1}, {}}); | |||
| } | |||
| } | |||
| TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_UNARY_Q4) { | |||
| using Param = ElemwiseMultiType::Param; | |||
| Checker<ElemwiseMultiType> checker(handle()); | |||
| checker.set_param(Param::Mode::QRELU); | |||
| checker.set_dtype(0, dtype::QuantizedS32(1.f)); | |||
| checker.set_dtype(1, dtype::QuantizedS4(1.f)); | |||
| checker.execs({{3, 4, 5, 6}, {}}); | |||
| } | |||
| TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_BINARY) { | |||
| using Param = ElemwiseMultiType::Param; | |||
| @@ -225,6 +240,13 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_BINARY) { | |||
| checker.execs({{10, 4, 5, 6}, {10, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 6}, {20, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 1}, {2, 1, 1, 2}, {}}); | |||
| checker.set_dtype(2, dtype::QuantizedS4(0.35f)); | |||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||
| checker.execs({{10, 4, 5, 6}, {10, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 6}, {20, 4, 5, 6}, {}}); | |||
| checker.execs({{1, 4, 5, 1}, {2, 1, 1, 2}, {}}); | |||
| } | |||
| } | |||
| @@ -275,6 +297,10 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) { | |||
| checker.set_dtype(3, dtype::QuantizedS32(0.35f)); | |||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||
| checker.execs({{10, 4, 5, 6}, {10, 4, 5, 6}, {10, 4, 5, 6}, {}}); | |||
| checker.set_dtype(3, dtype::QuantizedS4(0.35f)); | |||
| checker.execs({{3, 4, 5, 6}, {3, 4, 5, 6}, {3, 4, 5, 6}, {}}); | |||
| checker.execs({{10, 4, 5, 6}, {10, 4, 5, 6}, {10, 4, 5, 6}, {}}); | |||
| } | |||
| } | |||