| @@ -21,6 +21,7 @@ using namespace megdnn; | |||||
| using namespace fallback; | using namespace fallback; | ||||
| MIDOUT_DECL(megdnn_fallback_conv) | MIDOUT_DECL(megdnn_fallback_conv) | ||||
| MIDOUT_DECL(megdnn_fallback_deconv) | |||||
| namespace { | namespace { | ||||
| @@ -459,6 +460,70 @@ SmallVector<ConvolutionImpl::NCBKern> ConvolutionImpl::AlgoDefault::get_kimpl( | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| } | } | ||||
| /////////////////////////// ConvolutionBackwardData ///////////////////// | |||||
| /* ===================== naive algo ===================== */ | |||||
| bool ConvolutionBackwardDataImpl::AlgoNaive::usable( | |||||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
| bool ret = false; | |||||
| #define cb(dt) ret |= (param.diff_type.enumv() == DTypeTrait<dt>::enumv); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
| #undef cb | |||||
| #define cb(dt_src, dt_dst) \ | |||||
| ret |= (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) | |||||
| cb(dtype::Int8, dtype::Int32); | |||||
| cb(dtype::Quantized8Asymm, dtype::QuantizedS32); | |||||
| cb(dtype::QuantizedS8, dtype::QuantizedS32); | |||||
| #undef cb | |||||
| return ret; | |||||
| } | |||||
| size_t ConvolutionBackwardDataImpl::AlgoNaive::get_workspace( | |||||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam&) const { | |||||
| return 0; | |||||
| } | |||||
| ConvolutionBackwardDataImpl::ncb_kern_t | |||||
| ConvolutionBackwardDataImpl::AlgoNaive::dispatch_kern( | |||||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | |||||
| #define cb(_dt) \ | |||||
| do { \ | |||||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_deconv, \ | |||||
| midout_iv(DTypeTrait<_dt>::enumv)) { \ | |||||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||||
| return kern_naive<ctype, ctype, ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| } while (0); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
| #undef cb | |||||
| #define cb(dt_src, dt_dst) \ | |||||
| do { \ | |||||
| if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_deconv, \ | |||||
| midout_iv(DTypeTrait<_dt>::enumv)) { \ | |||||
| return kern_naive<DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_dst>::ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| } while (0) | |||||
| cb(dtype::Int8, dtype::Int32); | |||||
| cb(dtype::Quantized8Asymm, dtype::QuantizedS32); | |||||
| cb(dtype::QuantizedS8, dtype::QuantizedS32); | |||||
| megdnn_throw("unsupported data type on ConvolutionBackwardData"); | |||||
| #undef cb | |||||
| } | |||||
| /* ===================== direct algo ===================== */ | /* ===================== direct algo ===================== */ | ||||
| bool ConvolutionBackwardDataImpl::AlgoDirect::usable( | bool ConvolutionBackwardDataImpl::AlgoDirect::usable( | ||||
| @@ -474,7 +539,7 @@ bool ConvolutionBackwardDataImpl::AlgoDirect::usable( | |||||
| size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoDirect::get_workspace( | ||||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
| MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
| MIDOUT_BEGIN(megdnn_fallback_deconv, | |||||
| midout_iv("AlgoDirect::get_workspace"_hash)) { | midout_iv("AlgoDirect::get_workspace"_hash)) { | ||||
| auto FH = param.filter_meta.spatial[0], | auto FH = param.filter_meta.spatial[0], | ||||
| FW = param.filter_meta.spatial[1]; | FW = param.filter_meta.spatial[1]; | ||||
| @@ -511,7 +576,7 @@ bool ConvolutionBackwardDataImpl::AlgoMatrixMul::usable( | |||||
| size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace( | size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace( | ||||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
| MIDOUT_BEGIN(megdnn_fallback_conv, | |||||
| MIDOUT_BEGIN(megdnn_fallback_deconv, | |||||
| midout_iv("AlgoMatrixMul::get_workspace"_hash)) { | midout_iv("AlgoMatrixMul::get_workspace"_hash)) { | ||||
| return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
| } | } | ||||
| @@ -522,33 +587,33 @@ size_t ConvolutionBackwardDataImpl::AlgoMatrixMul::get_workspace( | |||||
| ConvolutionBackwardDataImpl::ncb_kern_t | ConvolutionBackwardDataImpl::ncb_kern_t | ||||
| ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern( | ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern( | ||||
| ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { | ||||
| #define cb(dt, midout_tag) \ | |||||
| do { \ | |||||
| if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| return kern_matmul<ctype, ctype, ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| #define cb(dt, midout_tag) \ | |||||
| do { \ | |||||
| if (param.filter_type.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_deconv, midout_iv(midout_tag)) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| return kern_matmul<ctype, ctype, ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| } while (0); | } while (0); | ||||
| cb(dtype::Float32, "FLOAT"_hash); | cb(dtype::Float32, "FLOAT"_hash); | ||||
| MEGDNN_INC_FLOAT16(cb(dtype::Float16, "FLOAT16"_hash)); | MEGDNN_INC_FLOAT16(cb(dtype::Float16, "FLOAT16"_hash)); | ||||
| MEGDNN_INC_FLOAT16(cb(dtype::BFloat16, "BFLOAT16"_hash)); | MEGDNN_INC_FLOAT16(cb(dtype::BFloat16, "BFLOAT16"_hash)); | ||||
| #undef cb | #undef cb | ||||
| #define cb(dt_src, dt_dst, midout_tag) \ | |||||
| do { \ | |||||
| if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_conv, midout_iv(midout_tag)) { \ | |||||
| return kern_matmul<DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_dst>::ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| #define cb(dt_src, dt_dst, midout_tag) \ | |||||
| do { \ | |||||
| if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fallback_deconv, midout_iv(midout_tag)) { \ | |||||
| return kern_matmul<DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_dst>::ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| cb(dtype::Int8, dtype::Int32, "INT8x8x32"_hash); | cb(dtype::Int8, dtype::Int32, "INT8x8x32"_hash); | ||||
| cb(dtype::QuantizedS8, dtype::QuantizedS32, "QINT8x8x32"_hash); | cb(dtype::QuantizedS8, dtype::QuantizedS32, "QINT8x8x32"_hash); | ||||
| @@ -557,4 +622,9 @@ ConvolutionBackwardDataImpl::AlgoMatrixMul::dispatch_kern( | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| bool ConvolutionBackwardDataImpl::AlgoMatrixMul::is_preferred( | |||||
| const NCBKernSizeParam& param) const { | |||||
| return is_matrix_mul_preferred(param); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -156,6 +156,20 @@ private: | |||||
| ConvBiasImpl::AlgoBase* m_algorithm; | ConvBiasImpl::AlgoBase* m_algorithm; | ||||
| }; | }; | ||||
| ////////////////////////// convolutionbackwarddata //////////////////////// | |||||
| class ConvolutionBackwardDataImpl::AlgoNaive final : public AlgoBase { | |||||
| public: | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "DeconvNaive"; } | |||||
| bool usable(ConvolutionBackwardDataImpl* opr, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| size_t get_workspace(ConvolutionBackwardDataImpl*, | |||||
| const NCBKernSizeParam& param) const override; | |||||
| ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | |||||
| const NCBKernSizeParam&) const override; | |||||
| bool is_naive() const override { return true; } | |||||
| }; | |||||
| class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase { | class ConvolutionBackwardDataImpl::AlgoDirect final : public AlgoBase { | ||||
| public: | public: | ||||
| bool is_reproducible() const override { return true; } | bool is_reproducible() const override { return true; } | ||||
| @@ -178,6 +192,7 @@ public: | |||||
| const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
| ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ncb_kern_t dispatch_kern(ConvolutionBackwardDataImpl*, | ||||
| const NCBKernSizeParam&) const override; | const NCBKernSizeParam&) const override; | ||||
| bool is_preferred(const NCBKernSizeParam& param) const override; | |||||
| }; | }; | ||||
| } // namespace fallback | } // namespace fallback | ||||
| @@ -31,12 +31,6 @@ using namespace megdnn; | |||||
| using namespace fallback; | using namespace fallback; | ||||
| namespace { | namespace { | ||||
| class NaiveConvolutionBackwardData final | |||||
| : public megdnn::ConvolutionBackwardData::Algorithm { | |||||
| bool is_reproducible() const override { return true; } | |||||
| const char* name() const override { return "NCBD"; } | |||||
| }; | |||||
| NaiveConvolutionBackwardData naive_conv_backward_data; | |||||
| template <typename T> | template <typename T> | ||||
| void incr_ptr(T*& dst, ptrdiff_t delta) { | void incr_ptr(T*& dst, ptrdiff_t delta) { | ||||
| @@ -407,11 +401,25 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { | |||||
| /* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
| struct ConvolutionBackwardDataImpl::AlgoPack { | |||||
| AlgoDirect direct; | |||||
| AlgoMatrixMul matmul; | |||||
| class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { | |||||
| AlgoNaive algo_naive; | |||||
| AlgoDirect algo_direct; | |||||
| AlgoMatrixMul algo_matmul; | |||||
| public: | |||||
| AlgoPack() { | |||||
| all_algos.emplace_back(&algo_matmul); | |||||
| all_algos.emplace_back(&algo_direct); | |||||
| all_algos.emplace_back(&algo_naive); | |||||
| } | |||||
| SmallVector<AlgoBase*> all_algos; | |||||
| }; | }; | ||||
| ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; | |||||
| SmallVector<ConvolutionBackwardDataImpl::AlgoBase*> | |||||
| ConvolutionBackwardDataImpl::algo_pack() { | |||||
| static AlgoPack sl_algo_pack; | |||||
| return sl_algo_pack.all_algos; | |||||
| } | |||||
| void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, | void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, | ||||
| _megdnn_tensor_in diff, | _megdnn_tensor_in diff, | ||||
| @@ -539,7 +547,7 @@ void ConvolutionBackwardDataImpl::exec_with_ncb_kern( | |||||
| p1g.filter_meta.group = 1; | p1g.filter_meta.group = 1; | ||||
| auto algo = get_algorithm(p1g); | auto algo = get_algorithm(p1g); | ||||
| auto kptr = ncb_1g_dispatch_kern(algo, p1g); | auto kptr = ncb_1g_dispatch_kern(algo, p1g); | ||||
| if (algo == &naive_conv_backward_data || group == 1) { | |||||
| if (group == 1 || static_cast<AlgoBase*>(algo)->is_naive()) { | |||||
| auto run = [kptr, param]() { kptr(param); }; | auto run = [kptr, param]() { kptr(param); }; | ||||
| static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run); | static_cast<naive::HandleImpl*>(handle())->dispatch_kern(run); | ||||
| } else { | } else { | ||||
| @@ -625,7 +633,6 @@ size_t ConvolutionBackwardDataImpl::ncb_1g_get_workspace( | |||||
| if (algo->handle_type() == Handle::HandleType::FALLBACK) { | if (algo->handle_type() == Handle::HandleType::FALLBACK) { | ||||
| return static_cast<AlgoBase*>(algo)->get_workspace(this, param); | return static_cast<AlgoBase*>(algo)->get_workspace(this, param); | ||||
| } | } | ||||
| megdnn_assert(algo == &naive_conv_backward_data); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| @@ -638,36 +645,6 @@ ConvolutionBackwardDataImpl::ncb_1g_dispatch_kern( | |||||
| return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); | return static_cast<AlgoBase*>(algo)->dispatch_kern(this, param); | ||||
| } | } | ||||
| if (algo == &naive_conv_backward_data) { | |||||
| #define cb(_dt) \ | |||||
| do { \ | |||||
| if (param.filter_type.enumv() == DTypeTrait<_dt>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_fb_convbwd_float, \ | |||||
| midout_iv(DTypeTrait<_dt>::enumv)) { \ | |||||
| using ctype = DTypeTrait<_dt>::ctype; \ | |||||
| return kern_naive<ctype, ctype, ctype>; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } \ | |||||
| } while (0); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); | |||||
| #undef cb | |||||
| #define cb(dt_src, dt_dst) \ | |||||
| do { \ | |||||
| if (param.diff_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.filter_type.enumv() == DTypeTrait<dt_src>::enumv && \ | |||||
| param.grad_type.enumv() == DTypeTrait<dt_dst>::enumv) { \ | |||||
| return kern_naive<DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_src>::ctype, \ | |||||
| DTypeTrait<dt_dst>::ctype>; \ | |||||
| } \ | |||||
| } while (0); | |||||
| cb(dtype::Int8, dtype::Int32) cb(dtype::Quantized8Asymm, | |||||
| dtype::QuantizedS32) | |||||
| cb(dtype::QuantizedS8, dtype::QuantizedS32) megdnn_throw( | |||||
| "unsupported data type on ConvolutionBackwardData"); | |||||
| #undef cb | |||||
| } | |||||
| megdnn_throw( | megdnn_throw( | ||||
| megdnn_mangle("no suitable ConvolutionBackwardData algorithm")); | megdnn_mangle("no suitable ConvolutionBackwardData algorithm")); | ||||
| } | } | ||||
| @@ -686,34 +663,17 @@ std::vector<ConvolutionBackwardDataImpl::Algorithm*> | |||||
| ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | ConvolutionBackwardDataImpl::ncb_1g_get_all_algorithms( | ||||
| const NCBKernSizeParam& param) { | const NCBKernSizeParam& param) { | ||||
| std::vector<Algorithm*> ret; | std::vector<Algorithm*> ret; | ||||
| ret.reserve(2); | |||||
| ret.push_back(&naive_conv_backward_data); | |||||
| // insert from lowest to highest preference | |||||
| AlgoBase* cand[2] = {nullptr}; | |||||
| if (param.filter_meta.group == 1 && param.filter_meta.dilation[0] == 1 && | |||||
| param.filter_meta.dilation[1] == 1) { | |||||
| // we currently only have non-dilated algos | |||||
| if (param.filter_type.enumv() == DTypeEnum::Float32) { | |||||
| if (is_matrix_mul_preferred(param)) { | |||||
| cand[0] = &sm_algo_pack.direct; | |||||
| cand[1] = &sm_algo_pack.matmul; | |||||
| std::vector<Algorithm*> prefer_algos; | |||||
| for (auto&& i : algo_pack()) { | |||||
| if (i->usable(this, param)) { | |||||
| if (i->is_preferred(param)) { | |||||
| prefer_algos.push_back(i); | |||||
| } else { | } else { | ||||
| cand[0] = &sm_algo_pack.matmul; | |||||
| cand[1] = &sm_algo_pack.direct; | |||||
| ret.push_back(i); | |||||
| } | } | ||||
| } else { | |||||
| cand[0] = &sm_algo_pack.matmul; | |||||
| } | |||||
| } | |||||
| for (auto i : cand) { | |||||
| if (i && i->usable(this, param)) { | |||||
| ret.push_back(i); | |||||
| } | } | ||||
| } | } | ||||
| std::reverse(ret.begin(), ret.end()); | |||||
| ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -373,7 +373,7 @@ public: | |||||
| }; | }; | ||||
| protected: | protected: | ||||
| typedef void (*ncb_kern_t)(const NCBKernParam& param); | |||||
| using ncb_kern_t = thin_function<void(const NCBKernParam& param)>; | |||||
| //! default impl calls ncb_1g_dispatch_kern() | //! default impl calls ncb_1g_dispatch_kern() | ||||
| virtual void exec_with_ncb_kern(const NCBKernParam& param); | virtual void exec_with_ncb_kern(const NCBKernParam& param); | ||||
| @@ -428,9 +428,18 @@ protected: | |||||
| bool reproducible = true) const { | bool reproducible = true) const { | ||||
| return (!reproducible || is_reproducible()) && usable(opr, param); | return (!reproducible || is_reproducible()) && usable(opr, param); | ||||
| } | } | ||||
| virtual bool is_preferred(const NCBKernSizeParam&) const { | |||||
| return false; | |||||
| } | |||||
| //! if the algo is naive, it will not split by group | |||||
| virtual bool is_naive() const { return false; } | |||||
| }; | }; | ||||
| static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); | static bool is_matrix_mul_preferred(const NCBKernSizeParam& param); | ||||
| /** | |||||
| * \brief get all the algorithm for the opr. | |||||
| */ | |||||
| virtual SmallVector<AlgoBase*> algo_pack(); | |||||
| private: | private: | ||||
| NCBKernSizeParam m_prev_selected_algo_sizep; | NCBKernSizeParam m_prev_selected_algo_sizep; | ||||
| @@ -448,11 +457,10 @@ private: | |||||
| _megdnn_tensor_out grad, | _megdnn_tensor_out grad, | ||||
| _megdnn_workspace workspace); | _megdnn_workspace workspace); | ||||
| class AlgoNaive; | |||||
| class AlgoDirect; | class AlgoDirect; | ||||
| class AlgoMatrixMul; | class AlgoMatrixMul; | ||||
| struct AlgoPack; | |||||
| static AlgoPack sm_algo_pack; | |||||
| class AlgoPack; | |||||
| }; | }; | ||||
| } // namespace fallback | } // namespace fallback | ||||
| @@ -9,6 +9,7 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "megdnn/dtype.h" | |||||
| #include "test/fallback/fixture.h" | #include "test/fallback/fixture.h" | ||||
| #include "test/common/benchmarker.h" | #include "test/common/benchmarker.h" | ||||
| @@ -614,4 +615,53 @@ TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_QUINT8) { | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(FALLBACK, CONVOLUTION_BACKWARD_DATA_NAIVE_ALGO) { | |||||
| Checker<ConvolutionBackwardData> checker(handle()); | |||||
| checker.set_before_exec_callback( | |||||
| AlgoChecker<ConvolutionBackwardData>("DeconvNaive")); | |||||
| using Param = ConvolutionBackwardData::Param; | |||||
| Param param; | |||||
| auto run = [&](size_t n, size_t ic, size_t oh, size_t ow, size_t oc, | |||||
| size_t fh, size_t fw, size_t stride, size_t padding, | |||||
| size_t dilate = 1, size_t group = 1) { | |||||
| param.pad_h = param.pad_w = padding; | |||||
| param.stride_h = param.stride_w = stride; | |||||
| param.dilate_h = param.dilate_w = dilate; | |||||
| TensorLayout diff = | |||||
| TensorLayout{{n, oc * group, oh, ow}, dtype::Float32()}; | |||||
| TensorLayout grad; | |||||
| TensorLayout filter; | |||||
| if (group == 1) { | |||||
| param.sparse = Param::Sparse::DENSE; | |||||
| filter = {{oc, ic, fh, fw}, dtype::Float32()}; | |||||
| } else { | |||||
| param.sparse = Param::Sparse::GROUP; | |||||
| filter = {{group, oc, ic, fh, fw}, dtype::Float32()}; | |||||
| } | |||||
| // TensorLayout grad; | |||||
| { | |||||
| auto opr = handle()->create_operator<ConvolutionBackwardData>(); | |||||
| opr->param() = param; | |||||
| opr->deduce_layout(filter, diff, grad); | |||||
| } | |||||
| checker.set_param(param); | |||||
| checker.exec(TensorLayoutArray{filter, diff, grad}); | |||||
| }; | |||||
| for (auto mode : | |||||
| {Param::Mode::CONVOLUTION, Param::Mode::CROSS_CORRELATION}) { | |||||
| param.mode = mode; | |||||
| run(4, 3, 10, 13, 5, 1, 1, 1, 0, 1, 1); | |||||
| run(5, 5, 24, 43, 11, 9, 3, 3, 12, 1, 2); | |||||
| run(4, 3, 10, 45, 2, 1, 1, 1, 0, 4, 3); | |||||
| run(2, 3, 9, 12, 2, 4, 6, 1, 0, 1, 2); | |||||
| run(3, 4, 17, 32, 2, 3, 2, 5, 4, 4, 3); | |||||
| run(5, 5, 24, 43, 11, 9, 3, 3, 12, 2, 2); | |||||
| run(2, 3, 20, 33, 3, 5, 7, 4, 15, 2, 3); | |||||
| run(4, 4, 6, 7, 9, 3, 2, 2, 1, 3, 2); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||