GitOrigin-RevId: 444429a625
tags/v1.11.1
| @@ -1186,7 +1186,7 @@ void ConvolutionBackwardData::deduce_layout( | |||
| MEGDNN_MARK_USED_VAR(errmsg); | |||
| megdnn_assert_contiguous(filter); | |||
| megdnn_assert_contiguous(diff); | |||
| megdnn_assert(filter.ndim == 4_z || filter.ndim == 5_z, "%s", errmsg().c_str()); | |||
| megdnn_assert(filter.ndim >= 4_z && filter.ndim <= 7_z, "%s", errmsg().c_str()); | |||
| megdnn_assert(diff.ndim == 4_z || diff.ndim == 5_z, "%s", errmsg().c_str()); | |||
| deduce_dtype(filter.dtype, diff.dtype, grad.dtype); | |||
| @@ -1223,11 +1223,12 @@ void ConvolutionBackwardData::deduce_layout( | |||
| deduce(diff[i + src_or_dst_spatial_start], cflt.dilated_spatial[i], | |||
| cflt.stride[i], cflt.padding[i]); | |||
| } | |||
| } else if (param().format == Param::Format::NCHW4) { | |||
| } else if ( | |||
| param().format == Param::Format::NCHW4 || | |||
| param().format == Param::Format::NCHW44) { | |||
| megdnn_assert( | |||
| diff.ndim == 5, "valid diff ndim for NCHW4, expected=5, got=%zu", | |||
| diff.ndim); | |||
| megdnn_assert(cflt.group == 1, "%s", errmsg().c_str()); | |||
| diff.ndim == 5, | |||
| "valid diff ndim for NCHW4 and NCHW44, expected=5, got=%zu", diff.ndim); | |||
| megdnn_assert(cflt.ocpg * cflt.group == diff[1] * 4, "%s", errmsg().c_str()); | |||
| grad.ndim = diff.ndim; | |||
| grad[0] = diff[0]; | |||
| @@ -29,14 +29,19 @@ Relayout* get_relayout_opr() { | |||
| MatrixMul* get_matmul_opr(const NCBKernSizeParam& param) { | |||
| using ConvCM = param::Convolution::ComputeMode; | |||
| using MmCM = param::MatrixMul::ComputeMode; | |||
| static CpuOprDelegationStorage<2> storage; | |||
| static CpuOprDelegationStorage<3> storage; | |||
| if (param.filter_meta.format == param::Convolution::Format::NCHW44) { | |||
| MatrixMul::Param p; | |||
| p.format = param::MatrixMul::Format::MK4; | |||
| return storage.get<MatrixMul, 0>(p); | |||
| } | |||
| switch (param.compute_mode) { | |||
| default: | |||
| return storage.get<MatrixMul, 0>({}); | |||
| return storage.get<MatrixMul, 1>({}); | |||
| case ConvCM::FLOAT32: { | |||
| MatrixMul::Param p; | |||
| p.compute_mode = MmCM::FLOAT32; | |||
| return storage.get<MatrixMul, 1>(p); | |||
| return storage.get<MatrixMul, 2>(p); | |||
| } | |||
| } | |||
| } | |||
| @@ -58,7 +63,14 @@ WorkspaceBundle get_bundle(const NCBKernSizeParam& param) { | |||
| part0 = (IC * FH * FW * IH * IW) * param.grad_type.size(); | |||
| } | |||
| part2 = (OC * IC * FH * FW) * param.filter_type.size(); | |||
| { | |||
| if (param.filter_meta.format == param::Convolution::Format::NCHW44) { | |||
| TensorLayout A_, B_, C_; | |||
| A_ = TensorLayout({IC / 4 * FH * FW, OC / 4, 4, 4}, param.filter_type); | |||
| B_ = TensorLayout({OC / 4, IH * IW}, param.diff_type); | |||
| C_ = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type); | |||
| auto matmul_algo = get_matmul_opr(param); | |||
| part1 = matmul_algo->get_workspace_in_bytes(A_, B_, C_); | |||
| } else { | |||
| TensorLayout A_, B_, C_; | |||
| A_ = TensorLayout({IC * FH * FW, OC}, param.filter_type); | |||
| B_ = TensorLayout({OC, IH * IW}, param.diff_type); | |||
| @@ -573,4 +585,119 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::is_preferred( | |||
| return is_matrix_mul_preferred(param); | |||
| } | |||
| /* ===================== Matrix mul nchw44 algo ===================== */ | |||
| namespace{ | |||
| void kern_matmul_nchw44(const NCBKernParam& param) { | |||
| bool is_xcorr = !param.filter_meta.should_flip; | |||
| UNPACK_CONV_F32_NCB_KERN_SIZES(param); | |||
| auto bundle = get_bundle(param); | |||
| bundle.set(param.workspace_ptr); | |||
| bool is1X1 = (FH == 1 && FW == 1 && SH == 1 && SW == 1 && PH == 0 && PW == 0); | |||
| typedef void (*Func1)(const float*, float*, int, int, int, int, int, int, int); | |||
| typedef void (*Func2)( | |||
| const float*, float*, int, int, int, int, int, int, int, int, int, int, | |||
| int); | |||
| Func1 f1 = nullptr; | |||
| Func2 f2 = nullptr; | |||
| if (is_xcorr) { | |||
| f1 = col2img_nchw44<true>; | |||
| f2 = col2img_stride_padding_nchw44<true>; | |||
| } else { | |||
| f1 = col2img_nchw44<false>; | |||
| f2 = col2img_stride_padding_nchw44<false>; | |||
| } | |||
| float* filter = const_cast<float*>(param.filter<float>()); | |||
| TensorND A_src, A_dst; | |||
| { | |||
| A_src.layout = TensorLayout( | |||
| {IC / 4 * FH * FW, OC / 4, 4, 4}, | |||
| { | |||
| static_cast<std::ptrdiff_t>(16), | |||
| static_cast<std::ptrdiff_t>(IC * FH * FW * 4), | |||
| static_cast<std::ptrdiff_t>(1), | |||
| static_cast<std::ptrdiff_t>(4), | |||
| }, | |||
| param.filter_type); | |||
| A_src.reset_ptr(static_cast<void*>(filter)); | |||
| A_dst.layout = | |||
| TensorLayout({IC / 4 * FH * FW, OC / 4, 4, 4}, param.filter_type); | |||
| A_dst.reset_ptr(static_cast<void*>(bundle.get(2))); | |||
| // TODO Should be removed once armv8 convolution support transpose. | |||
| get_relayout_opr()->exec(A_src, A_dst, inplace_cpu_handle().get()); | |||
| } | |||
| TensorND B_, C_; | |||
| for (size_t n = 0; n < N; ++n) { | |||
| float*C_src, *C_dst; | |||
| float* diff = const_cast<float*>(param.diff<float>() + n * param.inp_bs); | |||
| float* grad = param.grad<float>() + n * param.out_bs; | |||
| if (is1X1) { | |||
| C_src = grad; | |||
| } else { | |||
| C_src = static_cast<float*>(bundle.get(0)); | |||
| } | |||
| { | |||
| B_.layout = TensorLayout({OC/4, IH * IW, 4}, param.diff_type); | |||
| B_.reset_ptr(static_cast<void*>(diff)); | |||
| C_.layout = TensorLayout({IC / 4 * FH * FW, IH * IW, 4}, param.grad_type); | |||
| C_.reset_ptr(C_src); | |||
| Workspace workspace( | |||
| static_cast<dt_byte*>(bundle.get(1)), bundle.get_size(1)); | |||
| auto matmul_opr =get_matmul_opr(param); | |||
| matmul_opr->exec(A_dst, B_, C_, workspace); | |||
| } | |||
| if (!is1X1) { | |||
| C_dst = grad; | |||
| std::memset(C_dst, 0, param.grad_type.size() * IC * OH * OW); | |||
| if (PH == 0 && PW == 0 && SH == 1 && SW == 1) { | |||
| f1(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW); | |||
| } else { | |||
| f2(C_src, C_dst, OH, OW, IC, IH, IW, FH, FW, SH, SW, PH, PW); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| bool ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::usable( | |||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
| auto&& fm = param.filter_meta; | |||
| return fm.format == param::Convolution::Format::NCHW44 && | |||
| param.diff_type.enumv() == DTypeTrait<dtype::Float32>::enumv && | |||
| param.filter_type.enumv() == DTypeTrait<dtype::Float32>::enumv && | |||
| param.grad_type.enumv() == DTypeTrait<dtype::Float32>::enumv && | |||
| fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 && | |||
| fm.dilation[1] == 1 && fm.icpg % 4 == 0 && fm.ocpg % 4 == 0; | |||
| } | |||
| size_t ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::get_workspace( | |||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
| MIDOUT_BEGIN( | |||
| megdnn_fallback_deconv, | |||
| midout_iv("AlgoMatrixMulNCHW44::get_workspace"_hash)) { | |||
| return get_bundle(param).total_size_in_bytes(); | |||
| } | |||
| MIDOUT_END(); | |||
| return 0; | |||
| } | |||
| ConvolutionBackwardDataImpl::ncb_kern_t ConvolutionBackwardDataImpl:: | |||
| AlgoMatrixMulNCHW44::dispatch_kern( | |||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||
| if (param.filter_type.enumv() == DTypeTrait<dtype::Float32>::enumv) { | |||
| MIDOUT_BEGIN(megdnn_fallback_deconv, midout_iv("FLOAT_NCHW44"_hash)) { | |||
| return kern_matmul_nchw44; | |||
| } | |||
| MIDOUT_END(); | |||
| } | |||
| megdnn_throw("unsupported data type on matrix mul"); | |||
| } | |||
| bool ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44::is_preferred( | |||
| const NCBKernSizeParam& param) const { | |||
| return is_matrix_mul_preferred(param); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -198,6 +198,20 @@ public: | |||
| MEGDNN_DECL_ALGO_TYPE(FB_MATMUL) | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoMatrixMulNCHW44 final : public AlgoBase { | |||
| public: | |||
| const char* name() const override { return "DeconvMatmulNchw44"; } | |||
| bool usable(ConvolutionBackwardDataImpl* opr, const NCBKernSizeParam& param) | |||
| const override; | |||
| size_t get_workspace( | |||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const override; | |||
| ncb_kern_t dispatch_kern( | |||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const override; | |||
| bool is_preferred(const NCBKernSizeParam& param) const override; | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| MEGDNN_DECL_ALGO_TYPE(FB_MATMUL_NCHW44) | |||
| }; | |||
| } // namespace fallback | |||
| } // namespace megdnn | |||
| @@ -1,5 +1,6 @@ | |||
| #include <cstddef> | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/general_intrinsic/gi_float.h" | |||
| namespace { | |||
| @@ -61,6 +62,72 @@ void col2img( | |||
| } | |||
| } | |||
| template <bool is_xcorr> | |||
| void col2img_stride_padding_nchw44( | |||
| const float* __restrict src, float* __restrict dst, const int OH, const int OW, | |||
| const int IC, const int IH, const int IW, const int FH, const int FW, | |||
| const int SH, const int SW, int PH, int PW) { | |||
| size_t i = 0; | |||
| rep(ic, IC / 4) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| rep(ih, IH) { | |||
| int h = ih * SH - PH + fh2; | |||
| rep(iw, IW) { | |||
| int w = iw * SW - PW + fw2; | |||
| if (h >= 0 && h < OH && w >= 0 && w < OW) { | |||
| float* dst_ptr = dst + (ic * OH * OW + h * OW + w) * 4; | |||
| GI_FLOAT32_t dst_data = GiLoadFloat32(dst_ptr); | |||
| GI_FLOAT32_t src_data = GiLoadFloat32(src+i); | |||
| GiStoreFloat32(dst_ptr, GiAddFloat32(dst_data, src_data)); | |||
| } | |||
| i += 4; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| template <bool is_xcorr> | |||
| void col2img_nchw44( | |||
| const float* __restrict src, float* __restrict dst, const int OH, const int OW, | |||
| const int IC, const int IH, const int IW, const int FH, const int FW) { | |||
| size_t i = 0; | |||
| rep(ic, IC / 4) { | |||
| rep(fh, FH) { | |||
| rep(fw, FW) { | |||
| int fh2, fw2; | |||
| if (is_xcorr) { | |||
| fh2 = fh; | |||
| fw2 = fw; | |||
| } else { | |||
| fh2 = FH - fh - 1; | |||
| fw2 = FW - fw - 1; | |||
| } | |||
| rep(ih, IH) { | |||
| rep(iw, IW) { | |||
| float* dst_ptr = dst + ic * OH * OW * 4 + (ih + fh2) * OW * 4 + | |||
| iw * 4 + fw2 * 4; | |||
| GI_FLOAT32_t dst_data = GiLoadFloat32(dst_ptr); | |||
| GI_FLOAT32_t src_data = GiLoadFloat32(src + i); | |||
| GiStoreFloat32(dst_ptr, GiAddFloat32(dst_data, src_data)); | |||
| i += 4; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -437,11 +437,13 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||
| AlgoNaive algo_naive; | |||
| AlgoDirect algo_direct; | |||
| AlgoMatrixMul algo_matmul; | |||
| AlgoMatrixMulNCHW44 algo_matmul_nchw44; | |||
| SmallVector<AlgoBase*> m_all_algos; | |||
| AlgoBase::Mapper m_all_algos_map; | |||
| public: | |||
| AlgoPack() { | |||
| m_all_algos.emplace_back(&algo_matmul_nchw44); | |||
| m_all_algos.emplace_back(&algo_matmul); | |||
| m_all_algos.emplace_back(&algo_direct); | |||
| m_all_algos.emplace_back(&algo_naive); | |||
| @@ -557,7 +559,8 @@ ConvolutionBackwardDataImpl::NCBKernSizeParam ConvolutionBackwardDataImpl:: | |||
| return v; | |||
| }; | |||
| size_t spatial_pos; | |||
| if (param().format == Param::Format::NCHW) { | |||
| if (param().format == Param::Format::NCHW || | |||
| param().format == Param::Format::NCHW44) { | |||
| spatial_pos = 2; | |||
| } else { | |||
| megdnn_assert(param().format == Param::Format::NHWC, "invalid conv format"); | |||
| @@ -622,7 +625,8 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) | |||
| } else { | |||
| megdnn_assert( | |||
| p1g.filter_meta.format == Param::Format::NCHW || | |||
| p1g.filter_meta.format == Param::Format::NHWC, | |||
| p1g.filter_meta.format == Param::Format::NHWC || | |||
| p1g.filter_meta.format == Param::Format::NCHW44, | |||
| "invalid conv format"); | |||
| auto run = [kptr, p1g_orig = p1g, group]() { | |||
| auto p1g = p1g_orig; | |||
| @@ -640,7 +644,8 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern(const NCBKernParam& param) | |||
| p1g.filter_type.size(); | |||
| p1g.grad_extra_mem_size = | |||
| (group - 1) * p1g.filter_meta.icpg * p1g.grad_type.size(); | |||
| if (p1g.filter_meta.format == Param::Format::NCHW) { | |||
| if (p1g.filter_meta.format == Param::Format::NCHW || | |||
| p1g.filter_meta.format == Param::Format::NCHW44) { | |||
| istrd *= p1g.isz[0] * p1g.isz[1]; | |||
| ostrd *= p1g.osz[0] * p1g.osz[1]; | |||
| p1g.diff_extra_mem_size *= p1g.isz[0] * p1g.isz[1]; | |||
| @@ -392,6 +392,7 @@ protected: | |||
| FB_NAIVE = 1 << 0, | |||
| FB_DIRECT, | |||
| FB_MATMUL, | |||
| FB_MATMUL_NCHW44, | |||
| #if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
| ARM_COMMON_DIRECT_STRD1_DOT_INT8X8X32 = 1 << 8, | |||
| @@ -480,6 +481,7 @@ private: | |||
| class AlgoNaive; | |||
| class AlgoDirect; | |||
| class AlgoMatrixMul; | |||
| class AlgoMatrixMulNCHW44; | |||
| class AlgoPack; | |||
| Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override; | |||
| @@ -463,6 +463,60 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA) { | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NCHW44) { | |||
| Checker<ConvolutionBackwardData> checker(handle()); | |||
| using Param = ConvolutionBackwardData::Param; | |||
| Param param; | |||
| param.format = Param::Format::NCHW44; | |||
| auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, | |||
| size_t fw, size_t stride, size_t padding, size_t dilate = 1, | |||
| size_t group = 1) { | |||
| param.pad_h = param.pad_w = padding; | |||
| param.stride_h = param.stride_w = stride; | |||
| param.dilate_h = param.dilate_w = dilate; | |||
| TensorLayout diff = | |||
| TensorLayout{{n, oc / 4 * group, oh, ow, 4}, dtype::Float32()}; | |||
| TensorLayout grad; | |||
| TensorLayout filter; | |||
| if (group == 1) { | |||
| param.sparse = Param::Sparse::DENSE; | |||
| filter = {{oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
| } else { | |||
| param.sparse = Param::Sparse::GROUP; | |||
| filter = {{group, oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
| } | |||
| // TensorLayout grad; | |||
| { | |||
| auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||
| opr->param() = param; | |||
| opr->deduce_layout(filter, diff, grad); | |||
| } | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype::Float32()) | |||
| .set_dtype(1, dtype::Float32()); | |||
| checker.exec(TensorLayoutArray{filter, diff, grad}); | |||
| }; | |||
| for (auto mode : {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) { | |||
| param.mode = mode; | |||
| run(1, 4, 2, 2, 4, 1, 1, 1, 0, 1, 1); | |||
| run(1, 4, 2, 2, 4, 3, 3, 1, 0, 1, 1); | |||
| run(1, 4, 2, 2, 4, 3, 3, 1, 1, 1, 1); | |||
| run(4, 16, 10, 13, 16, 1, 1, 1, 0, 1, 1); | |||
| run(4, 16, 10, 13, 16, 3, 3, 1, 0, 1, 1); | |||
| run(4, 16, 10, 13, 16, 3, 3, 1, 1, 1, 1); | |||
| run(4, 32, 11, 23, 32, 1, 1, 1, 0, 1, 4); | |||
| run(4, 16, 11, 23, 8, 3, 3, 1, 0, 1, 4); | |||
| run(4, 16, 11, 23, 8, 3, 3, 1, 1, 1, 4); | |||
| run(4, 16, 11, 23, 8, 3, 3, 2, 1, 1, 4); | |||
| } | |||
| } | |||
| TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_RECORD) { | |||
| TaskRecordChecker<ConvolutionBackwardData> checker(1); | |||
| using Param = ConvolutionBackwardData::Param; | |||
| @@ -707,4 +761,73 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NAIVE_ALGO) { | |||
| } | |||
| } | |||
| #if MEGDNN_WITH_BENCHMARK | |||
| TEST_F(FALLBACK, BENCHMARK_CONVOLUTION_BACKWARD_DATA_NCHW44) { | |||
| using Param = ConvolutionBackwardData::Param; | |||
| auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, size_t fh, | |||
| size_t fw, size_t stride, size_t padding, size_t dilate = 1, | |||
| size_t group = 1) { | |||
| Param param; | |||
| param.pad_h = param.pad_w = padding; | |||
| param.stride_h = param.stride_w = stride; | |||
| param.dilate_h = param.dilate_w = dilate; | |||
| TensorLayout diff_nchw44 = | |||
| TensorLayout{{n, oc / 4 * group, oh, ow, 4}, dtype::Float32()}; | |||
| TensorLayout diff = TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()}; | |||
| TensorLayout grad; | |||
| TensorLayout grad_nchw44; | |||
| TensorLayout filter; | |||
| TensorLayout filter_nchw44; | |||
| if (group == 1) { | |||
| param.sparse = Param::Sparse::DENSE; | |||
| filter_nchw44 = {{oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
| filter = {{oc, ic, fh, fw}, dtype::Float32()}; | |||
| } else { | |||
| param.sparse = Param::Sparse::GROUP; | |||
| filter_nchw44 = {{group, oc / 4, ic / 4, fh, fw, 4, 4}, dtype::Float32()}; | |||
| filter = {{group, oc, ic, fh, fw}, dtype::Float32()}; | |||
| } | |||
| { | |||
| auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||
| opr->param() = param; | |||
| opr->deduce_layout(filter, diff, grad); | |||
| opr->param().format = Param::Format::NCHW44; | |||
| opr->deduce_layout(filter_nchw44, diff_nchw44, grad_nchw44); | |||
| } | |||
| Benchmarker<ConvolutionBackwardData> benchmarker_fallback(handle()); | |||
| size_t RUN = 50; | |||
| benchmarker_fallback.set_display(false) | |||
| .set_dtype(0, dtype::Float32{}) | |||
| .set_dtype(1, dtype::Float32{}) | |||
| .set_dtype(2, dtype::Float32{}) | |||
| .set_times(RUN); | |||
| auto tnchw = | |||
| benchmarker_fallback.set_param(param) | |||
| .exec(TensorLayoutArray{filter, diff, grad}); | |||
| param.format = Param::Format::NCHW44; | |||
| auto tnchw44 = | |||
| benchmarker_fallback.set_param(param) | |||
| .exec(TensorLayoutArray{filter_nchw44, diff_nchw44, grad_nchw44}); | |||
| size_t IC = ic; | |||
| size_t FH = fh; | |||
| size_t FW = fw; | |||
| size_t total_flops = IC * diff.total_nr_elems() * FH * FW * 2; | |||
| printf("nchw_time: %.3f ms nchw_flops: %.3f Gflops\n", tnchw, | |||
| total_flops / (tnchw / RUN * 1e6)); | |||
| printf("nchw44_time: %.3f ms nchw44_flops: %.3f Gflops\n", tnchw44, | |||
| total_flops / (tnchw44 / RUN * 1e6)); | |||
| printf("speedup: %.3f\n", tnchw / tnchw44); | |||
| }; | |||
| run(1, 16, 14, 14, 16, 3, 3, 1, 1, 1, 1); | |||
| run(1, 32, 28, 28, 16, 3, 3, 1, 1, 1, 1); | |||
| run(1, 48, 28, 28, 48, 2, 2, 1, 0, 1, 1); | |||
| run(1, 32, 26, 26, 32, 3, 3, 1, 0, 1, 1); | |||
| run(2, 32, 64, 64, 32, 3, 3, 1, 0, 1, 1); | |||
| run(2, 16, 112, 112, 16, 3, 3, 1, 0, 1, 1); | |||
| } | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||