GitOrigin-RevId: f65feae5cc
tags/v1.7.0
| @@ -588,7 +588,7 @@ if(MGE_WITH_CUDA) | |||||
| set(CMAKE_CUDA_FLAGS_MINSIZEREL "-Os") | set(CMAKE_CUDA_FLAGS_MINSIZEREL "-Os") | ||||
| if(MSVC OR WIN32) | if(MSVC OR WIN32) | ||||
| set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin -compress-all") | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xfatbin -compress-all") | ||||
| set(CCBIN_FLAG "${CCBIN_FLAG} /wd4819 /wd4334 /wd4267 /wd4002 /wd4244 /wd4068 /std:c++14") | |||||
| set(CCBIN_FLAG "${CCBIN_FLAG} /wd4819 /wd4334 /wd4267 /wd4002 /wd4244 /wd4068 /std:c++14 /bigobj") | |||||
| if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") | if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") | ||||
| set(CCBIN_FLAG "${CCBIN_FLAG} -D_ITERATOR_DEBUG_LEVEL=2 -MTd") | set(CCBIN_FLAG "${CCBIN_FLAG} -D_ITERATOR_DEBUG_LEVEL=2 -MTd") | ||||
| endif() | endif() | ||||
| @@ -365,27 +365,22 @@ void aarch64::RelayoutForwardImpl::exec( | |||||
| relayout::TransposeParam trans_param; | relayout::TransposeParam trans_param; | ||||
| bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param, true); | ||||
| if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | ||||
| auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||||
| dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | ||||
| trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
| trans_param.stride_m)); | |||||
| trans_param.batch, trans_param.m, trans_param.n, | |||||
| static_cast<TransposeByte*>(src.raw_ptr()), | |||||
| static_cast<TransposeByte*>(dst.raw_ptr()), trans_param.stride_m)); | |||||
| return; | return; | ||||
| } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { | } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 2) { | ||||
| auto sptr = static_cast<Transpose2Byte*>(src.raw_ptr), | |||||
| dptr = static_cast<Transpose2Byte*>(dst.raw_ptr); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose2Byte>( | ||||
| trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
| trans_param.stride_m)); | |||||
| trans_param.batch, trans_param.m, trans_param.n, | |||||
| static_cast<Transpose2Byte*>(src.raw_ptr()), | |||||
| static_cast<Transpose2Byte*>(dst.raw_ptr()), trans_param.stride_m)); | |||||
| return; | return; | ||||
| } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { | } else if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 4) { | ||||
| auto sptr = static_cast<Transpose4Byte*>(src.raw_ptr), | |||||
| dptr = static_cast<Transpose4Byte*>(dst.raw_ptr); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<Transpose4Byte>( | ||||
| trans_param.batch, trans_param.m, trans_param.n, sptr, dptr, | |||||
| trans_param.stride_m)); | |||||
| trans_param.batch, trans_param.m, trans_param.n, | |||||
| static_cast<Transpose4Byte*>(src.raw_ptr()), | |||||
| static_cast<Transpose4Byte*>(dst.raw_ptr()), trans_param.stride_m)); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -358,11 +358,13 @@ void RotateImpl::exec( | |||||
| return fallback::RotateImpl::exec(src, dst, workspace); | return fallback::RotateImpl::exec(src, dst, workspace); | ||||
| } | } | ||||
| auto clockwise = param().clockwise; | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | MEGDNN_DISPATCH_CPU_KERN_OPR({ | ||||
| for (size_t i = 0; i < src.layout.shape[0]; ++i) { | for (size_t i = 0; i < src.layout.shape[0]; ++i) { | ||||
| Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | ||||
| Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | ||||
| rotate(src_mat, dst_mat, param().clockwise); | |||||
| rotate(src_mat, dst_mat, clockwise); | |||||
| } | } | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -205,16 +205,16 @@ void megdnn::aarch64::warp_perspective_cv_exec( | |||||
| megdnn_assert( | megdnn_assert( | ||||
| ch == 1 || ch == 3 || ch == 2, | ch == 1 || ch == 3 || ch == 2, | ||||
| "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | ||||
| const float* trans_ptr = trans.ptr<dt_float32>(); | |||||
| const int* midx_ptr = nullptr; | |||||
| if (mat_idx.raw_ptr) { | |||||
| megdnn_assert(mat_idx.layout.ndim == 1); | |||||
| midx_ptr = mat_idx.ptr<int>(); | |||||
| } | |||||
| if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | ||||
| #define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
| auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
| auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
| size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
| const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
| const int* midx_ptr = nullptr; \ | |||||
| if (mat_idx.raw_ptr()) { \ | |||||
| megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
| midx_ptr = mat_idx.ptr<int>(); \ | |||||
| } \ | |||||
| size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
| size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
| size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
| @@ -240,8 +240,14 @@ void megdnn::aarch64::warp_perspective_cv_exec( | |||||
| #undef cb | #undef cb | ||||
| } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | ||||
| #define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
| auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
| auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
| size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
| const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
| const int* midx_ptr = nullptr; \ | |||||
| if (mat_idx.raw_ptr()) { \ | |||||
| megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
| midx_ptr = mat_idx.ptr<int>(); \ | |||||
| } \ | |||||
| size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
| size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
| size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
| @@ -531,10 +531,10 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoI8x8x16Stride2Filter2:: | |||||
| megdnn_arm_common_conv_bias_int8816_kimpl, | megdnn_arm_common_conv_bias_int8816_kimpl, | ||||
| midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { | midout_iv("AlgoI8x8x16Stride2Filter2::dispatch_kerns"_hash)) { | ||||
| auto ncb_param = param; | auto ncb_param = param; | ||||
| ncb_param.src_ptr = param.src<void>(0, ncb_index.ndrange_id[0]); | |||||
| ncb_param.dst_ptr = param.dst<void>(0, ncb_index.ndrange_id[0]); | |||||
| ncb_param.filter_ptr = param.filter<void>(ncb_index.ndrange_id[0]); | |||||
| ncb_param.bias_ptr = param.bias<void>(0, ncb_index.ndrange_id[0]); | |||||
| ncb_param.src_ptr += param.src_offset(0, ncb_index.ndrange_id[0]); | |||||
| ncb_param.dst_ptr += param.dst_offset(0, ncb_index.ndrange_id[0]); | |||||
| ncb_param.filter_ptr += param.filter_offset(ncb_index.ndrange_id[0]); | |||||
| ncb_param.bias_ptr += param.bias_offset(0, ncb_index.ndrange_id[0]); | |||||
| conv_bias::conv_int8x8x16_stride2_flt2(ncb_param); | conv_bias::conv_int8x8x16_stride2_flt2(ncb_param); | ||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| @@ -133,7 +133,8 @@ static void pack_weight( | |||||
| constexpr int pack_oc = 8; | constexpr int pack_oc = 8; | ||||
| if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { | if (kern_param.bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS && oc % pack_oc != 0) { | ||||
| auto packed_bias = reinterpret_cast<int16_t*>(bundle.get(2)); | auto packed_bias = reinterpret_cast<int16_t*>(bundle.get(2)); | ||||
| memcpy(packed_bias, kern_param.bias_ptr, round_up(oc, 8) * sizeof(int16_t)); | |||||
| memcpy(packed_bias, kern_param.bias_ptr.get_ptr(), | |||||
| round_up(oc, 8) * sizeof(int16_t)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1657,4 +1657,4 @@ void CvtColorImpl::exec( | |||||
| } // namespace arm_common | } // namespace arm_common | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -220,9 +220,9 @@ void ElemwiseImpl::AlgoBinaryVecVec::exec(const KernParam& kern_param) const { | |||||
| run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ | run = OpCallerBinary<_op<_type, _type>, BcastType::VEC_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, dst.layout.dtype, \ | ||||
| src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -254,9 +254,9 @@ void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const | |||||
| _op<_type, _type>, BcastType::VEC_SCALAR>::run; \ | _op<_type, _type>, BcastType::VEC_SCALAR>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr)[0], \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr())[0], \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, dst.layout.dtype, \ | ||||
| src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -280,9 +280,9 @@ void ElemwiseImpl::AlgoBinaryVecScalar::exec(const KernParam& kern_param) const | |||||
| _op<_type, _type>, BcastType::SCALAR_VEC>::run; \ | _op<_type, _type>, BcastType::SCALAR_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr)[0], \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr())[0], \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, dst.layout.dtype, \ | ||||
| src1.layout.total_nr_elems())); \ | src1.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -318,9 +318,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||||
| _op<_type, _type>, BcastType::VEC_BCAST101>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
| binfo.z)); \ | binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -347,9 +347,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101::exec(const KernParam& kern_param) cons | |||||
| _op<_type, _type>, BcastType::BCAST101_VEC>::run; \ | _op<_type, _type>, BcastType::BCAST101_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
| binfo.z)); \ | binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -384,9 +384,9 @@ void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) cons | |||||
| _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ | _op<_type, _type>, BcastType::VEC_BCASTX0X>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
| binfo.z)); \ | binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -413,9 +413,9 @@ void ElemwiseImpl::AlgoBinaryVecBcastX0X::exec(const KernParam& kern_param) cons | |||||
| _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ | _op<_type, _type>, BcastType::BCASTX0X_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
| binfo.z)); \ | binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -450,9 +450,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) con | |||||
| _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST111C>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
| binfo.z)); \ | binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -479,9 +479,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast111C::exec(const KernParam& kern_param) con | |||||
| _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | _op<_type, _type>, BcastType::BCAST111C_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | src1.layout.dtype, dst.layout.dtype, binfo.x, binfo.y, \ | ||||
| binfo.z)); \ | binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -519,9 +519,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) co | |||||
| _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101xX>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | ||||
| binfo.y, binfo.z)); \ | binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -551,9 +551,9 @@ void ElemwiseImpl::AlgoBinaryVecBcast101xX::exec(const KernParam& kern_param) co | |||||
| _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ | _op<_type, _type>, BcastType::BCAST101xX_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | src1.layout.dtype, dst.layout.dtype, batch_size, binfo.x, \ | ||||
| binfo.y, binfo.z)); \ | binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -79,10 +79,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) c | |||||
| _op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_VEC_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -113,10 +113,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec( | |||||
| _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ | _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr)[0], \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr())[0], \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -149,10 +149,10 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( | |||||
| _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -187,11 +187,11 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( | |||||
| BcastType::BCAST111C_VEC_BCAST111C>::run; \ | BcastType::BCAST111C_VEC_BCAST111C>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ | ||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -228,10 +228,10 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec( | |||||
| BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ | BcastType::BCAST101xX_VEC_BCAST101xX>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| batch_size, binfo.x, binfo.y, binfo.z)); \ | batch_size, binfo.x, binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -268,10 +268,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec( | |||||
| _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| batch_size, binfo.x, binfo.y, binfo.z)); \ | batch_size, binfo.x, binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -306,10 +306,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec( | |||||
| _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -343,12 +343,12 @@ void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec( | |||||
| _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \ | ||||
| static_cast<const _type*>(src1.raw_ptr), \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr()), \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \ | ||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| binfo.x, binfo.y, binfo.z)); \ | binfo.x, binfo.y, binfo.z)); \ | ||||
| } \ | } \ | ||||
| @@ -380,10 +380,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec( | |||||
| _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ | _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr)[0], \ | |||||
| static_cast<const _type*>(src2.raw_ptr), \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr())[0], \ | |||||
| static_cast<const _type*>(src2.raw_ptr()), \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -414,10 +414,10 @@ void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec( | |||||
| _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ | _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \ | ||||
| MEGDNN_DISPATCH_CPU_KERN( \ | MEGDNN_DISPATCH_CPU_KERN( \ | ||||
| static_cast<naive::HandleImpl*>(kern_param.handle), \ | static_cast<naive::HandleImpl*>(kern_param.handle), \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr), \ | |||||
| static_cast<const _type*>(src1.raw_ptr)[0], \ | |||||
| static_cast<const _type*>(src2.raw_ptr)[0], \ | |||||
| static_cast<_type*>(dst.raw_ptr), src0.layout.dtype, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()), \ | |||||
| static_cast<const _type*>(src1.raw_ptr())[0], \ | |||||
| static_cast<const _type*>(src2.raw_ptr())[0], \ | |||||
| static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ | |||||
| src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ | ||||
| src0.layout.total_nr_elems())); \ | src0.layout.total_nr_elems())); \ | ||||
| } \ | } \ | ||||
| @@ -76,8 +76,8 @@ void ElemwiseImpl::AlgoUnary::exec(const KernParam& kern_param) const { | |||||
| size_t offset = task_id * nr_elems_per_thread; \ | size_t offset = task_id * nr_elems_per_thread; \ | ||||
| size_t nr_elems_thread = \ | size_t nr_elems_thread = \ | ||||
| std::min(nr_elems - offset, nr_elems_per_thread); \ | std::min(nr_elems - offset, nr_elems_per_thread); \ | ||||
| run(static_cast<const _type*>(src0.raw_ptr) + offset, \ | |||||
| static_cast<_type*>(dst_tensor.raw_ptr) + offset, \ | |||||
| run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ | |||||
| static_cast<_type*>(dst_tensor.raw_ptr()) + offset, \ | |||||
| src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ | src0.layout.dtype, dst_tensor.layout.dtype, nr_elems_thread); \ | ||||
| }; \ | }; \ | ||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
| @@ -148,17 +148,17 @@ void ElemwiseMultiTypeImpl::neon_round_shr_saturate_bcast_scalar<int32_t>( | |||||
| template <typename ctype> | template <typename ctype> | ||||
| void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | void ElemwiseMultiTypeImpl::dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | ||||
| const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst) { | |||||
| auto a_ptr = param[0].ptr<ctype>(); | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
| auto k = param[1].ptr<dt_int8>()[0]; | auto k = param[1].ptr<dt_int8>()[0]; | ||||
| size_t size = param.size; | size_t size = param.size; | ||||
| auto src = param[0]; | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( | |||||
| neon_round_shr_saturate_bcast_scalar(a_ptr, k, size, dst)); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(neon_round_shr_saturate_bcast_scalar( | |||||
| src.ptr<ctype>(), k, size, static_cast<dt_int8*>(dst.raw_ptr()))); | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | ||||
| const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst) { | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
| if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | ||||
| switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
| #define cb(t) \ | #define cb(t) \ | ||||
| @@ -282,7 +282,7 @@ void neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_int32( | |||||
| } | } | ||||
| bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | ||||
| const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
| BroadcastChannelInfo binfo; | BroadcastChannelInfo binfo; | ||||
| if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
| is_broadcasted_channel_like(param[1].layout, binfo) && | is_broadcasted_channel_like(param[1].layout, binfo) && | ||||
| @@ -294,16 +294,18 @@ bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | |||||
| auto minv = param[4].ptr<dt_int8>()[0]; | auto minv = param[4].ptr<dt_int8>()[0]; | ||||
| auto maxv = param[5].ptr<dt_int8>()[0]; | auto maxv = param[5].ptr<dt_int8>()[0]; | ||||
| switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
| #define DISPATCH(stype, suffix) \ | |||||
| case DTypeTrait<stype>::enumv: { \ | |||||
| auto x_ptr = param[0].ptr<DTypeTrait<stype>::ctype>(); \ | |||||
| auto b_ptr = param[1].ptr<DTypeTrait<stype>::ctype>(); \ | |||||
| auto M = param[2].ptr<DTypeTrait<stype>::ctype>()[0]; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ | |||||
| binfo.x, binfo.y, binfo.z, x_ptr, b_ptr, M, offset, minv, \ | |||||
| maxv, param.size, dst)); \ | |||||
| break; \ | |||||
| #define DISPATCH(stype, suffix) \ | |||||
| case DTypeTrait<stype>::enumv: { \ | |||||
| auto M = param[2].ptr<DTypeTrait<stype>::ctype>()[0]; \ | |||||
| auto src0 = param[0]; \ | |||||
| auto src1 = param[1]; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| neon_fuse_add_rmulh_round_shr_saturate_bcast_1c11_##suffix( \ | |||||
| binfo.x, binfo.y, binfo.z, \ | |||||
| src0.ptr<DTypeTrait<stype>::ctype>(), \ | |||||
| src1.ptr<DTypeTrait<stype>::ctype>(), M, offset, minv, maxv, \ | |||||
| param.size, static_cast<dt_int8*>(dst.raw_ptr()))); \ | |||||
| break; \ | |||||
| } | } | ||||
| DISPATCH(dtype::Int16, int16) | DISPATCH(dtype::Int16, int16) | ||||
| DISPATCH(dtype::Int32, int32) | DISPATCH(dtype::Int32, int32) | ||||
| @@ -317,7 +319,7 @@ bool ElemwiseMultiTypeImpl::dispatch_fuse_add_rmulh_rshr( | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
| if (dispatch_fuse_add_rmulh_rshr(param, dst)) | if (dispatch_fuse_add_rmulh_rshr(param, dst)) | ||||
| return; | return; | ||||
| fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| @@ -325,7 +327,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst) { | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
| if (dispatch_fuse_add_rmulh_rshr(param, dst)) | if (dispatch_fuse_add_rmulh_rshr(param, dst)) | ||||
| return; | return; | ||||
| fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | fallback::ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| @@ -23,18 +23,18 @@ class ElemwiseMultiTypeImpl : public fallback::ElemwiseMultiTypeImpl { | |||||
| template <typename ctype> | template <typename ctype> | ||||
| void dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | void dispatch_round_shr_saturate_iXxi8xi8_bcast_scalar( | ||||
| const ElemwiseOpParamN<2>& param, megdnn::dt_int8* dst); | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst); | |||||
| bool dispatch_fuse_add_rmulh_rshr( | bool dispatch_fuse_add_rmulh_rshr( | ||||
| const ElemwiseOpParamN<6>& param, megdnn::dt_int8* dst); | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst); | |||||
| protected: | protected: | ||||
| void on_round_shr_saturate_iXxi8xi8( | void on_round_shr_saturate_iXxi8xi8( | ||||
| const ElemwiseOpParamN<2>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||||
| void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
| void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
| void on_quantized_mode( | void on_quantized_mode( | ||||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, | const ElemwiseOpParamN<1>& param, const TensorND& dst, | ||||
| @@ -117,27 +117,27 @@ void PoolingImpl::AlgoFilterxModexStride1::exec(const PoolingKernParam& param) c | |||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| auto FH = param.filter[0]; | auto FH = param.filter[0]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id), \ | |||||
| Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) { \ | |||||
| auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
| src_dtype = param.src_type](size_t index, size_t) { \ | |||||
| size_t n = index / C; \ | |||||
| size_t c = index % C; \ | |||||
| do_pooling_compact<Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \ | |||||
| static_cast<const typename Pooler::ctype*>(src_ptr) + \ | |||||
| n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<typename Pooler::ctype*>(dst_ptr) + n * C * OH * OW + \ | |||||
| c * OH * OW, \ | |||||
| src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
| }; \ | |||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
| } \ | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(Pooler, NeonPooler, window, midout_type_id) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_arm_common_pooling, midout_iv(0), midout_iv(midout_type_id), \ | |||||
| Pooler::MIDOUT_CASE_NUM, NeonPooler::MIDOUT_CASE_NUM, window) { \ | |||||
| auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
| src_dtype = param.src_type](size_t index, size_t) { \ | |||||
| size_t n = index / C; \ | |||||
| size_t c = index % C; \ | |||||
| do_pooling_compact<Pooler MEGDNN_COMMA NeonPooler MEGDNN_COMMA window>( \ | |||||
| static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) + \ | |||||
| n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) + \ | |||||
| n * C * OH * OW + c * OH * OW, \ | |||||
| src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
| }; \ | |||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
| } \ | |||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \ | #define DISPATCH_WINDOW(Pooler, NeonPooler, dtype, ctype, comp_type, midout_type_id) \ | ||||
| @@ -213,26 +213,26 @@ void PoolingImpl::AlgoFilter2ModexStride2::exec(const PoolingKernParam& param) c | |||||
| auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id), \ | |||||
| Pooler::MIDOUT_CASE_NUM) { \ | |||||
| auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
| src_dtype = param.src_type](size_t index, size_t) { \ | |||||
| size_t n = index / C; \ | |||||
| size_t c = index % C; \ | |||||
| do_pooling_2x2<Pooler MEGDNN_COMMA mode>( \ | |||||
| static_cast<const typename Pooler::ctype*>(src_ptr) + \ | |||||
| n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<typename Pooler::ctype*>(dst_ptr) + n * C * OH * OW + \ | |||||
| c * OH * OW, \ | |||||
| src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
| }; \ | |||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
| } \ | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(Pooler, mode, midout_type_id) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_arm_common_pooling, midout_iv(1), midout_iv(midout_type_id), \ | |||||
| Pooler::MIDOUT_CASE_NUM) { \ | |||||
| auto run = [C, IH, IW, OH, OW, PH, PW, src_ptr, dst_ptr, \ | |||||
| src_dtype = param.src_type](size_t index, size_t) { \ | |||||
| size_t n = index / C; \ | |||||
| size_t c = index % C; \ | |||||
| do_pooling_2x2<Pooler MEGDNN_COMMA mode>( \ | |||||
| static_cast<const typename Pooler::ctype*>(src_ptr.get_ptr()) + \ | |||||
| n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<typename Pooler::ctype*>(dst_ptr.get_ptr()) + \ | |||||
| n * C * OH * OW + c * OH * OW, \ | |||||
| src_dtype, IH, IW, OH, OW, PH, PW); \ | |||||
| }; \ | |||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | |||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | |||||
| } \ | |||||
| MIDOUT_END() | MIDOUT_END() | ||||
| #define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ | #define DISPATCH_MODE(dtype, ctype, comp_type, midout_type_id) \ | ||||
| @@ -286,8 +286,8 @@ void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) con | |||||
| auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, func, midout_type_id) \ | #define DISPATCH_FUNC(type, func, midout_type_id) \ | ||||
| MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(2), midout_iv(midout_type_id)) { \ | ||||
| @@ -300,9 +300,11 @@ void PoolingImpl::AlgoFilter3MaxStride2::exec(const PoolingKernParam& param) con | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_max_pooling_3x3_s2x2_##func##_NEON( \ | do_max_pooling_3x3_s2x2_##func##_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ | |||||
| IW, OH, OW, PH, PW, ws); \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
| c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
| c * OH * OW, \ | |||||
| IH, IW, OH, OW, PH, PW, ws); \ | |||||
| }; \ | }; \ | ||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | ||||
| @@ -339,8 +341,8 @@ void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) | |||||
| auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ | #define DISPATCH_FUNC(type, MEGDNN_SIMD_WIDTH, midout_type_id) \ | ||||
| MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(3), midout_iv(midout_type_id)) { \ | ||||
| @@ -353,9 +355,11 @@ void PoolingImpl::AlgoFilter3AverageStride2::exec(const PoolingKernParam& param) | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_average_pooling_3x3_s2x2_NEON( \ | do_average_pooling_3x3_s2x2_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ | |||||
| IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
| c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
| c * OH * OW, \ | |||||
| IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
| }; \ | }; \ | ||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | ||||
| @@ -392,8 +396,8 @@ void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) con | |||||
| auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, func, midout_type_id) \ | #define DISPATCH_FUNC(type, func, midout_type_id) \ | ||||
| MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(4), midout_iv(midout_type_id)) { \ | ||||
| @@ -402,8 +406,10 @@ void PoolingImpl::AlgoFilter4MaxStride2::exec(const PoolingKernParam& param) con | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_max_pooling_w4x4_s2x2_##func##_NEON( \ | do_max_pooling_w4x4_s2x2_##func##_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
| c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
| c * OH * OW, \ | |||||
| src_dtype, IH, IW, OH, OW, PH, PW); \ | src_dtype, IH, IW, OH, OW, PH, PW); \ | ||||
| }; \ | }; \ | ||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
| @@ -446,8 +452,8 @@ void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) con | |||||
| auto PH = param.padding[0]; | auto PH = param.padding[0]; | ||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ | #define DISPATCH_FUNC(dtype, type, midout_type_id, MEGDNN_SIMD_WIDTH) \ | ||||
| MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \ | MIDOUT_BEGIN(megdnn_arm_common_pooling, midout_iv(5), midout_iv(midout_type_id)) { \ | ||||
| @@ -460,9 +466,11 @@ void PoolingImpl::AlgoFilter5MaxStride2::exec(const PoolingKernParam& param) con | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_max_pooling_w5x5_s2x2_NEON<dtype>( \ | do_max_pooling_w5x5_s2x2_NEON<dtype>( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW + c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW + c * OH * OW, IH, \ | |||||
| IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + n * C * IH * IW + \ | |||||
| c * IH * IW, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW + \ | |||||
| c * OH * OW, \ | |||||
| IH, IW, OH, OW, PH, PW, ws, MEGDNN_SIMD_WIDTH); \ | |||||
| }; \ | }; \ | ||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), N* C, run); \ | ||||
| @@ -593,8 +601,8 @@ void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( | |||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| @@ -608,9 +616,9 @@ void PoolingImpl::AlgoFilter3ModexStridexNCHW44::exec( | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_3x3_stride##i##_##func##_nchw44_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
| c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
| n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
| c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
| IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
| }; \ | }; \ | ||||
| @@ -685,8 +693,8 @@ void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( | |||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| @@ -700,9 +708,9 @@ void PoolingImpl::AlgoFilter2ModexStridexNCHW44::exec( | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_2x2_stride##i##_##func##_nchw44_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
| c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
| n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
| c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
| IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
| }; \ | }; \ | ||||
| @@ -778,8 +786,8 @@ void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( | |||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| @@ -793,9 +801,9 @@ void PoolingImpl::AlgoFilter4ModexStridexNCHW44::exec( | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_4x4_stride##i##_##func##_nchw44_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
| c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
| n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
| c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
| IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
| }; \ | }; \ | ||||
| @@ -870,8 +878,8 @@ void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( | |||||
| auto PW = param.padding[1]; | auto PW = param.padding[1]; | ||||
| auto SW = param.stride[0]; | auto SW = param.stride[0]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(type, func, i, mode) \ | #define DISPATCH_FUNC(type, func, i, mode) \ | ||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| @@ -885,9 +893,9 @@ void PoolingImpl::AlgoFilter5ModexStridexNCHW44::exec( | |||||
| size_t n = index / C; \ | size_t n = index / C; \ | ||||
| size_t c = index % C; \ | size_t c = index % C; \ | ||||
| do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | do_##mode##_pooling_5x5_stride##i##_##func##_nchw44_NEON( \ | ||||
| static_cast<const type*>(src_ptr) + n * C * IH * IW * 4 + \ | |||||
| c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr) + n * C * OH * OW * 4 + \ | |||||
| static_cast<const type*>(src_ptr.get_ptr()) + \ | |||||
| n * C * IH * IW * 4 + c * IH * IW * 4, \ | |||||
| static_cast<type*>(dst_ptr.get_ptr()) + n * C * OH * OW * 4 + \ | |||||
| c * OH * OW * 4, \ | c * OH * OW * 4, \ | ||||
| IH, IW, OH, OW, PH, PW, ws); \ | IH, IW, OH, OW, PH, PW, ws); \ | ||||
| }; \ | }; \ | ||||
| @@ -50,8 +50,8 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
| int sh = param.stride[0]; | int sh = param.stride[0]; | ||||
| int fh = param.filter[0]; | int fh = param.filter[0]; | ||||
| void* src_ptr = param.src_ptr; | |||||
| void* dst_ptr = param.dst_ptr; | |||||
| auto src_ptr = param.src_ptr; | |||||
| auto dst_ptr = param.dst_ptr; | |||||
| #define DISPATCH_FUNC(filter, stride, mode) \ | #define DISPATCH_FUNC(filter, stride, mode) \ | ||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| @@ -60,9 +60,10 @@ void PoolingImpl::AlgoFp32ModexStridexNCHW44::exec( | |||||
| auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, size_t) { \ | auto run = [ih, iw, oh, ow, ph, pw, src_ptr, dst_ptr](size_t index, size_t) { \ | ||||
| const int c_idx = index; \ | const int c_idx = index; \ | ||||
| pooling_fp32_nchw44<filter, stride, mode>( \ | pooling_fp32_nchw44<filter, stride, mode>( \ | ||||
| static_cast<const float*>(src_ptr) + c_idx * ih * iw * 4, \ | |||||
| static_cast<float*>(dst_ptr) + c_idx * oh * ow * 4, ih, iw, oh, \ | |||||
| ow, ph, pw); \ | |||||
| static_cast<const float*>(src_ptr.get_ptr()) + \ | |||||
| c_idx * ih * iw * 4, \ | |||||
| static_cast<float*>(dst_ptr.get_ptr()) + c_idx * oh * ow * 4, ih, \ | |||||
| iw, oh, ow, ph, pw); \ | |||||
| }; \ | }; \ | ||||
| MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ | ||||
| static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \ | static_cast<::megdnn::naive::HandleImpl*>(param.handle), n* ic, run); \ | ||||
| @@ -89,8 +89,8 @@ PoolingImpl::PoolingKernParam PoolingImpl::make_pooling_kern_param( | |||||
| PoolingKernParam ret; | PoolingKernParam ret; | ||||
| static_cast<PoolingKernSizeParam&>(ret) = | static_cast<PoolingKernSizeParam&>(ret) = | ||||
| make_pooling_kern_szie_param(opr, src.layout, dst.layout); | make_pooling_kern_szie_param(opr, src.layout, dst.layout); | ||||
| ret.src_ptr = src.raw_ptr; | |||||
| ret.dst_ptr = dst.raw_ptr; | |||||
| ret.src_ptr = src.get_ref_ptr(); | |||||
| ret.dst_ptr = dst.get_ref_ptr(); | |||||
| ret.workspace_ptr = workspace.raw_ptr; | ret.workspace_ptr = workspace.raw_ptr; | ||||
| ret.workspace_size = workspace.size; | ret.workspace_size = workspace.size; | ||||
| return ret; | return ret; | ||||
| @@ -56,21 +56,21 @@ public: | |||||
| }; | }; | ||||
| struct PoolingKernParam : public PoolingKernSizeParam { | struct PoolingKernParam : public PoolingKernSizeParam { | ||||
| void* src_ptr; | |||||
| void* dst_ptr; | |||||
| RefPtr src_ptr; | |||||
| RefPtr dst_ptr; | |||||
| void* workspace_ptr; | void* workspace_ptr; | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| template <typename T> | template <typename T> | ||||
| const T* src() const { | const T* src() const { | ||||
| src_type.assert_is_compatible_ctype<T>(); | src_type.assert_is_compatible_ctype<T>(); | ||||
| return static_cast<const T*>(src_ptr); | |||||
| return static_cast<const T*>(src_ptr.get_ptr()); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| T* dst() const { | T* dst() const { | ||||
| dst_type.assert_is_compatible_ctype<T>(); | dst_type.assert_is_compatible_ctype<T>(); | ||||
| return static_cast<T*>(dst_ptr); | |||||
| return static_cast<T*>(dst_ptr.get_ptr()); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| @@ -816,8 +816,8 @@ void ReduceImpl::exec( | |||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | ||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | ||||
| reinterpret_cast<ctype*>(src.raw_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst.raw_ptr), src_type, A, B, C)); \ | |||||
| reinterpret_cast<ctype*>(src.raw_ptr()), \ | |||||
| reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C)); \ | |||||
| execed = true; \ | execed = true; \ | ||||
| } \ | } \ | ||||
| MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
| @@ -828,8 +828,8 @@ void ReduceImpl::exec( | |||||
| MIDOUT_BEGIN( \ | MIDOUT_BEGIN( \ | ||||
| megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | megdnn_arm_common_reduce, ctype, dtype, comp_type, midout_iv(1)) { \ | ||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | MEGDNN_DISPATCH_CPU_KERN_OPR(do_reduce( \ | ||||
| reinterpret_cast<ctype*>(src.raw_ptr), \ | |||||
| reinterpret_cast<ctype*>(dst.raw_ptr), src_type, A, B, C)); \ | |||||
| reinterpret_cast<ctype*>(src.raw_ptr()), \ | |||||
| reinterpret_cast<ctype*>(dst.raw_ptr()), src_type, A, B, C)); \ | |||||
| execed = true; \ | execed = true; \ | ||||
| } \ | } \ | ||||
| MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
| @@ -72,14 +72,14 @@ void resize_direct_nchwxx( | |||||
| void megdnn::arm_common::resize_direct_nearest_nchw44_fp32( | void megdnn::arm_common::resize_direct_nearest_nchw44_fp32( | ||||
| const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
| resize_direct_nchwxx<float, InterpolationMode::INTER_NEAREST>( | resize_direct_nchwxx<float, InterpolationMode::INTER_NEAREST>( | ||||
| kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
| kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
| kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | ||||
| } | } | ||||
| void megdnn::arm_common::resize_direct_linear_nchw44_fp32( | void megdnn::arm_common::resize_direct_linear_nchw44_fp32( | ||||
| const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
| resize_direct_nchwxx<float, InterpolationMode::INTER_LINEAR>( | resize_direct_nchwxx<float, InterpolationMode::INTER_LINEAR>( | ||||
| kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
| kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
| kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | kern_param.ih, kern_param.iw, kern_param.oh, kern_param.ow); | ||||
| } | } | ||||
| @@ -87,8 +87,8 @@ void megdnn::arm_common::resize_direct_linear_nchw44_fp32( | |||||
| void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( | void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( | ||||
| const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
| resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>( | resize_direct_nchwxx<__fp16, InterpolationMode::INTER_NEAREST>( | ||||
| sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | ||||
| kern_param.oh, kern_param.ow); | kern_param.oh, kern_param.ow); | ||||
| @@ -96,8 +96,8 @@ void megdnn::arm_common::resize_direct_nearest_nchw88_fp16( | |||||
| void megdnn::arm_common::resize_direct_linear_nchw88_fp16( | void megdnn::arm_common::resize_direct_linear_nchw88_fp16( | ||||
| const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
| resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>( | resize_direct_nchwxx<__fp16, InterpolationMode::INTER_LINEAR>( | ||||
| sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw, | ||||
| kern_param.oh, kern_param.ow); | kern_param.oh, kern_param.ow); | ||||
| @@ -191,14 +191,14 @@ void nearest_upsample2_nchw( | |||||
| void megdnn::arm_common::resize_linear_upsample2_nchw_fp32( | void megdnn::arm_common::resize_linear_upsample2_nchw_fp32( | ||||
| const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
| linear_upsample2_nchw( | linear_upsample2_nchw( | ||||
| kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c, | |||||
| kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, | |||||
| kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( | void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( | ||||
| const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
| nearest_upsample2_nchw( | nearest_upsample2_nchw( | ||||
| kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c, | |||||
| kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c, | |||||
| kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| @@ -206,16 +206,16 @@ void megdnn::arm_common::resize_nearest_upsample2_nchw_fp32( | |||||
| void megdnn::arm_common::resize_linear_upsample2_nchw_fp16( | void megdnn::arm_common::resize_linear_upsample2_nchw_fp16( | ||||
| const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
| linear_upsample2_nchw( | linear_upsample2_nchw( | ||||
| sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16( | void megdnn::arm_common::resize_nearest_upsample2_nchw_fp16( | ||||
| const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
| nearest_upsample2_nchw( | nearest_upsample2_nchw( | ||||
| sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c, kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| @@ -158,14 +158,14 @@ void nearest_upsample2_nchwxx( | |||||
| void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( | void megdnn::arm_common::resize_linear_upsample2_nchw44_fp32( | ||||
| const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
| linear_upsample2_nchwxx( | linear_upsample2_nchwxx( | ||||
| kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
| kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
| kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( | void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( | ||||
| const ResizeImpl::KernParam<float>& kern_param) { | const ResizeImpl::KernParam<float>& kern_param) { | ||||
| nearest_upsample2_nchwxx( | nearest_upsample2_nchwxx( | ||||
| kern_param.sptr, kern_param.dptr, kern_param.n * kern_param.c / 4, | |||||
| kern_param.src(), kern_param.dst(), kern_param.n * kern_param.c / 4, | |||||
| kern_param.ih, kern_param.iw); | kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| @@ -173,16 +173,16 @@ void megdnn::arm_common::resize_nearest_upsample2_nchw44_fp32( | |||||
| void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( | void megdnn::arm_common::resize_linear_upsample2_nchw88_fp16( | ||||
| const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
| linear_upsample2_nchwxx( | linear_upsample2_nchwxx( | ||||
| sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( | void megdnn::arm_common::resize_nearest_upsample2_nchw88_fp16( | ||||
| const ResizeImpl::KernParam<dt_float16>& kern_param) { | const ResizeImpl::KernParam<dt_float16>& kern_param) { | ||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr); | |||||
| auto sptr = reinterpret_cast<const __fp16*>(kern_param.sptr.get_ptr()); | |||||
| auto dptr = reinterpret_cast<__fp16*>(kern_param.dptr.get_ptr()); | |||||
| nearest_upsample2_nchwxx( | nearest_upsample2_nchwxx( | ||||
| sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | sptr, dptr, kern_param.n * kern_param.c / 8, kern_param.ih, kern_param.iw); | ||||
| } | } | ||||
| @@ -78,9 +78,9 @@ void SeparableFilterImpl::separable_filter_exec_8u( | |||||
| megdnn_assert(src.layout.dtype == dtype::Uint8()); | megdnn_assert(src.layout.dtype == dtype::Uint8()); | ||||
| Mat<float> kernel_column( | Mat<float> kernel_column( | ||||
| 1, filter_y.layout.shape[3], 1, static_cast<float*>(filter_y.raw_ptr)); | |||||
| 1, filter_y.layout.shape[3], 1, static_cast<float*>(filter_y.raw_ptr())); | |||||
| Mat<float> kernel_row( | Mat<float> kernel_row( | ||||
| 1, filter_x.layout.shape[3], 1, static_cast<float*>(filter_x.raw_ptr)); | |||||
| 1, filter_x.layout.shape[3], 1, static_cast<float*>(filter_x.raw_ptr())); | |||||
| size_t src_channels = src.layout.shape[3]; | size_t src_channels = src.layout.shape[3]; | ||||
| @@ -128,9 +128,9 @@ void SeparableFilterImpl::separable_filter_exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, | _megdnn_tensor_in src, _megdnn_tensor_in filter_x, _megdnn_tensor_in filter_y, | ||||
| _megdnn_tensor_out dst) { | _megdnn_tensor_out dst) { | ||||
| Mat<T> kernel_column( | Mat<T> kernel_column( | ||||
| 1, filter_y.layout.shape[3], 1, static_cast<T*>(filter_y.raw_ptr)); | |||||
| 1, filter_y.layout.shape[3], 1, static_cast<T*>(filter_y.raw_ptr())); | |||||
| Mat<T> kernel_row( | Mat<T> kernel_row( | ||||
| 1, filter_x.layout.shape[3], 1, static_cast<T*>(filter_x.raw_ptr)); | |||||
| 1, filter_x.layout.shape[3], 1, static_cast<T*>(filter_x.raw_ptr())); | |||||
| size_t src_channels = src.layout.shape[3]; | size_t src_channels = src.layout.shape[3]; | ||||
| T border_value[4] = {0, 0, 0, 0}; | T border_value[4] = {0, 0, 0, 0}; | ||||
| @@ -483,18 +483,18 @@ void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||||
| #undef DISPATCH_QUANTIZED | #undef DISPATCH_QUANTIZED | ||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| #define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||||
| if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||||
| dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ | |||||
| using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ | |||||
| reinterpret_cast<_stype*>(src.raw_ptr), \ | |||||
| reinterpret_cast<_dtype*>(dst.raw_ptr), src_dtype, dst_dtype, \ | |||||
| nr_elems)); \ | |||||
| execed = true; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| #define DISPATCH_FLOAT(_stype_enumv, _stype, _dtype_enumv, _dtype, _midout_iv) \ | |||||
| if (src_dtype.enumv() == DTypeTrait<_stype_enumv>::enumv && \ | |||||
| dst_dtype.enumv() == DTypeTrait<_dtype_enumv>::enumv) { \ | |||||
| MIDOUT_BEGIN(megdnn_arm_typecvt_float, midout_iv(_midout_iv)) { \ | |||||
| using _TypeCvter = FloatTypeCvter<_stype, _dtype>; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(do_typecvt<_TypeCvter>( \ | |||||
| reinterpret_cast<_stype*>(src.raw_ptr()), \ | |||||
| reinterpret_cast<_dtype*>(dst.raw_ptr()), src_dtype, dst_dtype, \ | |||||
| nr_elems)); \ | |||||
| execed = true; \ | |||||
| } \ | |||||
| MIDOUT_END(); \ | |||||
| } | } | ||||
| DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); | DISPATCH_FLOAT(dt_float16, __fp16, float, float, 0); | ||||
| DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); | DISPATCH_FLOAT(float, float, dt_float16, __fp16, 1); | ||||
| @@ -167,21 +167,17 @@ void megdnn::arm_common::warp_perspective_cv_exec( | |||||
| megdnn_assert( | megdnn_assert( | ||||
| ch == 1 || ch == 3 || ch == 2, | ch == 1 || ch == 3 || ch == 2, | ||||
| "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | "unsupported src channel: %zu, avaiable channel size: 1/2/3", ch); | ||||
| const float* trans_ptr = trans.ptr<dt_float32>(); | |||||
| const int* midx_ptr = nullptr; | |||||
| if (mat_idx.raw_ptr) { | |||||
| megdnn_assert(mat_idx.layout.ndim == 1); | |||||
| midx_ptr = mat_idx.ptr<int>(); | |||||
| } | |||||
| if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | if (dst.layout.dtype.enumv() == DTypeEnum::Float32) { | ||||
| #define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
| auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
| auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
| size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
| size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
| size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
| size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
| if (midx_ptr) { \ | |||||
| src_id = midx_ptr[batch_id]; \ | |||||
| const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
| if (mat_idx.raw_ptr()) { \ | |||||
| megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
| src_id = mat_idx.ptr<int>()[batch_id]; \ | |||||
| megdnn_assert( \ | megdnn_assert( \ | ||||
| src_id < src.layout.shape[0], \ | src_id < src.layout.shape[0], \ | ||||
| "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | ||||
| @@ -202,13 +198,15 @@ void megdnn::arm_common::warp_perspective_cv_exec( | |||||
| #undef cb | #undef cb | ||||
| } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | } else if (dst.layout.dtype.enumv() == DTypeEnum::Uint8) { | ||||
| #define cb(_imode, _bmode, _ch) \ | #define cb(_imode, _bmode, _ch) \ | ||||
| auto task = [src, trans_ptr, midx_ptr, dst, border_value, parallelism_batch]( \ | |||||
| auto task = [src, trans, mat_idx, dst, border_value, parallelism_batch]( \ | |||||
| size_t index, size_t) { \ | size_t index, size_t) { \ | ||||
| size_t batch_id = index / parallelism_batch; \ | size_t batch_id = index / parallelism_batch; \ | ||||
| size_t task_id = index % parallelism_batch; \ | size_t task_id = index % parallelism_batch; \ | ||||
| size_t src_id = batch_id; \ | size_t src_id = batch_id; \ | ||||
| if (midx_ptr) { \ | |||||
| src_id = midx_ptr[batch_id]; \ | |||||
| const float* trans_ptr = trans.ptr<dt_float32>(); \ | |||||
| if (mat_idx.raw_ptr()) { \ | |||||
| megdnn_assert(mat_idx.layout.ndim == 1); \ | |||||
| src_id = mat_idx.ptr<int>()[batch_id]; \ | |||||
| megdnn_assert( \ | megdnn_assert( \ | ||||
| src_id < src.layout.shape[0], \ | src_id < src.layout.shape[0], \ | ||||
| "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | "mat_idx out of bound: mat_idx[%zu]=%zu src_batch=%zu", batch_id, \ | ||||
| @@ -136,10 +136,10 @@ void armv7::RelayoutForwardImpl::exec( | |||||
| relayout::TransposeParam trans_param; | relayout::TransposeParam trans_param; | ||||
| bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | bool trans = relayout::is_transpose(src.layout, dst.layout, trans_param); | ||||
| if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | if (trans && trans_param.c == 1 && src0.layout.dtype.size() == 1) { | ||||
| auto sptr = static_cast<TransposeByte*>(src.raw_ptr), | |||||
| dptr = static_cast<TransposeByte*>(dst.raw_ptr); | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | MEGDNN_DISPATCH_CPU_KERN_OPR(transpose_fallback::transpose<TransposeByte>( | ||||
| trans_param.batch, trans_param.m, trans_param.n, sptr, dptr)); | |||||
| trans_param.batch, trans_param.m, trans_param.n, | |||||
| static_cast<TransposeByte*>(src.raw_ptr()), | |||||
| static_cast<TransposeByte*>(dst.raw_ptr()))); | |||||
| return; | return; | ||||
| } | } | ||||
| exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | exec_after_preprocess(src, dst, trans ? &trans_param : nullptr); | ||||
| @@ -288,11 +288,13 @@ void RotateImpl::exec( | |||||
| return fallback::RotateImpl::exec(src, dst, workspace); | return fallback::RotateImpl::exec(src, dst, workspace); | ||||
| } | } | ||||
| auto clockwise = param().clockwise; | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR({ | MEGDNN_DISPATCH_CPU_KERN_OPR({ | ||||
| for (size_t i = 0; i < src.layout.shape[0]; ++i) { | for (size_t i = 0; i < src.layout.shape[0]; ++i) { | ||||
| Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | Mat<uchar> src_mat = TensorND2Mat<uchar>(src, i); | ||||
| Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | Mat<uchar> dst_mat = TensorND2Mat<uchar>(dst, i); | ||||
| rotate(src_mat, dst_mat, param().clockwise); | |||||
| rotate(src_mat, dst_mat, clockwise); | |||||
| } | } | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
| megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle(); | megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle(); | ||||
| megcoreGetDeviceHandle(comp_handle, &dev_handle); | megcoreGetDeviceHandle(comp_handle, &dev_handle); | ||||
| megcoreMemcpy( | megcoreMemcpy( | ||||
| comp_handle, cpu_data.data(), data.raw_ptr, cpu_data.size(), | |||||
| comp_handle, cpu_data.data(), data.raw_ptr(), cpu_data.size(), | |||||
| megcoreMemcpyDeviceToHost); | megcoreMemcpyDeviceToHost); | ||||
| megcoreSynchronize(comp_handle); | megcoreSynchronize(comp_handle); | ||||
| @@ -62,7 +62,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
| check_exec(data.layout, workspace.size); | check_exec(data.layout, workspace.size); | ||||
| auto queue = cnrt_queue(handle()); | auto queue = cnrt_queue(handle()); | ||||
| auto ptr = static_cast<uint8_t*>(data.raw_ptr); | |||||
| auto ptr = static_cast<uint8_t*>(data.raw_ptr()); | |||||
| size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | ||||
| auto last_val_size = std::min<size_t>(size_all, 4); | auto last_val_size = std::min<size_t>(size_all, 4); | ||||
| cnrt_check(cnrtMemcpyAsync( | cnrt_check(cnrtMemcpyAsync( | ||||
| @@ -72,7 +72,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
| auto&& device_info = current_device_info(); | auto&& device_info = current_device_info(); | ||||
| bang_c_wrapper( | bang_c_wrapper( | ||||
| reinterpret_cast<uint32_t*>(workspace.raw_ptr), | reinterpret_cast<uint32_t*>(workspace.raw_ptr), | ||||
| static_cast<uint32_t*>(data.raw_ptr), size_ints, queue, | |||||
| static_cast<uint32_t*>(data.raw_ptr()), size_ints, queue, | |||||
| device_info.core_version); | device_info.core_version); | ||||
| cnrt_check(cnrtMemcpyAsync( | cnrt_check(cnrtMemcpyAsync( | ||||
| &result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue, | &result.checksum, workspace.raw_ptr, sizeof(result.checksum), queue, | ||||
| @@ -38,10 +38,9 @@ void ConcatSplitBase::check_layout_common( | |||||
| megdnn_assert_eq_size_t(src.ndim, ndim); | megdnn_assert_eq_size_t(src.ndim, ndim); | ||||
| } | } | ||||
| // ensure param().axis is correct | // ensure param().axis is correct | ||||
| auto errmsg = "param().axis=" + std::to_string(param().axis) + | |||||
| ", ndim=" + std::to_string(ndim); | |||||
| MEGDNN_MARK_USED_VAR(errmsg); | |||||
| megdnn_assert(param().axis < static_cast<int32_t>(ndim), "%s", errmsg.c_str()); | |||||
| megdnn_assert( | |||||
| param().axis < static_cast<int32_t>(ndim), "param().axis=%u, ndim=%zu", | |||||
| param().axis, ndim); | |||||
| // ensure shape size for each axis is correct | // ensure shape size for each axis is correct | ||||
| for (size_t i = 0; i < ndim; ++i) { | for (size_t i = 0; i < ndim; ++i) { | ||||
| if (i == static_cast<size_t>(param().axis)) { | if (i == static_cast<size_t>(param().axis)) { | ||||
| @@ -24,28 +24,24 @@ void ElemwiseMultiTypeImplHelper::exec( | |||||
| _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { | _megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) { | ||||
| switch (m_param.mode) { | switch (m_param.mode) { | ||||
| case Mode::FUSE_MUL_ADD3_INT16x32x32x32: | case Mode::FUSE_MUL_ADD3_INT16x32x32x32: | ||||
| on_fuse_mul_add3_int16x32x32x32( | |||||
| make_elemwise_op_param<3>(src, dst), dst.ptr<dt_int32>()); | |||||
| on_fuse_mul_add3_int16x32x32x32(make_elemwise_op_param<3>(src, dst), dst); | |||||
| break; | break; | ||||
| case Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: | case Mode::FUSE_MUL_ADD3_IXxF32xF32xI8: | ||||
| on_fuse_mul_add3_iXxf32xf32xi8( | |||||
| make_elemwise_op_param<3>(src, dst), dst.ptr<dt_int8>()); | |||||
| on_fuse_mul_add3_iXxf32xf32xi8(make_elemwise_op_param<3>(src, dst), dst); | |||||
| break; | break; | ||||
| case Mode::ROUND_SHR_SATURATE_IXxI8xI8: | case Mode::ROUND_SHR_SATURATE_IXxI8xI8: | ||||
| on_round_shr_saturate_iXxi8xi8( | |||||
| make_elemwise_op_param<2>(src, dst), dst.ptr<dt_int8>()); | |||||
| on_round_shr_saturate_iXxi8xi8(make_elemwise_op_param<2>(src, dst), dst); | |||||
| break; | break; | ||||
| case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: | case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8: | ||||
| on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| make_elemwise_op_param<6>(src, dst), dst.ptr<dt_int8>()); | |||||
| make_elemwise_op_param<6>(src, dst), dst); | |||||
| break; | break; | ||||
| case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: | case Mode::FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8: | ||||
| on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| make_elemwise_op_param<6>(src, dst), dst.ptr<dt_int8>()); | |||||
| make_elemwise_op_param<6>(src, dst), dst); | |||||
| break; | break; | ||||
| case Mode::ROUND_SHR_SATURATE_IXxI8xI16: | case Mode::ROUND_SHR_SATURATE_IXxI8xI16: | ||||
| on_round_shr_saturate_iXxi8xi16( | |||||
| make_elemwise_op_param<2>(src, dst), dst.ptr<dt_int16>()); | |||||
| on_round_shr_saturate_iXxi8xi16(make_elemwise_op_param<2>(src, dst), dst); | |||||
| break; | break; | ||||
| ON_QUANTIZED_MODE(RELU, 1); | ON_QUANTIZED_MODE(RELU, 1); | ||||
| ON_QUANTIZED_MODE(ABS, 1); | ON_QUANTIZED_MODE(ABS, 1); | ||||
| @@ -33,22 +33,22 @@ class ElemwiseMultiTypeImplHelper : public ElemwiseMultiType, | |||||
| protected: | protected: | ||||
| virtual void on_fuse_mul_add3_int16x32x32x32( | virtual void on_fuse_mul_add3_int16x32x32x32( | ||||
| const ElemwiseOpParamN<3>& param, dt_int32* dst) = 0; | |||||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) = 0; | |||||
| virtual void on_fuse_mul_add3_iXxf32xf32xi8( | virtual void on_fuse_mul_add3_iXxf32xf32xi8( | ||||
| const ElemwiseOpParamN<3>& param, dt_int8* dst) = 0; | |||||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) = 0; | |||||
| virtual void on_round_shr_saturate_iXxi8xi8( | virtual void on_round_shr_saturate_iXxi8xi8( | ||||
| const ElemwiseOpParamN<2>& param, dt_int8* dst) = 0; | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0; | |||||
| virtual void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | virtual void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) = 0; | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) = 0; | |||||
| virtual void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | virtual void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) = 0; | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) = 0; | |||||
| virtual void on_round_shr_saturate_iXxi8xi16( | virtual void on_round_shr_saturate_iXxi8xi16( | ||||
| const ElemwiseOpParamN<2>& param, dt_int16* dst) = 0; | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) = 0; | |||||
| virtual void on_quantized_mode( | virtual void on_quantized_mode( | ||||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, | const ElemwiseOpParamN<1>& param, const TensorND& dst, | ||||
| @@ -29,9 +29,9 @@ template <int N, int OC> | |||||
| void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | void local_xcorr_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | ||||
| template <int N, int OC> | template <int N, int OC> | ||||
| void local_xcorr_tpl(const LocalKParam& kparam) { | void local_xcorr_tpl(const LocalKParam& kparam) { | ||||
| const float* src = static_cast<const float*>(kparam.src); | |||||
| const float* filter = static_cast<const float*>(kparam.filter); | |||||
| float* dst = static_cast<float*>(kparam.dst); | |||||
| const float* src = static_cast<const float*>(kparam.src.get_ptr()); | |||||
| const float* filter = static_cast<const float*>(kparam.filter.get_ptr()); | |||||
| float* dst = static_cast<float*>(kparam.dst.get_ptr()); | |||||
| float* workspace = static_cast<float*>(kparam.workspace); | float* workspace = static_cast<float*>(kparam.workspace); | ||||
| const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | ||||
| OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | ||||
| @@ -191,9 +191,9 @@ template <int N, int OC> | |||||
| void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | void local_conv_tpl(const LocalKParam& kparam) MEGDNN_SIMD_ATTRIBUTE_TARGET; | ||||
| template <int N, int OC> | template <int N, int OC> | ||||
| void local_conv_tpl(const LocalKParam& kparam) { | void local_conv_tpl(const LocalKParam& kparam) { | ||||
| const float* src = static_cast<const float*>(kparam.src); | |||||
| const float* filter = static_cast<const float*>(kparam.filter); | |||||
| float* dst = static_cast<float*>(kparam.dst); | |||||
| const float* src = static_cast<const float*>(kparam.src.get_ptr()); | |||||
| const float* filter = static_cast<const float*>(kparam.filter.get_ptr()); | |||||
| float* dst = static_cast<float*>(kparam.dst.get_ptr()); | |||||
| float* workspace = static_cast<float*>(kparam.workspace); | float* workspace = static_cast<float*>(kparam.workspace); | ||||
| const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | const int IC = kparam.ic, IH = kparam.ih, IW = kparam.iw, OH = kparam.oh, | ||||
| OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | OW = kparam.ow, FH = kparam.fh, FW = kparam.fw; | ||||
| @@ -11,9 +11,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| #if MEGDNN_CC_HOST | |||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #endif | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace reduce { | namespace reduce { | ||||
| @@ -24,16 +22,14 @@ struct SumOp { | |||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs + rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE SumOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } | |||||
| SumOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -43,18 +39,16 @@ struct MeanOp { | |||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||||
| dst[idx] = val / static_cast<wtype>(B); | |||||
| } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs + rhs; | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { | |||||
| dst.ptr<dst_ctype>()[idx] = val / static_cast<wtype>(B); | |||||
| } | } | ||||
| MEGDNN_HOST MEGDNN_DEVICE MeanOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } | |||||
| MeanOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -64,18 +58,17 @@ struct SumSqrOp { | |||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
| return static_cast<wtype>(src[idx]) * static_cast<wtype>(src[idx]); | |||||
| wtype read(uint32_t idx) { | |||||
| return static_cast<wtype>(src.ptr<src_ctype>()[idx]) * | |||||
| static_cast<wtype>(src.ptr<src_ctype>()[idx]); | |||||
| } | } | ||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs + rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE SumSqrOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { return lhs + rhs; } | |||||
| SumSqrOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -84,16 +77,14 @@ struct ProdOp { | |||||
| typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs * rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE ProdOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { return lhs * rhs; } | |||||
| ProdOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(1)), src(src), dst(dst), B(B) {} | : INIT(wtype(1)), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -102,20 +93,14 @@ struct MinOp { | |||||
| typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return lhs < rhs ? lhs : rhs; | |||||
| #else | |||||
| return std::min(lhs, rhs); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { return std::min(lhs, rhs); } | |||||
| MinOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -124,20 +109,16 @@ struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||||
| typedef dt_float32 wtype; | typedef dt_float32 wtype; | ||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
| #else | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { | |||||
| return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | ||||
| #endif | |||||
| } | } | ||||
| MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| MinOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -146,20 +127,14 @@ struct MaxOp { | |||||
| typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return lhs > rhs ? lhs : rhs; | |||||
| #else | |||||
| return std::max(lhs, rhs); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { return std::max(lhs, rhs); } | |||||
| MaxOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -168,20 +143,16 @@ struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||||
| typedef dt_float32 wtype; | typedef dt_float32 wtype; | ||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
| #else | |||||
| wtype read(uint32_t idx) { return src.ptr<src_ctype>()[idx]; } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { | |||||
| return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | ||||
| #endif | |||||
| } | } | ||||
| MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| MaxOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| @@ -190,28 +161,19 @@ struct CheckNonFiniteOp { | |||||
| typedef wtype_ wtype; | typedef wtype_ wtype; | ||||
| const wtype INIT; | const wtype INIT; | ||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| RefPtr src; | |||||
| RefPtr dst; | |||||
| const size_t B; | const size_t B; | ||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return !isfinite(src[idx]); | |||||
| #else | |||||
| return !std::isfinite(src[idx]); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs | rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| wtype read(uint32_t idx) { return !std::isfinite(src.ptr<src_ctype>()[idx]); } | |||||
| void write(uint32_t idx, wtype val) { dst.ptr<dst_ctype>()[idx] = val; } | |||||
| static wtype apply(wtype lhs, wtype rhs) { return lhs | rhs; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE | |||||
| CheckNonFiniteOp(const RefPtr& src, const RefPtr& dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | : INIT(wtype(0)), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| #if MEGDNN_CC_HOST | |||||
| void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | ||||
| #endif | |||||
| } // namespace reduce | } // namespace reduce | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -0,0 +1,222 @@ | |||||
| /** | |||||
| * \file dnn/src/common/reduce_helper_device.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/dtype.h" | |||||
| #if MEGDNN_CC_HOST | |||||
| #include "megdnn/basic_types.h" | |||||
| #endif | |||||
| namespace megdnn { | |||||
| namespace device_reduce { | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct SumOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs + rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE SumOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct MeanOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||||
| dst[idx] = val / static_cast<wtype>(B); | |||||
| } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs + rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MeanOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct SumSqrOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
| return static_cast<wtype>(src[idx]) * static_cast<wtype>(src[idx]); | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs + rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE SumSqrOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct ProdOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs * rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE ProdOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(1)), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct MinOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return lhs < rhs ? lhs : rhs; | |||||
| #else | |||||
| return std::min(lhs, rhs); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype> | |||||
| struct MinOp<src_ctype, dst_ctype, dt_float32> { | |||||
| typedef dt_float32 wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return (isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
| #else | |||||
| return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::max())), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct MaxOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return lhs > rhs ? lhs : rhs; | |||||
| #else | |||||
| return std::max(lhs, rhs); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype> | |||||
| struct MaxOp<src_ctype, dst_ctype, dt_float32> { | |||||
| typedef dt_float32 wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return (isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
| #else | |||||
| return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct CheckNonFiniteOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return !isfinite(src[idx]); | |||||
| #else | |||||
| return !std::isfinite(src[idx]); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs | rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE CheckNonFiniteOp(src_ctype* src, dst_ctype* dst, size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| } // namespace device_reduce | |||||
| namespace reduce { | |||||
| #if MEGDNN_CC_HOST | |||||
| void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, size_t axis); | |||||
| #endif | |||||
| } // namespace reduce | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -362,6 +362,10 @@ static inline void copy_plane_in_bytes( | |||||
| megcoreDeviceHandle_t get_device_handle(Handle* handle); | megcoreDeviceHandle_t get_device_handle(Handle* handle); | ||||
| static inline void incr_refp(RefPtr& ptr, ptrdiff_t delta) { | |||||
| ptr += (size_t)delta; | |||||
| } | |||||
| static inline void incr_voidp(void*& ptr, ptrdiff_t delta) { | static inline void incr_voidp(void*& ptr, ptrdiff_t delta) { | ||||
| ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ptr) + delta); | ptr = reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(ptr) + delta); | ||||
| } | } | ||||
| @@ -674,7 +678,8 @@ struct CompTypeCvter { | |||||
| comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | ||||
| comp.layout.dtype = CompType(); | comp.layout.dtype = CompType(); | ||||
| comp.layout.init_contiguous_stride(); | comp.layout.init_contiguous_stride(); | ||||
| comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++); | |||||
| comp = TensorND{ | |||||
| m_workspace_bundle->get(m_workspace_idx++), comp.layout}; | |||||
| if (src.layout.ndim) { | if (src.layout.ndim) { | ||||
| m_cvt_opr->exec(src, comp); | m_cvt_opr->exec(src, comp); | ||||
| } | } | ||||
| @@ -699,7 +704,7 @@ struct CompTypeCvter { | |||||
| * \brief get TensorND raw_ptr+low_byte pointer. | * \brief get TensorND raw_ptr+low_byte pointer. | ||||
| */ | */ | ||||
| inline dt_byte* get_low_ptr(const TensorND* tensor) { | inline dt_byte* get_low_ptr(const TensorND* tensor) { | ||||
| return static_cast<dt_byte*>(tensor->raw_ptr) + tensor->layout.span().low_byte; | |||||
| return static_cast<dt_byte*>(tensor->raw_ptr()) + tensor->layout.span().low_byte; | |||||
| } | } | ||||
| /*! | /*! | ||||
| @@ -11,7 +11,7 @@ | |||||
| #include "src/cuda/argmxx/opr_impl.h" | #include "src/cuda/argmxx/opr_impl.h" | ||||
| #include "src/common/argmxx_helper.h" | #include "src/common/argmxx_helper.h" | ||||
| #include "src/common/reduce_helper.h" | |||||
| #include "src/common/reduce_helper_device.h" | |||||
| #include "src/cuda/reduce_helper.cuh" | #include "src/cuda/reduce_helper.cuh" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| @@ -117,32 +117,34 @@ void BNForwardImpl::exec( | |||||
| #if CUDNN_VERSION >= 7410 | #if CUDNN_VERSION >= 7410 | ||||
| cudnn_check(cudnnBatchNormalizationForwardTrainingEx( | cudnn_check(cudnnBatchNormalizationForwardTrainingEx( | ||||
| handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, | handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, | ||||
| &beta, // one & zero | |||||
| tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||||
| nullptr, nullptr, // zDesc & z | |||||
| tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||||
| tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
| bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, | |||||
| variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, | |||||
| batch_inv_variance.raw_ptr, nullptr, workspace.raw_ptr, | |||||
| workspace.size, reserve.raw_ptr, reserve.layout.access_bytes())); | |||||
| &beta, // one & zero | |||||
| tensor_desc.xy_desc.desc, src.raw_ptr(), // xDesc & x | |||||
| nullptr, nullptr, // zDesc & z | |||||
| tensor_desc.xy_desc.desc, dst.raw_ptr(), // yDesc & y | |||||
| tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
| bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor, | |||||
| mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon, | |||||
| batch_mean.raw_ptr(), batch_inv_variance.raw_ptr(), nullptr, | |||||
| workspace.raw_ptr, workspace.size, reserve.raw_ptr(), | |||||
| reserve.layout.access_bytes())); | |||||
| #else | #else | ||||
| cudnn_check(cudnnBatchNormalizationForwardTraining( | cudnn_check(cudnnBatchNormalizationForwardTraining( | ||||
| handle, tensor_desc.bn_mode, &alpha, &beta, | handle, tensor_desc.bn_mode, &alpha, &beta, | ||||
| tensor_desc.xy_desc.desc, src.raw_ptr, // xDesc & x | |||||
| tensor_desc.xy_desc.desc, dst.raw_ptr, // yDesc & y | |||||
| tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
| bn_scale.raw_ptr, bn_bias.raw_ptr, m_param.avg_factor, mean.raw_ptr, | |||||
| variance.raw_ptr, m_param.epsilon, batch_mean.raw_ptr, | |||||
| batch_inv_variance.raw_ptr)); | |||||
| tensor_desc.xy_desc.desc, src.raw_ptr(), // xDesc & x | |||||
| tensor_desc.xy_desc.desc, dst.raw_ptr(), // yDesc & y | |||||
| tensor_desc.param_desc.desc, // bnScaleBiasMeanVarDesc | |||||
| bn_scale.raw_ptr(), bn_bias.raw_ptr(), m_param.avg_factor, | |||||
| mean.raw_ptr(), variance.raw_ptr(), m_param.epsilon, | |||||
| batch_mean.raw_ptr(), batch_inv_variance.raw_ptr())); | |||||
| #endif // CUDNN_VERSION >= 7410 | #endif // CUDNN_VERSION >= 7410 | ||||
| break; | break; | ||||
| case param::BN::FwdMode::INFERENCE: | case param::BN::FwdMode::INFERENCE: | ||||
| cudnn_check(cudnnBatchNormalizationForwardInference( | cudnn_check(cudnnBatchNormalizationForwardInference( | ||||
| handle, tensor_desc.bn_mode, &alpha, &beta, | handle, tensor_desc.bn_mode, &alpha, &beta, | ||||
| tensor_desc.xy_desc.desc, src.raw_ptr, tensor_desc.xy_desc.desc, | |||||
| dst.raw_ptr, tensor_desc.param_desc.desc, bn_scale.raw_ptr, | |||||
| bn_bias.raw_ptr, mean.raw_ptr, variance.raw_ptr, m_param.epsilon)); | |||||
| tensor_desc.xy_desc.desc, src.raw_ptr(), tensor_desc.xy_desc.desc, | |||||
| dst.raw_ptr(), tensor_desc.param_desc.desc, bn_scale.raw_ptr(), | |||||
| bn_bias.raw_ptr(), mean.raw_ptr(), variance.raw_ptr(), | |||||
| m_param.epsilon)); | |||||
| break; | break; | ||||
| default: | default: | ||||
| megdnn_throw("Unknown forward mode type of batch normalization."); | megdnn_throw("Unknown forward mode type of batch normalization."); | ||||
| @@ -198,27 +200,27 @@ void BNBackwardImpl::exec( | |||||
| cudnn_check(cudnnBatchNormalizationBackwardEx( | cudnn_check(cudnnBatchNormalizationBackwardEx( | ||||
| handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, &alpha, | handle, tensor_desc.bn_mode, CUDNN_BATCHNORM_OPS_BN, &alpha, &beta, &alpha, | ||||
| &beta, tensor_desc.xy_desc.desc, | &beta, tensor_desc.xy_desc.desc, | ||||
| x.raw_ptr, // xDesc & x | |||||
| nullptr, nullptr, // yDesc & y | |||||
| tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||||
| nullptr, nullptr, // dzDesc & dz | |||||
| tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||||
| tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||||
| nullptr, // bnBias | |||||
| d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||||
| m_param.epsilon, saved_batch_mean.raw_ptr, saved_batch_inv_variance.raw_ptr, | |||||
| nullptr, workspace.raw_ptr, workspace.size, reserve.raw_ptr, | |||||
| reserve.layout.access_bytes())); | |||||
| x.raw_ptr(), // xDesc & x | |||||
| nullptr, nullptr, // yDesc & y | |||||
| tensor_desc.xy_desc.desc, dy.raw_ptr(), // dyDesc & dy | |||||
| nullptr, nullptr, // dzDesc & dz | |||||
| tensor_desc.xy_desc.desc, dx.raw_ptr(), // dxDesc & dx | |||||
| tensor_desc.param_desc.desc, bn_scale.raw_ptr(), // bnScale | |||||
| nullptr, // bnBias | |||||
| d_bn_scale.raw_ptr(), d_bn_bias.raw_ptr(), // dScale, dBias | |||||
| m_param.epsilon, saved_batch_mean.raw_ptr(), | |||||
| saved_batch_inv_variance.raw_ptr(), nullptr, workspace.raw_ptr, | |||||
| workspace.size, reserve.raw_ptr(), reserve.layout.access_bytes())); | |||||
| #else | #else | ||||
| cudnn_check(cudnnBatchNormalizationBackward( | cudnn_check(cudnnBatchNormalizationBackward( | ||||
| handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | handle, tensor_desc.bn_mode, &alpha, &beta, &alpha, &beta, | ||||
| tensor_desc.xy_desc.desc, x.raw_ptr, // xDesc & x | |||||
| tensor_desc.xy_desc.desc, dy.raw_ptr, // dyDesc & dy | |||||
| tensor_desc.xy_desc.desc, dx.raw_ptr, // dxDesc & dx | |||||
| tensor_desc.param_desc.desc, bn_scale.raw_ptr, // bnScale | |||||
| d_bn_scale.raw_ptr, d_bn_bias.raw_ptr, // dScale, dBias | |||||
| m_param.epsilon, saved_batch_mean.raw_ptr, | |||||
| saved_batch_inv_variance.raw_ptr)); | |||||
| tensor_desc.xy_desc.desc, x.raw_ptr(), // xDesc & x | |||||
| tensor_desc.xy_desc.desc, dy.raw_ptr(), // dyDesc & dy | |||||
| tensor_desc.xy_desc.desc, dx.raw_ptr(), // dxDesc & dx | |||||
| tensor_desc.param_desc.desc, bn_scale.raw_ptr(), // bnScale | |||||
| d_bn_scale.raw_ptr(), d_bn_bias.raw_ptr(), // dScale, dBias | |||||
| m_param.epsilon, saved_batch_mean.raw_ptr(), | |||||
| saved_batch_inv_variance.raw_ptr())); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -80,9 +80,9 @@ void BatchedMatrixMulForwardImpl::AlgoBruteForce::exec(const ExecArgs& args) con | |||||
| rep(n, N) { | rep(n, N) { | ||||
| TensorND A_, B_, C_; | TensorND A_, B_, C_; | ||||
| auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | auto tensor_n_from_batch = [n](const TensorND& in, TensorND& out) { | ||||
| out.raw_ptr = static_cast<void*>( | |||||
| static_cast<dt_byte*>(in.raw_ptr) + | |||||
| n * in.layout.stride[0] * in.layout.dtype.size()); | |||||
| out.reset_ptr(static_cast<void*>( | |||||
| static_cast<dt_byte*>(in.raw_ptr()) + | |||||
| n * in.layout.stride[0] * in.layout.dtype.size())); | |||||
| out.layout = in.layout.remove_axis(0); | out.layout = in.layout.remove_axis(0); | ||||
| }; | }; | ||||
| tensor_n_from_batch(args.tensor_a, A_); | tensor_n_from_batch(args.tensor_a, A_); | ||||
| @@ -76,13 +76,13 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { | |||||
| static_cast<void*>(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t))); | static_cast<void*>(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t))); | ||||
| arange<uintptr_t>( | arange<uintptr_t>( | ||||
| As, reinterpret_cast<uintptr_t>(args.tensor_a.raw_ptr), | |||||
| As, reinterpret_cast<uintptr_t>(args.tensor_a.raw_ptr()), | |||||
| args.layout_a.stride[0] * dtype.size(), batch, stream); | args.layout_a.stride[0] * dtype.size(), batch, stream); | ||||
| arange<uintptr_t>( | arange<uintptr_t>( | ||||
| Bs, reinterpret_cast<uintptr_t>(args.tensor_b.raw_ptr), | |||||
| Bs, reinterpret_cast<uintptr_t>(args.tensor_b.raw_ptr()), | |||||
| args.layout_b.stride[0] * dtype.size(), batch, stream); | args.layout_b.stride[0] * dtype.size(), batch, stream); | ||||
| arange<uintptr_t>( | arange<uintptr_t>( | ||||
| Cs, reinterpret_cast<uintptr_t>(args.tensor_c.raw_ptr), | |||||
| Cs, reinterpret_cast<uintptr_t>(args.tensor_c.raw_ptr()), | |||||
| args.layout_c.stride[0] * dtype.size(), batch, stream); | args.layout_c.stride[0] * dtype.size(), batch, stream); | ||||
| auto io32_c32 = [&]() { | auto io32_c32 = [&]() { | ||||
| @@ -62,10 +62,10 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
| "workspace bundle size should be 1(ws_algo)"); | "workspace bundle size should be 1(ws_algo)"); | ||||
| cublas_check(cublasLtMatmul( | cublas_check(cublasLtMatmul( | ||||
| cublasLt_handle, desc.matmul_desc, one_half, | cublasLt_handle, desc.matmul_desc, one_half, | ||||
| static_cast<const __half*>(args.tensor_b.raw_ptr), desc.layout_b, | |||||
| static_cast<const __half*>(args.tensor_a.raw_ptr), desc.layout_a, | |||||
| zero_half, static_cast<const __half*>(args.tensor_c.raw_ptr), | |||||
| desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr), | |||||
| static_cast<const __half*>(args.tensor_b.raw_ptr()), desc.layout_b, | |||||
| static_cast<const __half*>(args.tensor_a.raw_ptr()), desc.layout_a, | |||||
| zero_half, static_cast<const __half*>(args.tensor_c.raw_ptr()), | |||||
| desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr()), | |||||
| desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); | desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0), stream)); | ||||
| }; | }; | ||||
| auto batched_sgemm = [&]() { | auto batched_sgemm = [&]() { | ||||
| @@ -77,7 +77,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
| auto dev_a = (desc.dt_a == CUDA_R_16F) | auto dev_a = (desc.dt_a == CUDA_R_16F) | ||||
| ? static_cast<void*>(args.tensor_a.ptr<dt_float16>()) | ? static_cast<void*>(args.tensor_a.ptr<dt_float16>()) | ||||
| : static_cast<void*>(args.tensor_a.ptr<dt_float32>()); | : static_cast<void*>(args.tensor_a.ptr<dt_float32>()); | ||||
| auto dev_c = static_cast<void*>(args.tensor_c.raw_ptr); | |||||
| auto dev_c = static_cast<void*>(args.tensor_c.raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| ws_bundle.nr_workspace() == 1, | ws_bundle.nr_workspace() == 1, | ||||
| "workspace bundle size should be 1(ws_algo)"); | "workspace bundle size should be 1(ws_algo)"); | ||||
| @@ -104,14 +104,14 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
| transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, | transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE, &pm, | ||||
| sizeof(pm))); | sizeof(pm))); | ||||
| cublas_check(cublasLtMatrixTransform( | cublas_check(cublasLtMatrixTransform( | ||||
| cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr, | |||||
| cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr(), | |||||
| desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b, | desc.layout_b, zero, nullptr, nullptr, ws_b, desc.layout_trans_b, | ||||
| stream)); | stream)); | ||||
| cublas_check(cublasLtMatrixTransformDescSetAttribute( | cublas_check(cublasLtMatrixTransformDescSetAttribute( | ||||
| transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, | transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a, | ||||
| sizeof(trans_a))); | sizeof(trans_a))); | ||||
| cublas_check(cublasLtMatrixTransform( | cublas_check(cublasLtMatrixTransform( | ||||
| cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr, | |||||
| cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr(), | |||||
| desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a, | desc.layout_a, zero, nullptr, nullptr, ws_a, desc.layout_trans_a, | ||||
| stream)); | stream)); | ||||
| cublas_check(cublasLtMatmul( | cublas_check(cublasLtMatmul( | ||||
| @@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
| sizeof(trans_c))); | sizeof(trans_c))); | ||||
| cublas_check(cublasLtMatrixTransform( | cublas_check(cublasLtMatrixTransform( | ||||
| cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero, | cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c, zero, | ||||
| nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c, stream)); | |||||
| nullptr, nullptr, args.tensor_c.raw_ptr(), desc.layout_c, stream)); | |||||
| cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc)); | cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc)); | ||||
| }; | }; | ||||
| @@ -8,7 +8,7 @@ | |||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "src/common/reduce_helper.h" | |||||
| #include "src/common/reduce_helper_device.h" | |||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| #include "src/cuda/reduce_helper.cuh" | #include "src/cuda/reduce_helper.cuh" | ||||
| @@ -18,7 +18,9 @@ namespace cuda { | |||||
| #define COMMA , | #define COMMA , | ||||
| INST_REDUCE(reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||||
| INST_REDUCE( | |||||
| device_reduce::CheckNonFiniteOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, | |||||
| false); | |||||
| #undef COMMA | #undef COMMA | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -15,12 +15,12 @@ | |||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #include "src/common/reduce_helper.h" | |||||
| #include "src/common/reduce_helper_device.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| using reduce::CheckNonFiniteOp; | |||||
| using device_reduce::CheckNonFiniteOp; | |||||
| size_t CheckNonFiniteImpl::get_workspace_in_bytes( | size_t CheckNonFiniteImpl::get_workspace_in_bytes( | ||||
| const TensorLayout& src, const TensorLayout& dst) { | const TensorLayout& src, const TensorLayout& dst) { | ||||
| @@ -45,7 +45,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
| check_exec(data.layout, workspace.size); | check_exec(data.layout, workspace.size); | ||||
| auto stream = cuda_stream(handle()); | auto stream = cuda_stream(handle()); | ||||
| auto ptr = static_cast<uint8_t*>(data.raw_ptr); | |||||
| auto ptr = static_cast<uint8_t*>(data.raw_ptr()); | |||||
| size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | size_t size_all = data.layout.shape[0], size_ints = size_all / sizeof(uint32_t); | ||||
| auto last_val_size = std::min<size_t>(size_all, 4); | auto last_val_size = std::min<size_t>(size_all, 4); | ||||
| cuda_check(cudaMemcpyAsync( | cuda_check(cudaMemcpyAsync( | ||||
| @@ -54,7 +54,7 @@ ChecksumForward::Result ChecksumForwardImpl::exec( | |||||
| if (size_ints) { | if (size_ints) { | ||||
| checksum::calc( | checksum::calc( | ||||
| static_cast<uint32_t*>(wbundle.get(1)), | static_cast<uint32_t*>(wbundle.get(1)), | ||||
| static_cast<uint32_t*>(data.raw_ptr), | |||||
| static_cast<uint32_t*>(data.raw_ptr()), | |||||
| static_cast<uint32_t*>(wbundle.get(0)), size_ints, stream); | static_cast<uint32_t*>(wbundle.get(0)), size_ints, stream); | ||||
| cuda_check(cudaMemcpyAsync( | cuda_check(cudaMemcpyAsync( | ||||
| &result.checksum, wbundle.get(1), sizeof(result.checksum), | &result.checksum, wbundle.get(1), sizeof(result.checksum), | ||||
| @@ -135,9 +135,9 @@ size_t ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { | ||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(1); | |||||
| conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -150,9 +150,9 @@ void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const { | |||||
| { | { | ||||
| auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
| TensorND A{args.filter_tensor->raw_ptr, config.first[0]}, | |||||
| B{args.src_tensor->raw_ptr, config.first[1]}, | |||||
| C{args.dst_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND A{args.filter_tensor->raw_ptr(), config.first[0]}, | |||||
| B{args.src_tensor->raw_ptr(), config.first[1]}, | |||||
| C{args.dst_tensor->raw_ptr(), config.first[2]}; | |||||
| config.second->exec(A, B, C, bundle.get_workspace(0)); | config.second->exec(A, B, C, bundle.get_workspace(0)); | ||||
| } | } | ||||
| handle_bias_and_nonlinear( | handle_bias_and_nonlinear( | ||||
| @@ -52,9 +52,9 @@ size_t ConvBiasForwardImpl::AlgoChanwise::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { | ||||
| WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
| conv_dst_tensor = TensorND{bundle.get(0), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -74,9 +74,9 @@ void ConvBiasForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { | |||||
| #if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
| if (is_compute_capability_required(5, 3)) { | if (is_compute_capability_required(5, 3)) { | ||||
| chanwise::run_fwd( | chanwise::run_fwd( | ||||
| static_cast<half*>(conv_dst_tensor.raw_ptr), | |||||
| static_cast<half*>(args.src_tensor->raw_ptr), | |||||
| static_cast<half*>(args.filter_tensor->raw_ptr), kparam, | |||||
| static_cast<half*>(conv_dst_tensor.raw_ptr()), | |||||
| static_cast<half*>(args.src_tensor->raw_ptr()), | |||||
| static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, | |||||
| stream); | stream); | ||||
| } else { | } else { | ||||
| chanwise::run_fwd( | chanwise::run_fwd( | ||||
| @@ -50,9 +50,9 @@ size_t ConvBiasForwardImpl::AlgoChanwise8x8x32::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoChanwise8x8x32::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoChanwise8x8x32::exec(const ExecArgs& args) const { | ||||
| WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
| conv_dst_tensor = TensorND{bundle.get(0), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -65,9 +65,9 @@ size_t ConvBiasForwardImpl::AlgoChanwiseSmall::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { | ||||
| WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
| conv_dst_tensor = TensorND{bundle.get(0), conv_dst_tensor.layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -85,9 +85,9 @@ void ConvBiasForwardImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) const { | |||||
| #if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
| case DTypeEnum::Float16: | case DTypeEnum::Float16: | ||||
| chanwise::run_fwd_small( | chanwise::run_fwd_small( | ||||
| static_cast<half*>(conv_dst_tensor.raw_ptr), | |||||
| static_cast<half*>(args.src_tensor->raw_ptr), | |||||
| static_cast<half*>(args.filter_tensor->raw_ptr), kparam, | |||||
| static_cast<half*>(conv_dst_tensor.raw_ptr()), | |||||
| static_cast<half*>(args.src_tensor->raw_ptr()), | |||||
| static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, | |||||
| stream); | stream); | ||||
| break; | break; | ||||
| #endif | #endif | ||||
| @@ -100,9 +100,9 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | ||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(1); | |||||
| conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -120,10 +120,10 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| auto status = cudnnConvolutionForward( | auto status = cudnnConvolutionForward( | ||||
| conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| conv_args.src_tensor->raw_ptr, D.filter_desc.desc, | |||||
| conv_args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| conv_args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
| conv_args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc, | conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc, | ||||
| conv_args.dst_tensor->raw_ptr); | |||||
| conv_args.dst_tensor->raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | ||||
| cudnnGetErrorString(status), conv_args.to_string().c_str()); | cudnnGetErrorString(status), conv_args.to_string().c_str()); | ||||
| @@ -231,7 +231,7 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
| auto workspace_ptr = args.workspace.raw_ptr; | auto workspace_ptr = args.workspace.raw_ptr; | ||||
| auto workspace_size = args.workspace.size; | auto workspace_size = args.workspace.size; | ||||
| auto bias_ptr = args.bias_tensor->raw_ptr; | |||||
| auto bias_ptr = args.bias_tensor->raw_ptr(); | |||||
| if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | ||||
| args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | ||||
| auto cvt = args.handle->create_operator<TypeCvt>(); | auto cvt = args.handle->create_operator<TypeCvt>(); | ||||
| @@ -242,7 +242,7 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
| auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); | auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); | ||||
| megdnn_assert(args.workspace.size >= bias_size_in_bytes); | megdnn_assert(args.workspace.size >= bias_size_in_bytes); | ||||
| cvt->exec( | cvt->exec( | ||||
| {args.bias_tensor->raw_ptr, converted_bias_layout}, | |||||
| {args.bias_tensor->raw_ptr(), converted_bias_layout}, | |||||
| TensorND{workspace_ptr, float_bias_layout}); | TensorND{workspace_ptr, float_bias_layout}); | ||||
| bias_ptr = workspace_ptr; | bias_ptr = workspace_ptr; | ||||
| @@ -254,19 +254,19 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
| if (args.z_layout->ndim == 0) { | if (args.z_layout->ndim == 0) { | ||||
| status = cudnnConvolutionBiasActivationForward( | status = cudnnConvolutionBiasActivationForward( | ||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr, D.filter_desc.desc, | |||||
| args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
| args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| workspace_ptr, workspace_size, &beta, D.dst_desc.desc, | workspace_ptr, workspace_size, &beta, D.dst_desc.desc, | ||||
| args.dst_tensor->raw_ptr, D.bias_desc.desc, bias_ptr, | |||||
| D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr); | |||||
| args.dst_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr, | |||||
| D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | |||||
| } else { | } else { | ||||
| status = cudnnConvolutionBiasActivationForward( | status = cudnnConvolutionBiasActivationForward( | ||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr, D.filter_desc.desc, | |||||
| args.filter_tensor->raw_ptr, D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
| args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| workspace_ptr, workspace_size, &beta, D.z_desc.desc, | workspace_ptr, workspace_size, &beta, D.z_desc.desc, | ||||
| args.z_tensor->raw_ptr, D.bias_desc.desc, bias_ptr, | |||||
| D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr); | |||||
| args.z_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr, | |||||
| D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | |||||
| } | } | ||||
| megdnn_assert( | megdnn_assert( | ||||
| @@ -142,9 +142,10 @@ size_t ConvBiasForwardImpl::AlgoGroupConvGeneral::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const { | ||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
| conv_dst_tensor = TensorND{ | |||||
| bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -156,11 +157,11 @@ void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const | |||||
| sub_args.dst_layout = &conv_dst_tensor.layout; | sub_args.dst_layout = &conv_dst_tensor.layout; | ||||
| auto config = prepare_sub_opr(sub_args); | auto config = prepare_sub_opr(sub_args); | ||||
| TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
| TensorND tfilter{args.filter_tensor->raw_ptr, config.first[1]}; | |||||
| TensorND tbias{args.bias_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND tz{args.z_tensor->raw_ptr, config.first[3]}; | |||||
| TensorND tdst{conv_dst_tensor.raw_ptr, config.first[4]}; | |||||
| TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
| TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[1]}; | |||||
| TensorND tbias{args.bias_tensor->raw_ptr(), config.first[2]}; | |||||
| TensorND tz{args.z_tensor->raw_ptr(), config.first[3]}; | |||||
| TensorND tdst{conv_dst_tensor.raw_ptr(), config.first[4]}; | |||||
| size_t c_pos; | size_t c_pos; | ||||
| if (args.filter_meta.format == Param::Format::NCHW || | if (args.filter_meta.format == Param::Format::NCHW || | ||||
| @@ -187,9 +188,9 @@ void ConvBiasForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) const | |||||
| for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
| config.second->exec( | config.second->exec( | ||||
| tsrc, tfilter, tbias, tz, tdst, nullptr, bundle.get_workspace(0)); | tsrc, tfilter, tbias, tz, tdst, nullptr, bundle.get_workspace(0)); | ||||
| incr_voidp(tsrc.raw_ptr, strd_src); | |||||
| incr_voidp(tdst.raw_ptr, strd_dst); | |||||
| incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
| incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
| incr_refp(tdst.get_ref_ptr(), strd_dst); | |||||
| incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
| } | } | ||||
| } | } | ||||
| handle_bias_and_nonlinear( | handle_bias_and_nonlinear( | ||||
| @@ -189,19 +189,19 @@ SmallVector<size_t> matmul_get_workspace_bundle(const BiasForwardSizeArgs& args) | |||||
| } | } | ||||
| void flip_filter( | void flip_filter( | ||||
| const BiasForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { | |||||
| const BiasForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) { | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | ||||
| auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | ||||
| auto dtype = fm.dtype; | auto dtype = fm.dtype; | ||||
| megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | ||||
| TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}}, | |||||
| TensorND src{{{OC, IC, FH, FW}, dtype}, ref_ptr}, | |||||
| dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | ||||
| dst.layout.stride[2] = -dst.layout.stride[2]; | dst.layout.stride[2] = -dst.layout.stride[2]; | ||||
| dst.layout.stride[3] = -dst.layout.stride[3]; | dst.layout.stride[3] = -dst.layout.stride[3]; | ||||
| args.handle->relayout_opr()->exec(src, dst); | args.handle->relayout_opr()->exec(src, dst); | ||||
| raw_ptr = workspace.raw_ptr; | |||||
| ref_ptr.reset(workspace.raw_ptr); | |||||
| } | } | ||||
| } // namespace conv_bias | } // namespace conv_bias | ||||
| @@ -58,7 +58,7 @@ SmallVector<size_t> matmul_get_workspace_bundle(const BiasForwardSizeArgs& args) | |||||
| * change \p raw_ptr to workspace. | * change \p raw_ptr to workspace. | ||||
| */ | */ | ||||
| void flip_filter( | void flip_filter( | ||||
| const BiasForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); | |||||
| const BiasForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr); | |||||
| struct CUDNNForwardDescs { | struct CUDNNForwardDescs { | ||||
| TensorDesc src_desc, dst_desc, bias_desc, z_desc; | TensorDesc src_desc, dst_desc, bias_desc, z_desc; | ||||
| @@ -39,7 +39,7 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGem | |||||
| void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| reorder_filter(args, filter_ptr); | reorder_filter(args, filter_ptr); | ||||
| } | } | ||||
| @@ -48,12 +48,12 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm | |||||
| void* filter_ptr = nullptr; | void* filter_ptr = nullptr; | ||||
| if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| } else { | } else { | ||||
| filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| reorder_filter(args, filter_ptr); | reorder_filter(args, filter_ptr); | ||||
| } | } | ||||
| void* bias_ptr = args.bias_tensor->raw_ptr; | |||||
| void* bias_ptr = args.bias_tensor->raw_ptr(); | |||||
| return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
| } | } | ||||
| @@ -39,7 +39,7 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm: | |||||
| void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
| } | } | ||||
| @@ -48,12 +48,12 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm:: | |||||
| void* filter_ptr = nullptr; | void* filter_ptr = nullptr; | ||||
| if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | megdnn_assert(args.preprocessed_filter->tensors.size() == 1); | ||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| } else { | } else { | ||||
| filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
| } | } | ||||
| void* bias_ptr = args.bias_tensor->raw_ptr; | |||||
| void* bias_ptr = args.bias_tensor->raw_ptr(); | |||||
| return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
| } | } | ||||
| @@ -103,7 +103,7 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
| std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | ||||
| if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
| z_ptr = args.z_tensor->raw_ptr; | |||||
| z_ptr = args.z_tensor->raw_ptr(); | |||||
| // \note these constants of cutlass epilogue will be passed to method | // \note these constants of cutlass epilogue will be passed to method | ||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | ||||
| @@ -131,8 +131,8 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( | |||||
| use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
| execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
| op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| op, args.src_tensor->raw_ptr(), filter_ptr, bias_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | ||||
| &dst_scale, stream, &src_zero); | &dst_scale, stream, &src_zero); | ||||
| @@ -159,7 +159,7 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( | |||||
| // filter: KCRS64 => CRSK64 and reorder oc | // filter: KCRS64 => CRSK64 and reorder oc | ||||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( | ||||
| reinterpret_cast<int8_t*>(reordered_filter), | reinterpret_cast<int8_t*>(reordered_filter), | ||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
| true, stream); | true, stream); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -115,7 +115,7 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
| std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | std::tie(filter_ptr, bias_ptr) = prepare_filter_bias(args); | ||||
| if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
| z_ptr = args.z_tensor->raw_ptr; | |||||
| z_ptr = args.z_tensor->raw_ptr(); | |||||
| // \note these constants of cutlass epilogue will be passed to method | // \note these constants of cutlass epilogue will be passed to method | ||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | ||||
| @@ -151,8 +151,8 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( | |||||
| use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
| execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
| op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| op, args.src_tensor->raw_ptr(), filter_ptr, bias_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | ||||
| &dst_scale, stream, &src_zero); | &dst_scale, stream, &src_zero); | ||||
| @@ -188,7 +188,7 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( | |||||
| cutlass_wrapper::reorder_nhwc_imma_filter<4>( | cutlass_wrapper::reorder_nhwc_imma_filter<4>( | ||||
| reinterpret_cast<int8_t*>(reordered_filter), | reinterpret_cast<int8_t*>(reordered_filter), | ||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
| trans_oc, alignbits, oc_iterleaved, stream); | trans_oc, alignbits, oc_iterleaved, stream); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -158,18 +158,15 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter::exec( | |||||
| UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | ||||
| // reorder filter | // reorder filter | ||||
| { | { | ||||
| TensorLayout in = *(args.filter_layout); | |||||
| TensorLayout out = {{ci / 16, 4, fh, fw, co, 4}, in.dtype}; | |||||
| TensorLayout out = { | |||||
| {ci / 16, 4, fh, fw, co, 4}, args.filter_tensor->layout.dtype}; | |||||
| out.stride[0] = 16 * co * fh * fw; | out.stride[0] = 16 * co * fh * fw; | ||||
| out.stride[1] = 4; | out.stride[1] = 4; | ||||
| out.stride[2] = fw * co * 16; | out.stride[2] = fw * co * 16; | ||||
| out.stride[3] = co * 16; | out.stride[3] = co * 16; | ||||
| out.stride[4] = 16; | out.stride[4] = 16; | ||||
| out.stride[5] = 1; | out.stride[5] = 1; | ||||
| TensorND ts_in, ts_out; | |||||
| ts_in.layout = in, ts_out.layout = out; | |||||
| ts_in.raw_ptr = args.filter_tensor->raw_ptr, | |||||
| ts_out.raw_ptr = args.workspace.raw_ptr; | |||||
| TensorND ts_in = *args.filter_tensor, ts_out{args.workspace.raw_ptr, out}; | |||||
| args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | ||||
| } | } | ||||
| @@ -160,18 +160,15 @@ void ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth::exec( | |||||
| UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | UNPACK_CONV_BIAS_CHWN4_PARAM(*(args.src_layout), fm, *(args.dst_layout), param); | ||||
| // reorder filter | // reorder filter | ||||
| { | { | ||||
| TensorLayout in = *(args.filter_layout); | |||||
| TensorLayout out = {{ci / 16, 4, fh, fw, co, 4}, in.dtype}; | |||||
| TensorLayout out = { | |||||
| {ci / 16, 4, fh, fw, co, 4}, args.filter_tensor->layout.dtype}; | |||||
| out.stride[0] = 16 * co * fh * fw; | out.stride[0] = 16 * co * fh * fw; | ||||
| out.stride[1] = 4; | out.stride[1] = 4; | ||||
| out.stride[2] = fw * co * 16; | out.stride[2] = fw * co * 16; | ||||
| out.stride[3] = co * 16; | out.stride[3] = co * 16; | ||||
| out.stride[4] = 16; | out.stride[4] = 16; | ||||
| out.stride[5] = 1; | out.stride[5] = 1; | ||||
| TensorND ts_in, ts_out; | |||||
| ts_in.layout = in, ts_out.layout = out; | |||||
| ts_in.raw_ptr = args.filter_tensor->raw_ptr, | |||||
| ts_out.raw_ptr = args.workspace.raw_ptr; | |||||
| TensorND ts_in = *args.filter_tensor, ts_out{args.workspace.raw_ptr, out}; | |||||
| args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | args.opr->handle()->create_operator<RelayoutForward>()->exec(ts_in, ts_out); | ||||
| } | } | ||||
| @@ -125,11 +125,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<int8_t*>(args.workspace.raw_ptr); | ||||
| // filter: KCRS32 => CRSK32 and reorder oc | // filter: KCRS32 => CRSK32 and reorder oc | ||||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | ||||
| filter_ptr, reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, | |||||
| ci, fh, fw, trans_oc, stream); | |||||
| filter_ptr, reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), | |||||
| co, ci, fh, fw, trans_oc, stream); | |||||
| } else { | } else { | ||||
| filter_ptr = | |||||
| reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr); | |||||
| filter_ptr = reinterpret_cast<int8_t*>( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr()); | |||||
| } | } | ||||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| @@ -157,9 +157,9 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( | |||||
| use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
| execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
| op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, | |||||
| z_dev_ptr, args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, | |||||
| wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, | |||||
| op, args.src_tensor->raw_ptr(), filter_ptr, args.bias_tensor->raw_ptr(), | |||||
| z_dev_ptr, args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, | |||||
| ho, wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, | |||||
| &threshold, &dst_scale, stream); | &threshold, &dst_scale, stream); | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| @@ -204,8 +204,8 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( | |||||
| cudaStream_t stream = cuda_stream(args.opr->handle()); | cudaStream_t stream = cuda_stream(args.opr->handle()); | ||||
| // filter: KCRS32 => CRSK32 and reorder oc | // filter: KCRS32 => CRSK32 and reorder oc | ||||
| cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( | ||||
| reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
| reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr()), | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
| trans_oc, stream); | trans_oc, stream); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -155,16 +155,13 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
| TensorLayout dst = src; | TensorLayout dst = src; | ||||
| dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.workspace.raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| TensorND ts_src{args.filter_tensor->raw_ptr(), src}, | |||||
| ts_dst{args.workspace.raw_ptr, dst}; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
| } else { | } else { | ||||
| filter_ptr = | |||||
| reinterpret_cast<int8_t*>(args.preprocessed_filter->tensors[0].raw_ptr); | |||||
| filter_ptr = reinterpret_cast<int8_t*>( | |||||
| args.preprocessed_filter->tensors[0].raw_ptr()); | |||||
| } | } | ||||
| float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | float src_scale = args.src_layout->dtype.param<dtype::QuantizedS8>().scale, | ||||
| @@ -190,7 +187,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| float delta = 0.f; | float delta = 0.f; | ||||
| void* z_ptr = nullptr; | void* z_ptr = nullptr; | ||||
| if (args.z_layout->ndim > 0) { | if (args.z_layout->ndim > 0) { | ||||
| z_ptr = args.z_tensor->raw_ptr; | |||||
| z_ptr = args.z_tensor->raw_ptr(); | |||||
| gamma = 1.f; | gamma = 1.f; | ||||
| if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | ||||
| megdnn_assert( | megdnn_assert( | ||||
| @@ -213,10 +210,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( | |||||
| use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
| execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
| op, args.src_tensor->raw_ptr, filter_ptr, args.bias_tensor->raw_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | |||||
| &dst_scale, stream); | |||||
| op, args.src_tensor->raw_ptr(), filter_ptr, args.bias_tensor->raw_ptr(), | |||||
| z_ptr, args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, | |||||
| wo, ph, pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, | |||||
| &threshold, &dst_scale, stream); | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| @@ -261,11 +258,8 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( | |||||
| src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
| TensorLayout dst = src; | TensorLayout dst = src; | ||||
| dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| TensorND ts_src{args.filter_tensor->raw_ptr(), src}, | |||||
| ts_dst{args.preprocessed_filter->tensors[0].raw_ptr(), dst}; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
| } | } | ||||
| @@ -96,11 +96,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
| src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
| TensorLayout dst = src; | TensorLayout dst = src; | ||||
| dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.src_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = ws_src; | |||||
| ts_dst.layout = dst; | |||||
| TensorND ts_src{args.src_tensor->raw_ptr(), src}, ts_dst{ws_src, dst}; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
| } | } | ||||
| @@ -111,11 +107,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
| src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
| TensorLayout dst = src; | TensorLayout dst = src; | ||||
| dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = ws_filter; | |||||
| ts_dst.layout = dst; | |||||
| TensorND ts_src{args.filter_tensor->raw_ptr(), src}, ts_dst{ws_filter, dst}; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
| } | } | ||||
| @@ -142,11 +134,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
| src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
| TensorLayout dst = src; | TensorLayout dst = src; | ||||
| dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = args.z_tensor->raw_ptr; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = ws_z; | |||||
| ts_dst.layout = dst; | |||||
| TensorND ts_src{args.z_tensor->raw_ptr(), src}, ts_dst{ws_z, dst}; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
| z_dev_ptr = reinterpret_cast<int8_t*>(ws_z); | z_dev_ptr = reinterpret_cast<int8_t*>(ws_z); | ||||
| @@ -168,11 +156,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::exec( | |||||
| src.init_contiguous_stride(); | src.init_contiguous_stride(); | ||||
| TensorLayout dst = src; | TensorLayout dst = src; | ||||
| dst.stride[0] = 1, dst.stride[1] = dst[0]; | dst.stride[0] = 1, dst.stride[1] = dst[0]; | ||||
| TensorND ts_src, ts_dst; | |||||
| ts_src.raw_ptr = ws_dst; | |||||
| ts_src.layout = src; | |||||
| ts_dst.raw_ptr = args.dst_tensor->raw_ptr; | |||||
| ts_dst.layout = dst; | |||||
| TensorND ts_src{ws_dst, src}, ts_dst{args.dst_tensor->raw_ptr(), dst}; | |||||
| auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& transpose = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| transpose->exec(ts_src, ts_dst); | transpose->exec(ts_src, ts_dst); | ||||
| } | } | ||||
| @@ -114,7 +114,7 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm:: | |||||
| void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec_preprocess( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
| } | } | ||||
| @@ -189,15 +189,15 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( | |||||
| void* z_ptr = nullptr; | void* z_ptr = nullptr; | ||||
| if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| } else { | } else { | ||||
| filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| reorder_filter(args, m_algo_param.access_size, filter_ptr); | reorder_filter(args, m_algo_param.access_size, filter_ptr); | ||||
| } | } | ||||
| bias_ptr = args.bias_tensor->raw_ptr; | |||||
| bias_ptr = args.bias_tensor->raw_ptr(); | |||||
| if (args.z_layout->ndim > 0) | if (args.z_layout->ndim > 0) | ||||
| z_ptr = args.z_tensor->raw_ptr; | |||||
| z_ptr = args.z_tensor->raw_ptr(); | |||||
| // \note these constants of cutlass epilogue will be passed to method | // \note these constants of cutlass epilogue will be passed to method | ||||
| // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | // `execute_cutlass_conv_op` by pointer and interpreted as ElementCompute*, | ||||
| @@ -233,8 +233,8 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::exec( | |||||
| use_conv_filter_unity_opt, without_shared_load); | use_conv_filter_unity_opt, without_shared_load); | ||||
| execute_cutlass_conv_op( | execute_cutlass_conv_op( | ||||
| op, args.src_tensor->raw_ptr, filter_ptr, bias_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr, nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| op, args.src_tensor->raw_ptr(), filter_ptr, bias_ptr, z_ptr, | |||||
| args.dst_tensor->raw_ptr(), nullptr, n, hi, wi, ci, co, fh, fw, ho, wo, ph, | |||||
| pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | pw, sh, sw, dh, dw, &alpha, &beta, &gamma, &delta, &theta, &threshold, | ||||
| &dst_scale, stream); | &dst_scale, stream); | ||||
| @@ -272,7 +272,7 @@ void ConvBiasForwardImpl::AlgoInt8NHWCIMMAImplicitGemm::reorder_filter( | |||||
| cutlass_wrapper::reorder_nhwc_imma_filter<8>( | cutlass_wrapper::reorder_nhwc_imma_filter<8>( | ||||
| reinterpret_cast<int8_t*>(reordered_filter), | reinterpret_cast<int8_t*>(reordered_filter), | ||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr), co, ci, fh, fw, | |||||
| reinterpret_cast<int8_t*>(args.filter_tensor->raw_ptr()), co, ci, fh, fw, | |||||
| trans_oc, alignbits, oc_iterleaved, stream); | trans_oc, alignbits, oc_iterleaved, stream); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -52,8 +52,8 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGe | |||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::exec_preprocess( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
| void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| void* reduce_workspace = reinterpret_cast<void*>( | void* reduce_workspace = reinterpret_cast<void*>( | ||||
| args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | ||||
| @@ -67,8 +67,8 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGem | |||||
| void* bias_ptr = nullptr; | void* bias_ptr = nullptr; | ||||
| if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
| return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
| } else { | } else { | ||||
| filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| @@ -130,7 +130,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( | |||||
| int src_zero_point = | int src_zero_point = | ||||
| args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | ||||
| do_dispatch_reduce_filter_and_update_bias_4bit<true>( | do_dispatch_reduce_filter_and_update_bias_4bit<true>( | ||||
| reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
| reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | ||||
| reinterpret_cast<int32_t*>(updated_bias), | reinterpret_cast<int32_t*>(updated_bias), | ||||
| reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | ||||
| @@ -52,8 +52,8 @@ SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm | |||||
| void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::exec_preprocess( | ||||
| const ExecArgs& args) const { | const ExecArgs& args) const { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
| void* filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| void* bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
| void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | void* reduce_filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| void* reduce_workspace = reinterpret_cast<void*>( | void* reduce_workspace = reinterpret_cast<void*>( | ||||
| args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | args.workspace.raw_ptr + args.bias_layout->span().dist_byte()); | ||||
| @@ -67,8 +67,8 @@ std::tuple<void*, void*> ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm: | |||||
| void* bias_ptr = nullptr; | void* bias_ptr = nullptr; | ||||
| if (args.preprocessed_filter) { | if (args.preprocessed_filter) { | ||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | ||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr; | |||||
| bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr; | |||||
| filter_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| bias_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
| return {filter_ptr, bias_ptr}; | return {filter_ptr, bias_ptr}; | ||||
| } else { | } else { | ||||
| filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | filter_ptr = reinterpret_cast<void*>(args.workspace.raw_ptr); | ||||
| @@ -146,7 +146,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::update_bias( | |||||
| int src_zero_point = | int src_zero_point = | ||||
| args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | args.src_tensor->layout.dtype.param<dtype::Quantized4Asymm>().zero_point; | ||||
| do_dispatch_reduce_filter_and_update_bias_4bit<true>( | do_dispatch_reduce_filter_and_update_bias_4bit<true>( | ||||
| reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
| reinterpret_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
| args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | args.bias_tensor->compatible_ptr<int32_t>(), co, ci * fh * fw / 8, | ||||
| reinterpret_cast<int32_t*>(updated_bias), | reinterpret_cast<int32_t*>(updated_bias), | ||||
| reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | reinterpret_cast<int32_t*>(reduce_workspace), src_zero_point, stream); | ||||
| @@ -40,9 +40,9 @@ size_t ConvBiasForwardImpl::AlgoInplaceMatmul::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoInplaceMatmul::exec(const ExecArgs& args) const { | ||||
| WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}}; | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(0); | |||||
| conv_dst_tensor = TensorND{bundle.get(0), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -115,9 +115,10 @@ size_t ConvBiasForwardImpl::AlgoMatmul::get_workspace_in_bytes( | |||||
| void ConvBiasForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { | ||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
| conv_dst_tensor = TensorND{ | |||||
| bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -168,7 +169,7 @@ void ConvBiasForwardImpl::AlgoMatmul::exec_internal( | |||||
| C(dst_t, config.first[2]); | C(dst_t, config.first[2]); | ||||
| size_t matmul_ws_idx = 2; | size_t matmul_ws_idx = 2; | ||||
| if (fm.should_flip) { | if (fm.should_flip) { | ||||
| conv_bias::flip_filter(args, bundle.get_workspace(2), A.raw_ptr); | |||||
| conv_bias::flip_filter(args, bundle.get_workspace(2), A.get_ref_ptr()); | |||||
| matmul_ws_idx = 3; | matmul_ws_idx = 3; | ||||
| } | } | ||||
| @@ -128,12 +128,10 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
| auto bundle = get_bundle<format>(args); | auto bundle = get_bundle<format>(args); | ||||
| bundle.set(args.workspace.raw_ptr); | bundle.set(args.workspace.raw_ptr); | ||||
| TensorND src_tensor, dst_tensor, filter_tensor; | |||||
| if (format == Param::Format::NHWC) { | |||||
| src_tensor = *args.src_tensor; | |||||
| dst_tensor = *args.dst_tensor; | |||||
| filter_tensor = *args.filter_tensor; | |||||
| } else { | |||||
| TensorND src_tensor = *args.src_tensor; | |||||
| TensorND dst_tensor = *args.dst_tensor; | |||||
| TensorND filter_tensor = *args.filter_tensor; | |||||
| if (format == Param::Format::NCHW4) { | |||||
| // NCHW4 | // NCHW4 | ||||
| auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND { | auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND { | ||||
| return {raw_ptr, | return {raw_ptr, | ||||
| @@ -147,7 +145,7 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
| auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2], | auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2], | ||||
| W = src.layout[3]; | W = src.layout[3]; | ||||
| args.handle->relayout_opr()->exec( | args.handle->relayout_opr()->exec( | ||||
| {src.raw_ptr, | |||||
| {src.raw_ptr(), | |||||
| TensorLayout{ | TensorLayout{ | ||||
| {N, H, W, C / 4, 4}, | {N, H, W, C / 4, 4}, | ||||
| {src.layout.stride[0], src.layout.stride[2], | {src.layout.stride[0], src.layout.stride[2], | ||||
| @@ -156,8 +154,8 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
| src.layout.dtype}}, | src.layout.dtype}}, | ||||
| {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); | {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}}); | ||||
| }; | }; | ||||
| relayout(*args.src_tensor, src_tensor.raw_ptr); | |||||
| relayout(*args.filter_tensor, filter_tensor.raw_ptr); | |||||
| relayout(*args.src_tensor, src_tensor.raw_ptr()); | |||||
| relayout(*args.filter_tensor, filter_tensor.raw_ptr()); | |||||
| } | } | ||||
| size_t N, IH, IW, IC; | size_t N, IH, IW, IC; | ||||
| @@ -193,7 +191,7 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
| // copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD | // copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD | ||||
| inp1 = static_cast<int8_t*>(bundle.get(1)); | inp1 = static_cast<int8_t*>(bundle.get(1)); | ||||
| cuda_check(cudaMemcpy2DAsync( | cuda_check(cudaMemcpy2DAsync( | ||||
| inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr, | |||||
| inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr(), | |||||
| FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC, | FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC, | ||||
| cudaMemcpyDeviceToDevice, stream)); | cudaMemcpyDeviceToDevice, stream)); | ||||
| inp1_stride = LD; | inp1_stride = LD; | ||||
| @@ -222,12 +220,13 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) | |||||
| void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { | void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { | ||||
| ExecArgs conv_args = args; | ExecArgs conv_args = args; | ||||
| auto conv_dst_tensor = *args.dst_tensor; | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.filter_meta.format == Param::Format::NHWC) { | if (args.filter_meta.format == Param::Format::NHWC) { | ||||
| auto bundle = get_bundle<Param::Format::NHWC>(args); | auto bundle = get_bundle<Param::Format::NHWC>(args); | ||||
| bundle.set(args.workspace.raw_ptr); | bundle.set(args.workspace.raw_ptr); | ||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
| conv_dst_tensor = TensorND{ | |||||
| bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -239,7 +238,8 @@ void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const { | |||||
| auto bundle = get_bundle<Param::Format::NCHW4>(args); | auto bundle = get_bundle<Param::Format::NCHW4>(args); | ||||
| bundle.set(args.workspace.raw_ptr); | bundle.set(args.workspace.raw_ptr); | ||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| conv_dst_tensor.raw_ptr = bundle.get(bundle.nr_workspace() - 1); | |||||
| conv_dst_tensor = TensorND{ | |||||
| bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | conv_dst_tensor.layout.dtype = DType(); | ||||
| args.opr->check_or_deduce_dtype_fwd( | args.opr->check_or_deduce_dtype_fwd( | ||||
| args.src_layout->dtype, args.filter_layout->dtype, | args.src_layout->dtype, args.filter_layout->dtype, | ||||
| @@ -131,26 +131,26 @@ void ConvBiasForwardImpl::AlgoQUInt4x4x32WMMA::exec(const ExecArgs& args) const | |||||
| auto&& stream = cuda_stream(handle); | auto&& stream = cuda_stream(handle); | ||||
| // zp filter | // zp filter | ||||
| do_dispatch_reduce_with_scale_filter_4bit<false>( | do_dispatch_reduce_with_scale_filter_4bit<false>( | ||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr), -zp_data, OC, | |||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), -zp_data, OC, | |||||
| FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); | FH * FW * IC / 8, ws_zp_filter.ptr<int32_t>(), stream); | ||||
| // zp data | // zp data | ||||
| do_dispatch_reduce_with_scale_data_u4( | do_dispatch_reduce_with_scale_data_u4( | ||||
| ws_zp_data.ptr<int32_t>(), static_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
| N, IH, IW, OH, OW, PH, PW, FH, FW, SH, SW, IC, -zp_filter, | |||||
| static_cast<uint8_t>(zp_data), stream); | |||||
| ws_zp_data.ptr<int32_t>(), | |||||
| static_cast<uint8_t*>(args.src_tensor->raw_ptr()), N, IH, IW, OH, OW, PH, | |||||
| PW, FH, FW, SH, SW, IC, -zp_filter, static_cast<uint8_t>(zp_data), stream); | |||||
| // do conv | // do conv | ||||
| if (use_kernel_fhxfw(args)) { | if (use_kernel_fhxfw(args)) { | ||||
| wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw( | wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_fhxfw( | ||||
| static_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
| static_cast<uint8_t*>(args.src_tensor->raw_ptr()), | |||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
| args.dst_tensor->compatible_ptr<int32_t>(), N, IH, IW, OH, OW, PH, PW, | args.dst_tensor->compatible_ptr<int32_t>(), N, IH, IW, OH, OW, PH, PW, | ||||
| IC, OC, FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | IC, OC, FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | ||||
| } else { | } else { | ||||
| auto&& ws_relayout_filter = ws_bundle.get_workspace(2); | auto&& ws_relayout_filter = ws_bundle.get_workspace(2); | ||||
| wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw( | wmma_conv_integer_subbyte::_do_wmma_conv_integer_subbyte_1xfw( | ||||
| static_cast<uint8_t*>(args.src_tensor->raw_ptr), | |||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr), | |||||
| static_cast<uint8_t*>(args.src_tensor->raw_ptr()), | |||||
| static_cast<uint8_t*>(args.filter_tensor->raw_ptr()), | |||||
| args.dst_tensor->compatible_ptr<int32_t>(), | args.dst_tensor->compatible_ptr<int32_t>(), | ||||
| ws_relayout_filter.ptr<uint8_t>(), N, IH, IW, OH, OW, PH, PW, IC, OC, | ws_relayout_filter.ptr<uint8_t>(), N, IH, IW, OH, OW, PH, PW, IC, OC, | ||||
| FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | FH, FW, SH, SW, static_cast<uint8_t>(zp_data), stream); | ||||
| @@ -60,9 +60,9 @@ void ConvolutionBackwardDataImpl::AlgoChanwise::exec(const ExecArgs& args) const | |||||
| #if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
| if (is_compute_capability_required(5, 3)) { | if (is_compute_capability_required(5, 3)) { | ||||
| return chanwise::run_bwd_data( | return chanwise::run_bwd_data( | ||||
| static_cast<__half*>(args.grad_tensor->raw_ptr), | |||||
| static_cast<__half*>(args.diff_tensor->raw_ptr), | |||||
| static_cast<__half*>(args.filter_tensor->raw_ptr), kparam, | |||||
| static_cast<__half*>(args.grad_tensor->raw_ptr()), | |||||
| static_cast<__half*>(args.diff_tensor->raw_ptr()), | |||||
| static_cast<__half*>(args.filter_tensor->raw_ptr()), kparam, | |||||
| stream); | stream); | ||||
| } else { | } else { | ||||
| return chanwise::run_bwd_data( | return chanwise::run_bwd_data( | ||||
| @@ -68,9 +68,9 @@ void ConvolutionBackwardDataImpl::AlgoChanwiseSmall::exec(const ExecArgs& args) | |||||
| #if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
| case DTypeEnum::Float16: | case DTypeEnum::Float16: | ||||
| return chanwise::run_bwd_data_small( | return chanwise::run_bwd_data_small( | ||||
| static_cast<half*>(args.grad_tensor->raw_ptr), | |||||
| static_cast<half*>(args.diff_tensor->raw_ptr), | |||||
| static_cast<half*>(args.filter_tensor->raw_ptr), kparam, stream); | |||||
| static_cast<half*>(args.grad_tensor->raw_ptr()), | |||||
| static_cast<half*>(args.diff_tensor->raw_ptr()), | |||||
| static_cast<half*>(args.filter_tensor->raw_ptr()), kparam, stream); | |||||
| #endif | #endif | ||||
| default: | default: | ||||
| break; | break; | ||||
| @@ -71,9 +71,10 @@ void ConvolutionBackwardDataImpl::AlgoCUDNN::exec(const ExecArgs& args) const { | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| auto status = cudnnConvolutionBackwardData( | auto status = cudnnConvolutionBackwardData( | ||||
| args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | ||||
| args.filter_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
| D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | |||||
| &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
| args.filter_tensor->raw_ptr(), D.diff_desc.desc, | |||||
| args.diff_tensor->raw_ptr(), D.conv_desc.desc, m_cudnn_enum, | |||||
| args.workspace.raw_ptr, args.workspace.size, &beta, D.grad_desc.desc, | |||||
| args.grad_tensor->raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
| cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
| @@ -103,9 +103,9 @@ void ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| { | { | ||||
| auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
| TensorND tfilter{args.filter_tensor->raw_ptr, config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
| size_t c_pos = 1; | size_t c_pos = 1; | ||||
| @@ -121,9 +121,9 @@ void ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
| auto grp = args.filter_meta.group; | auto grp = args.filter_meta.group; | ||||
| for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
| config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | ||||
| incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
| incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
| incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
| incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
| incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
| incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -140,7 +140,8 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| relayout->exec( | relayout->exec( | ||||
| {args.filter_tensor->raw_ptr, exec_src}, {inner_filter_ptr, exec_dst}); | |||||
| {args.filter_tensor->raw_ptr(), exec_src}, | |||||
| {inner_filter_ptr, exec_dst}); | |||||
| } | } | ||||
| { | { | ||||
| inner_diff_ptr = reinterpret_cast<int8_t*>(bundle.get(1)); | inner_diff_ptr = reinterpret_cast<int8_t*>(bundle.get(1)); | ||||
| @@ -152,7 +153,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| relayout->exec( | relayout->exec( | ||||
| {args.diff_tensor->raw_ptr, exec_src}, {inner_diff_ptr, exec_dst}); | |||||
| {args.diff_tensor->raw_ptr(), exec_src}, {inner_diff_ptr, exec_dst}); | |||||
| } | } | ||||
| int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | int8_t* inner_grad_ptr = reinterpret_cast<int8_t*>(bundle.get(2)); | ||||
| @@ -196,7 +197,7 @@ void ConvolutionBackwardDataImpl::AlgoInt8NCHWDotProdImplicitGemm::exec( | |||||
| auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | auto&& relayout = args.opr->handle()->create_operator<RelayoutForward>(); | ||||
| relayout->exec( | relayout->exec( | ||||
| {inner_grad_ptr, exec_src}, {args.grad_tensor->raw_ptr, exec_dst}); | |||||
| {inner_grad_ptr, exec_src}, {args.grad_tensor->raw_ptr(), exec_dst}); | |||||
| } | } | ||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -143,7 +143,7 @@ void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal(const ExecArgs& args | |||||
| TensorND A(args.filter_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | TensorND A(args.filter_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | ||||
| if (fm.should_flip) { | if (fm.should_flip) { | ||||
| convolution::flip_filter( | convolution::flip_filter( | ||||
| args.as_fwd_args(), wbundle.get_workspace(2), A.raw_ptr); | |||||
| args.as_fwd_args(), wbundle.get_workspace(2), A.get_ref_ptr()); | |||||
| config.second->exec(A, C, B, wbundle.get_workspace(3)); | config.second->exec(A, C, B, wbundle.get_workspace(3)); | ||||
| } else { | } else { | ||||
| config.second->exec(A, C, B, wbundle.get_workspace(2)); | config.second->exec(A, C, B, wbundle.get_workspace(2)); | ||||
| @@ -50,9 +50,9 @@ void ConvolutionBackwardFilterImpl::AlgoChanwise::exec(const ExecArgs& args) con | |||||
| #if CUDA_VERSION >= 9000 | #if CUDA_VERSION >= 9000 | ||||
| if (is_compute_capability_required(5, 3)) { | if (is_compute_capability_required(5, 3)) { | ||||
| return chanwise::run_bwd_filter( | return chanwise::run_bwd_filter( | ||||
| static_cast<__half*>(args.grad_tensor->raw_ptr), | |||||
| static_cast<__half*>(args.src_tensor->raw_ptr), | |||||
| static_cast<__half*>(args.diff_tensor->raw_ptr), kparam, | |||||
| static_cast<__half*>(args.grad_tensor->raw_ptr()), | |||||
| static_cast<__half*>(args.src_tensor->raw_ptr()), | |||||
| static_cast<__half*>(args.diff_tensor->raw_ptr()), kparam, | |||||
| stream); | stream); | ||||
| } else { | } else { | ||||
| return chanwise::run_bwd_filter( | return chanwise::run_bwd_filter( | ||||
| @@ -71,9 +71,9 @@ void ConvolutionBackwardFilterImpl::AlgoCUDNN::exec(const ExecArgs& args) const | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| auto status = cudnnConvolutionBackwardFilter( | auto status = cudnnConvolutionBackwardFilter( | ||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
| args.src_tensor->raw_ptr(), D.diff_desc.desc, args.diff_tensor->raw_ptr(), | |||||
| D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | ||||
| &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
| &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
| cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
| @@ -101,9 +101,9 @@ void ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
| { | { | ||||
| auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
| TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
| size_t c_pos = 1; | size_t c_pos = 1; | ||||
| @@ -118,9 +118,9 @@ void ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
| auto grp = fm.group; | auto grp = fm.group; | ||||
| for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
| config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | ||||
| incr_voidp(tsrc.raw_ptr, strd_src); | |||||
| incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
| incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
| incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
| incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
| incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -133,7 +133,7 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(const ExecArgs& ar | |||||
| froml.stride[0] = args.diff_layout->stride[0]; | froml.stride[0] = args.diff_layout->stride[0]; | ||||
| tol.stride[0] = 1; | tol.stride[0] = 1; | ||||
| tol.stride[1] = N; | tol.stride[1] = N; | ||||
| TensorND from(args.diff_tensor->ptr<T>(), froml), to(diff_t, tol); | |||||
| TensorND from(args.diff_tensor->raw_ptr(), froml), to(diff_t, tol); | |||||
| args.handle->relayout_opr()->exec(from, to); | args.handle->relayout_opr()->exec(from, to); | ||||
| } | } | ||||
| { | { | ||||
| @@ -149,13 +149,13 @@ void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal(const ExecArgs& ar | |||||
| Cl({OC, OH * OW * N}, typename DTypeTrait<T>::dtype()); | Cl({OC, OH * OW * N}, typename DTypeTrait<T>::dtype()); | ||||
| TensorND A(args.grad_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | TensorND A(args.grad_tensor->ptr<T>(), Al), B(col, Bl), C(diff_t, Cl); | ||||
| if (fm.should_flip) { | if (fm.should_flip) { | ||||
| A.raw_ptr = wbundle.get(2); | |||||
| A.reset_ptr(wbundle.get(2)); | |||||
| config.second->exec(C, B, A, wbundle.get_workspace(3)); | config.second->exec(C, B, A, wbundle.get_workspace(3)); | ||||
| convolution::flip_filter( | convolution::flip_filter( | ||||
| args.as_fwd_args(), | args.as_fwd_args(), | ||||
| {static_cast<dt_byte*>(args.grad_tensor->raw_ptr), | |||||
| {static_cast<dt_byte*>(args.grad_tensor->raw_ptr()), | |||||
| wbundle.get_size(2)}, | wbundle.get_size(2)}, | ||||
| A.raw_ptr); | |||||
| A.get_ref_ptr()); | |||||
| } else { | } else { | ||||
| config.second->exec(C, B, A, wbundle.get_workspace(2)); | config.second->exec(C, B, A, wbundle.get_workspace(2)); | ||||
| } | } | ||||
| @@ -68,19 +68,19 @@ SmallVector<size_t> convolution::matmul_get_workspace_bundle( | |||||
| } | } | ||||
| void convolution::flip_filter( | void convolution::flip_filter( | ||||
| const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { | |||||
| const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) { | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); | ||||
| auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; | ||||
| auto dtype = fm.dtype; | auto dtype = fm.dtype; | ||||
| megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); | ||||
| TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}}, | |||||
| TensorND src{{{OC, IC, FH, FW}, dtype}, ref_ptr}, | |||||
| dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; | ||||
| dst.layout.stride[2] = -dst.layout.stride[2]; | dst.layout.stride[2] = -dst.layout.stride[2]; | ||||
| dst.layout.stride[3] = -dst.layout.stride[3]; | dst.layout.stride[3] = -dst.layout.stride[3]; | ||||
| args.handle->relayout_opr()->exec(src, dst); | args.handle->relayout_opr()->exec(src, dst); | ||||
| raw_ptr = workspace.raw_ptr; | |||||
| ref_ptr.reset(workspace.raw_ptr); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -85,7 +85,7 @@ struct CUDNNBwdFilterDescs { | |||||
| * change \p raw_ptr to workspace. | * change \p raw_ptr to workspace. | ||||
| */ | */ | ||||
| void flip_filter( | void flip_filter( | ||||
| const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); | |||||
| const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr); | |||||
| } // namespace convolution | } // namespace convolution | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -55,9 +55,10 @@ void Convolution3DBackwardDataImpl::AlgoCUDNN::exec(const ExecArgs& args) const | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| auto status = cudnnConvolutionBackwardData( | auto status = cudnnConvolutionBackwardData( | ||||
| args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | args.handle->cudnn_handle(), &alpha, D.filter_desc.desc, | ||||
| args.filter_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
| D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | |||||
| &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
| args.filter_tensor->raw_ptr(), D.diff_desc.desc, | |||||
| args.diff_tensor->raw_ptr(), D.conv_desc.desc, m_cudnn_enum, | |||||
| args.workspace.raw_ptr, args.workspace.size, &beta, D.grad_desc.desc, | |||||
| args.grad_tensor->raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
| cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
| @@ -96,9 +96,9 @@ void Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| { | { | ||||
| auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
| TensorND tfilter{args.filter_tensor->raw_ptr, config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
| size_t c_pos = 1; | size_t c_pos = 1; | ||||
| auto grp = args.filter_meta.group; | auto grp = args.filter_meta.group; | ||||
| @@ -114,9 +114,9 @@ void Convolution3DBackwardDataImpl::AlgoGroupConvGeneral::exec( | |||||
| for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
| config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tfilter, tdiff, tgrad, bundle.get_workspace(0)); | ||||
| incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
| incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
| incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
| incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
| incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
| incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -56,9 +56,9 @@ void Convolution3DBackwardFilterImpl::AlgoCUDNN::exec(const ExecArgs& args) cons | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| auto status = cudnnConvolutionBackwardFilter( | auto status = cudnnConvolutionBackwardFilter( | ||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr, D.diff_desc.desc, args.diff_tensor->raw_ptr, | |||||
| args.src_tensor->raw_ptr(), D.diff_desc.desc, args.diff_tensor->raw_ptr(), | |||||
| D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | ||||
| &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr); | |||||
| &beta, D.grad_desc.desc, args.grad_tensor->raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv bwd_data failed: %s; info: %s", | ||||
| cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
| @@ -98,9 +98,9 @@ void Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| { | { | ||||
| auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
| TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr, config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
| TensorND tdiff{args.diff_tensor->raw_ptr(), config.first[1]}; | |||||
| TensorND tgrad{args.grad_tensor->raw_ptr(), config.first[2]}; | |||||
| size_t c_pos = 1; | size_t c_pos = 1; | ||||
| auto grp = args.grad_filter_meta.group; | auto grp = args.grad_filter_meta.group; | ||||
| @@ -116,9 +116,9 @@ void Convolution3DBackwardFilterImpl::AlgoGroupConvGeneral::exec( | |||||
| for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
| config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | config.second->exec(tsrc, tdiff, tgrad, bundle.get_workspace(0)); | ||||
| incr_voidp(tsrc.raw_ptr, strd_src); | |||||
| incr_voidp(tdiff.raw_ptr, strd_diff); | |||||
| incr_voidp(tgrad.raw_ptr, strd_grad); | |||||
| incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
| incr_refp(tdiff.get_ref_ptr(), strd_diff); | |||||
| incr_refp(tgrad.get_ref_ptr(), strd_grad); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -54,17 +54,17 @@ size_t Convolution3DForwardImpl::Algo1x1x1::get_workspace_in_bytes( | |||||
| void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs& args) const { | void Convolution3DForwardImpl::Algo1x1x1::exec(const ExecArgs& args) const { | ||||
| TensorND A, B, C; | TensorND A, B, C; | ||||
| extract_matmul_layouts(args, A.layout, B.layout, C.layout); | extract_matmul_layouts(args, A.layout, B.layout, C.layout); | ||||
| A.raw_ptr = args.filter_tensor->raw_ptr; | |||||
| B.raw_ptr = args.src_tensor->raw_ptr; | |||||
| C.raw_ptr = args.dst_tensor->raw_ptr; | |||||
| A.reset_ptr(args.filter_tensor->raw_ptr()); | |||||
| B.reset_ptr(args.src_tensor->raw_ptr()); | |||||
| C.reset_ptr(args.dst_tensor->raw_ptr()); | |||||
| size_t batch = args.src_layout->shape[0]; | size_t batch = args.src_layout->shape[0]; | ||||
| auto mm = args.handle->matmul_opr(); | auto mm = args.handle->matmul_opr(); | ||||
| auto strd_B = args.src_layout->stride[0] * args.src_layout->dtype.size(), | auto strd_B = args.src_layout->stride[0] * args.src_layout->dtype.size(), | ||||
| strd_C = args.dst_layout->stride[0] * args.dst_layout->dtype.size(); | strd_C = args.dst_layout->stride[0] * args.dst_layout->dtype.size(); | ||||
| for (size_t i = 0; i < batch; ++i) { | for (size_t i = 0; i < batch; ++i) { | ||||
| mm->exec(A, B, C, args.workspace); | mm->exec(A, B, C, args.workspace); | ||||
| incr_voidp(B.raw_ptr, strd_B); | |||||
| incr_voidp(C.raw_ptr, strd_C); | |||||
| incr_refp(B.get_ref_ptr(), strd_B); | |||||
| incr_refp(C.get_ref_ptr(), strd_C); | |||||
| } | } | ||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -53,9 +53,10 @@ void Convolution3DForwardImpl::AlgoCUDNN::exec(const ExecArgs& args) const { | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| auto status = cudnnConvolutionForward( | auto status = cudnnConvolutionForward( | ||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr, D.filter_desc.desc, args.filter_tensor->raw_ptr, | |||||
| D.conv_desc.desc, m_cudnn_enum, args.workspace.raw_ptr, args.workspace.size, | |||||
| &beta, D.dst_desc.desc, args.dst_tensor->raw_ptr); | |||||
| args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
| args.filter_tensor->raw_ptr(), D.conv_desc.desc, m_cudnn_enum, | |||||
| args.workspace.raw_ptr, args.workspace.size, &beta, D.dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr()); | |||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | ||||
| cudnnGetErrorString(status), args.to_string().c_str()); | cudnnGetErrorString(status), args.to_string().c_str()); | ||||
| @@ -103,9 +103,9 @@ void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | ||||
| { | { | ||||
| auto config = prepare_sub_opr(args); | auto config = prepare_sub_opr(args); | ||||
| TensorND tsrc{args.src_tensor->raw_ptr, config.first[0]}; | |||||
| TensorND tfilter{args.filter_tensor->raw_ptr, config.first[1]}; | |||||
| TensorND tdst{args.dst_tensor->raw_ptr, config.first[2]}; | |||||
| TensorND tsrc{args.src_tensor->raw_ptr(), config.first[0]}; | |||||
| TensorND tfilter{args.filter_tensor->raw_ptr(), config.first[1]}; | |||||
| TensorND tdst{args.dst_tensor->raw_ptr(), config.first[2]}; | |||||
| size_t c_pos; | size_t c_pos; | ||||
| if (args.filter_meta.format == Param::Format::NCDHW) { | if (args.filter_meta.format == Param::Format::NCDHW) { | ||||
| @@ -127,9 +127,9 @@ void Convolution3DForwardImpl::AlgoGroupConvGeneral::exec(const ExecArgs& args) | |||||
| for (uint32_t g = 0; g < grp; ++g) { | for (uint32_t g = 0; g < grp; ++g) { | ||||
| config.second->exec(tsrc, tfilter, tdst, bundle.get_workspace(0)); | config.second->exec(tsrc, tfilter, tdst, bundle.get_workspace(0)); | ||||
| incr_voidp(tsrc.raw_ptr, strd_src); | |||||
| incr_voidp(tdst.raw_ptr, strd_dst); | |||||
| incr_voidp(tfilter.raw_ptr, strd_flt); | |||||
| incr_refp(tsrc.get_ref_ptr(), strd_src); | |||||
| incr_refp(tdst.get_ref_ptr(), strd_dst); | |||||
| incr_refp(tfilter.get_ref_ptr(), strd_flt); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -35,20 +35,20 @@ bool convolution3d::is_cudnn_supported(const ForwardSizeArgs& args) { | |||||
| } | } | ||||
| void convolution3d::flip_filter( | void convolution3d::flip_filter( | ||||
| const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr) { | |||||
| const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& ref_ptr) { | |||||
| auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
| megdnn_assert(fm.group == 1 && fm.spatial_ndim == 3); | megdnn_assert(fm.group == 1 && fm.spatial_ndim == 3); | ||||
| auto OC = fm.ocpg, IC = fm.icpg, FD = fm.spatial[0], FH = fm.spatial[1], | auto OC = fm.ocpg, IC = fm.icpg, FD = fm.spatial[0], FH = fm.spatial[1], | ||||
| FW = fm.spatial[2]; | FW = fm.spatial[2]; | ||||
| auto dtype = DType::from_enum(fm.dtype_enum); | auto dtype = DType::from_enum(fm.dtype_enum); | ||||
| megdnn_assert(workspace.size >= dtype.size() * OC * IC * FD * FH * FW); | megdnn_assert(workspace.size >= dtype.size() * OC * IC * FD * FH * FW); | ||||
| TensorND src{raw_ptr, {{OC, IC, FD, FH, FW}, dtype}}, | |||||
| TensorND src{{{OC, IC, FD, FH, FW}, dtype}, ref_ptr}, | |||||
| dst{workspace.raw_ptr + (FD * FH * FW - 1) * dtype.size(), src.layout}; | dst{workspace.raw_ptr + (FD * FH * FW - 1) * dtype.size(), src.layout}; | ||||
| dst.layout.stride[2] = -dst.layout.stride[2]; | dst.layout.stride[2] = -dst.layout.stride[2]; | ||||
| dst.layout.stride[3] = -dst.layout.stride[3]; | dst.layout.stride[3] = -dst.layout.stride[3]; | ||||
| dst.layout.stride[4] = -dst.layout.stride[4]; | dst.layout.stride[4] = -dst.layout.stride[4]; | ||||
| args.handle->relayout_opr()->exec(src, dst); | args.handle->relayout_opr()->exec(src, dst); | ||||
| raw_ptr = workspace.raw_ptr; | |||||
| ref_ptr.reset(workspace.raw_ptr); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -84,7 +84,7 @@ struct CUDNNBwdFilterDescs { | |||||
| * change \p raw_ptr to workspace. | * change \p raw_ptr to workspace. | ||||
| */ | */ | ||||
| void flip_filter( | void flip_filter( | ||||
| const ForwardSizeArgs& args, const Workspace& workspace, void*& raw_ptr); | |||||
| const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr); | |||||
| inline bool cudnn_get_convolution_fwd_algo_helper( | inline bool cudnn_get_convolution_fwd_algo_helper( | ||||
| cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | ||||
| @@ -169,10 +169,10 @@ void ConvPoolingForwardImpl::exec( | |||||
| nonlineMode = IDENTITY; | nonlineMode = IDENTITY; | ||||
| } | } | ||||
| float *src_ptr = static_cast<float*>(src.raw_ptr), | |||||
| *filter_ptr = static_cast<float*>(filter.raw_ptr), | |||||
| *bias_ptr = static_cast<float*>(bias.raw_ptr), | |||||
| *dst_ptr = static_cast<float*>(dst.raw_ptr); | |||||
| float *src_ptr = static_cast<float*>(src.raw_ptr()), | |||||
| *filter_ptr = static_cast<float*>(filter.raw_ptr()), | |||||
| *bias_ptr = static_cast<float*>(bias.raw_ptr()), | |||||
| *dst_ptr = static_cast<float*>(dst.raw_ptr()); | |||||
| switch (this->param().method) { | switch (this->param().method) { | ||||
| case Param::Method::WITH_SHARED_MEM: | case Param::Method::WITH_SHARED_MEM: | ||||
| @@ -12,7 +12,7 @@ | |||||
| #include "./opr_impl.h" | #include "./opr_impl.h" | ||||
| #include "./kern.cuh" | #include "./kern.cuh" | ||||
| #include "src/common/reduce_helper.h" | |||||
| #include "src/common/reduce_helper_device.h" | |||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -58,7 +58,7 @@ void DctChannelSelectForwardImpl::exec( | |||||
| megdnn_assert( | megdnn_assert( | ||||
| param().format == Param::Format::NCHW4, "qint8 only support nchw4"); | param().format == Param::Format::NCHW4, "qint8 only support nchw4"); | ||||
| dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW4>( | dct::call_kern_dct<dct_block, dct::DctLayoutFormat::NCHW4>( | ||||
| src.ptr<uint8_t>(), (int8_t*)dst.raw_ptr, in, ic, ih, iw, oc, | |||||
| src.ptr<uint8_t>(), (int8_t*)dst.raw_ptr(), in, ic, ih, iw, oc, | |||||
| with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, error_info, | with_fix_32_mask, mask_offset_ptr, mask_val_ptr, stream, error_info, | ||||
| m_error_tracker, | m_error_tracker, | ||||
| dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); | dst.layout.dtype.param<::megdnn::dtype::QuantizedS8>().scale); | ||||
| @@ -227,7 +227,7 @@ INST(dt_quint8); | |||||
| template <int ndim> | template <int ndim> | ||||
| void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | void ParamElemVisitor4bitBase<ndim, BCAST_OTHER>::host_init( | ||||
| const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | const TensorND& rv, int /*grid_size*/, int /*block_size*/) { | ||||
| m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr); | |||||
| m_ptr = reinterpret_cast<Storage*>(rv.raw_ptr()); | |||||
| ptrdiff_t min_stride = std::numeric_limits<ptrdiff_t>::max(); | ptrdiff_t min_stride = std::numeric_limits<ptrdiff_t>::max(); | ||||
| for (size_t i = 0; i < rv.layout.ndim; ++i) { | for (size_t i = 0; i < rv.layout.ndim; ++i) { | ||||
| m_stride[i] = rv.layout.stride[i]; | m_stride[i] = rv.layout.stride[i]; | ||||
| @@ -21,31 +21,31 @@ using namespace megdnn; | |||||
| using namespace cuda; | using namespace cuda; | ||||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | void ElemwiseMultiTypeImpl::on_fuse_mul_add3_int16x32x32x32( | ||||
| const ElemwiseOpParamN<3>& param, dt_int32* dst) { | |||||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||||
| BroadcastChannelInfo binfo0, binfo1; | BroadcastChannelInfo binfo0, binfo1; | ||||
| if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
| is_broadcasted_channel_like(param[1].layout, binfo0) && | is_broadcasted_channel_like(param[1].layout, binfo0) && | ||||
| is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | is_broadcasted_channel_like(param[2].layout, binfo1) && binfo0 == binfo1) { | ||||
| elemwise_multi_type::fma3_int16x32x32x32_1c1( | elemwise_multi_type::fma3_int16x32x32x32_1c1( | ||||
| param, dst, cuda_stream(this->handle())); | |||||
| param, dst.ptr<dt_int32>(), cuda_stream(this->handle())); | |||||
| return; | return; | ||||
| } | } | ||||
| megdnn_throw("unsupported fma3 int16x32x32x32 layout"); | megdnn_throw("unsupported fma3 int16x32x32x32 layout"); | ||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | ||||
| const ElemwiseOpParamN<3>& param, dt_int8* dst) { | |||||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) { | |||||
| Broadcast1xInfo binfo0, binfo1; | Broadcast1xInfo binfo0, binfo1; | ||||
| auto p1 = param[1].ptr<float>(), p2 = param[2].ptr<float>(); | auto p1 = param[1].ptr<float>(), p2 = param[2].ptr<float>(); | ||||
| auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
| if (is_vector(param[0].layout) && is_broadcasted_1x(param[1].layout, binfo0) && | if (is_vector(param[0].layout) && is_broadcasted_1x(param[1].layout, binfo0) && | ||||
| is_broadcasted_1x(param[2].layout, binfo1) && binfo0 == binfo1) { | is_broadcasted_1x(param[2].layout, binfo1) && binfo0 == binfo1) { | ||||
| switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
| #define cb(t) \ | |||||
| case DTypeTrait<t>::enumv: \ | |||||
| elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ | |||||
| param[0].ptr<DTypeTrait<t>::ctype>(), p1, p2, dst, binfo0.x, binfo0.y, \ | |||||
| stream); \ | |||||
| #define cb(t) \ | |||||
| case DTypeTrait<t>::enumv: \ | |||||
| elemwise_multi_type::fma3_iXxf32xf32xi8_bcast_1x( \ | |||||
| param[0].ptr<DTypeTrait<t>::ctype>(), p1, p2, dst.ptr<dt_int8>(), \ | |||||
| binfo0.x, binfo0.y, stream); \ | |||||
| return; | return; | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | ||||
| #undef cb | #undef cb | ||||
| @@ -58,14 +58,14 @@ void ElemwiseMultiTypeImpl::on_fuse_mul_add3_iXxf32xf32xi8( | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | ||||
| const ElemwiseOpParamN<2>& param, dt_int8* dst) { | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
| auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
| if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | ||||
| switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
| #define DISPATCH(t) \ | |||||
| case DTypeTrait<t>::enumv: \ | |||||
| elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
| DTypeTrait<t>::ctype, dt_int8>(param, dst, stream); \ | |||||
| #define DISPATCH(t) \ | |||||
| case DTypeTrait<t>::enumv: \ | |||||
| elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
| DTypeTrait<t>::ctype, dt_int8>(param, dst.ptr<dt_int8>(), stream); \ | |||||
| return; | return; | ||||
| DISPATCH(::megdnn::dtype::Int32) | DISPATCH(::megdnn::dtype::Int32) | ||||
| DISPATCH(::megdnn::dtype::Int16) | DISPATCH(::megdnn::dtype::Int16) | ||||
| @@ -85,7 +85,7 @@ void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi8( | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) { | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
| auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
| BroadcastChannelInfo info; | BroadcastChannelInfo info; | ||||
| if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
| @@ -95,7 +95,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||||
| is_broadcasted_scalar(param[4].layout) && | is_broadcasted_scalar(param[4].layout) && | ||||
| is_broadcasted_scalar(param[5].layout)) { | is_broadcasted_scalar(param[5].layout)) { | ||||
| elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int16>( | elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int16>( | ||||
| param, dst, stream); | |||||
| param, dst.ptr<dt_int8>(), stream); | |||||
| return; | return; | ||||
| } | } | ||||
| megdnn_throw( | megdnn_throw( | ||||
| @@ -106,7 +106,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) { | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) { | |||||
| auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
| BroadcastChannelInfo info; | BroadcastChannelInfo info; | ||||
| if (is_vector(param[0].layout) && | if (is_vector(param[0].layout) && | ||||
| @@ -116,7 +116,7 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | |||||
| is_broadcasted_scalar(param[4].layout) && | is_broadcasted_scalar(param[4].layout) && | ||||
| is_broadcasted_scalar(param[5].layout)) { | is_broadcasted_scalar(param[5].layout)) { | ||||
| elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int32>( | elemwise_multi_type::fuse_add_rmulh_round_shr_saturate_bcast_1c11<dt_int32>( | ||||
| param, dst, stream); | |||||
| param, dst.ptr<dt_int8>(), stream); | |||||
| return; | return; | ||||
| } | } | ||||
| megdnn_throw( | megdnn_throw( | ||||
| @@ -127,14 +127,14 @@ void ElemwiseMultiTypeImpl::on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | |||||
| } | } | ||||
| void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( | void ElemwiseMultiTypeImpl::on_round_shr_saturate_iXxi8xi16( | ||||
| const ElemwiseOpParamN<2>& param, dt_int16* dst) { | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) { | |||||
| auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
| if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | if (is_vector(param[0].layout) && is_broadcasted_scalar(param[1].layout)) { | ||||
| switch (param[0].layout.dtype.enumv()) { | switch (param[0].layout.dtype.enumv()) { | ||||
| #define DISPATCH(t) \ | |||||
| case DTypeTrait<t>::enumv: \ | |||||
| elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
| DTypeTrait<t>::ctype, dt_int16>(param, dst, stream); \ | |||||
| #define DISPATCH(t) \ | |||||
| case DTypeTrait<t>::enumv: \ | |||||
| elemwise_multi_type::round_shr_saturate_iXxi8xiX_scalar< \ | |||||
| DTypeTrait<t>::ctype, dt_int16>(param, dst.ptr<dt_int16>(), stream); \ | |||||
| return; | return; | ||||
| DISPATCH(::megdnn::dtype::Int32) | DISPATCH(::megdnn::dtype::Int32) | ||||
| DISPATCH(::megdnn::dtype::Int16) | DISPATCH(::megdnn::dtype::Int16) | ||||
| @@ -227,22 +227,22 @@ IMPL_MODE_DISPATCHER(2, dt_quint4, dt_qint32); | |||||
| #undef _cb_dispatch_mode | #undef _cb_dispatch_mode | ||||
| #define _cb_dispatch_mode(_m) \ | |||||
| case param::Elemwise::Mode::_m: \ | |||||
| do { \ | |||||
| using KernImpl = ElemwiseKern< \ | |||||
| megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, float>; \ | |||||
| using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
| arity, src_ctype, dst_ctype, KernImpl>; \ | |||||
| using dst_storage = typename VectTypeTrait<dst_ctype>::Storage; \ | |||||
| dst_storage* dst = reinterpret_cast<dst_storage*>(dst_tensor.raw_ptr); \ | |||||
| Op op(src_params, dst, dst_param); \ | |||||
| ElemwiseOpParamN<1> param_dst; \ | |||||
| param_dst[0] = dst_tensor; \ | |||||
| param_dst.init_from_given_tensor(); \ | |||||
| run_elemwise<Op, src_ctype, dst_ctype, arity>( \ | |||||
| param, param_dst, stream, op); \ | |||||
| return; \ | |||||
| #define _cb_dispatch_mode(_m) \ | |||||
| case param::Elemwise::Mode::_m: \ | |||||
| do { \ | |||||
| using KernImpl = ElemwiseKern< \ | |||||
| megcorePlatformCUDA, param_enumv::Elemwise::Mode::_m, float>; \ | |||||
| using Op = kern_ops_quantized::QuantizedMultiTypeOp< \ | |||||
| arity, src_ctype, dst_ctype, KernImpl>; \ | |||||
| using dst_storage = typename VectTypeTrait<dst_ctype>::Storage; \ | |||||
| dst_storage* dst = reinterpret_cast<dst_storage*>(dst_tensor.raw_ptr()); \ | |||||
| Op op(src_params, dst, dst_param); \ | |||||
| ElemwiseOpParamN<1> param_dst; \ | |||||
| param_dst[0] = dst_tensor; \ | |||||
| param_dst.init_from_given_tensor(); \ | |||||
| run_elemwise<Op, src_ctype, dst_ctype, arity>( \ | |||||
| param, param_dst, stream, op); \ | |||||
| return; \ | |||||
| } while (0); | } while (0); | ||||
| #define FOREACH(cb) \ | #define FOREACH(cb) \ | ||||
| @@ -18,22 +18,22 @@ namespace cuda { | |||||
| class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { | class ElemwiseMultiTypeImpl final : public ElemwiseMultiTypeImplHelper { | ||||
| void on_fuse_mul_add3_int16x32x32x32( | void on_fuse_mul_add3_int16x32x32x32( | ||||
| const ElemwiseOpParamN<3>& param, dt_int32* dst) override; | |||||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||||
| void on_fuse_mul_add3_iXxf32xf32xi8( | void on_fuse_mul_add3_iXxf32xf32xi8( | ||||
| const ElemwiseOpParamN<3>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<3>& param, const TensorND& dst) override; | |||||
| void on_round_shr_saturate_iXxi8xi8( | void on_round_shr_saturate_iXxi8xi8( | ||||
| const ElemwiseOpParamN<2>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||||
| void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
| void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8( | ||||
| const ElemwiseOpParamN<6>& param, dt_int8* dst) override; | |||||
| const ElemwiseOpParamN<6>& param, const TensorND& dst) override; | |||||
| void on_round_shr_saturate_iXxi8xi16( | void on_round_shr_saturate_iXxi8xi16( | ||||
| const ElemwiseOpParamN<2>& param, dt_int16* dst) override; | |||||
| const ElemwiseOpParamN<2>& param, const TensorND& dst) override; | |||||
| void on_quantized_mode( | void on_quantized_mode( | ||||
| const ElemwiseOpParamN<1>& param, const TensorND& dst, | const ElemwiseOpParamN<1>& param, const TensorND& dst, | ||||
| @@ -32,11 +32,6 @@ std::unique_ptr<LocalForward> get_opr(Handle* handle, param::Convolution param) | |||||
| return std::move(opr); | return std::move(opr); | ||||
| } | } | ||||
| template <typename T> | |||||
| void incr_ptr(T*& dst, ptrdiff_t delta) { | |||||
| dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta); | |||||
| } | |||||
| TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) { | TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) { | ||||
| TensorLayout ret = input; | TensorLayout ret = input; | ||||
| megdnn_assert(ret[1] % g == 0); | megdnn_assert(ret[1] % g == 0); | ||||
| @@ -84,18 +79,20 @@ void GroupLocalForwardImpl::exec( | |||||
| SH, SW, stream); | SH, SW, stream); | ||||
| } else { | } else { | ||||
| auto&& opr = get_opr(handle, param()); | auto&& opr = get_opr(handle, param()); | ||||
| TensorND src_g = {src.raw_ptr, prepare_src_dst(src.layout, G)}; | |||||
| TensorND dst_g = {dst.raw_ptr, prepare_src_dst(dst.layout, G)}; | |||||
| TensorND filter_g = {filter.raw_ptr, prepare_filter(filter.layout)}; | |||||
| TensorND src_g = {src.raw_ptr(), prepare_src_dst(src.layout, G)}; | |||||
| TensorND dst_g = {dst.raw_ptr(), prepare_src_dst(dst.layout, G)}; | |||||
| TensorND filter_g = {filter.raw_ptr(), prepare_filter(filter.layout)}; | |||||
| for (size_t g = 0; g < G; ++g) { | for (size_t g = 0; g < G; ++g) { | ||||
| opr->exec(src_g, filter_g, dst_g, workspace); | opr->exec(src_g, filter_g, dst_g, workspace); | ||||
| incr_ptr( | |||||
| src_g.raw_ptr, src_g.layout.stride[1] * src_g.layout.shape[1] * | |||||
| src_g.layout.dtype.size()); | |||||
| incr_ptr( | |||||
| dst_g.raw_ptr, dst_g.layout.stride[1] * dst_g.layout.shape[1] * | |||||
| dst_g.layout.dtype.size()); | |||||
| incr_ptr(filter_g.raw_ptr, filter_g.layout.span().dist_byte()); | |||||
| incr_refp( | |||||
| src_g.get_ref_ptr(), src_g.layout.stride[1] * | |||||
| src_g.layout.shape[1] * | |||||
| src_g.layout.dtype.size()); | |||||
| incr_refp( | |||||
| dst_g.get_ref_ptr(), dst_g.layout.stride[1] * | |||||
| dst_g.layout.shape[1] * | |||||
| dst_g.layout.dtype.size()); | |||||
| incr_refp(filter_g.get_ref_ptr(), filter_g.layout.span().dist_byte()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -106,7 +106,7 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) c | |||||
| B1.stride[4] = wo; | B1.stride[4] = wo; | ||||
| B1.stride[5] = 1; | B1.stride[5] = 1; | ||||
| B1.stride[6] = co * ho * wo; | B1.stride[6] = co * ho * wo; | ||||
| TensorND ts_B1{args.diff_tensor->raw_ptr, B1}; | |||||
| TensorND ts_B1{args.diff_tensor->raw_ptr(), B1}; | |||||
| TensorLayout B2{ | TensorLayout B2{ | ||||
| {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | ||||
| B2.init_contiguous_stride(); | B2.init_contiguous_stride(); | ||||
| @@ -122,7 +122,7 @@ void LocalShareBackwardDataImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) c | |||||
| TensorLayout C{ | TensorLayout C{ | ||||
| {groups * sgh * sgw, icpg * fh * fw, ho / sgh * wo / sgw * n}, | {groups * sgh * sgw, icpg * fh * fw, ho / sgh * wo / sgw * n}, | ||||
| dtype::Float32()}; | dtype::Float32()}; | ||||
| TensorND ts_A{args.filter_tensor->raw_ptr, A}; | |||||
| TensorND ts_A{args.filter_tensor->raw_ptr(), A}; | |||||
| TensorND ts_B{ws_pretranspose, B}; | TensorND ts_B{ws_pretranspose, B}; | ||||
| TensorND ts_C{ws_col2im, C}; | TensorND ts_C{ws_col2im, C}; | ||||
| Workspace ws_wrapper; | Workspace ws_wrapper; | ||||
| @@ -113,7 +113,7 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) | |||||
| B1.stride[4] = co * ho * wo; | B1.stride[4] = co * ho * wo; | ||||
| B1.stride[5] = wo; | B1.stride[5] = wo; | ||||
| B1.stride[6] = 1; | B1.stride[6] = 1; | ||||
| TensorND ts_B1{args.diff_tensor->raw_ptr, B1}; | |||||
| TensorND ts_B1{args.diff_tensor->raw_ptr(), B1}; | |||||
| TensorLayout B2{ | TensorLayout B2{ | ||||
| {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | {groups * sgh * sgw, ocpg, ho / sgh * wo / sgw * n}, dtype::Float32()}; | ||||
| B2.init_contiguous_stride(); | B2.init_contiguous_stride(); | ||||
| @@ -133,7 +133,7 @@ void LocalShareBackwardFilterImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) | |||||
| TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; | TensorLayout C{{groups * sgh * sgw, icpg * fh * fw, ocpg}, dtype::Float32()}; | ||||
| TensorND ts_A{ws_im2col, A}; | TensorND ts_A{ws_im2col, A}; | ||||
| TensorND ts_B{ws_pretranspose, B}; | TensorND ts_B{ws_pretranspose, B}; | ||||
| TensorND ts_C{args.grad_tensor->raw_ptr, C}; | |||||
| TensorND ts_C{args.grad_tensor->raw_ptr(), C}; | |||||
| Workspace ws_wrapper; | Workspace ws_wrapper; | ||||
| ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ||||
| ws_wrapper.size = ws.get_size(2); | ws_wrapper.size = ws.get_size(2); | ||||
| @@ -100,7 +100,7 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const | |||||
| TensorLayout C{ | TensorLayout C{ | ||||
| {groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, dtype::Float32()}; | {groups * sgh * sgw, ho / sgh * wo / sgw * n, ocpg}, dtype::Float32()}; | ||||
| TensorND ts_A{ws_im2col, A}; | TensorND ts_A{ws_im2col, A}; | ||||
| TensorND ts_B{args.filter_tensor->raw_ptr, B}; | |||||
| TensorND ts_B{args.filter_tensor->raw_ptr(), B}; | |||||
| TensorND ts_C{ws_posttranspose, C}; | TensorND ts_C{ws_posttranspose, C}; | ||||
| Workspace ws_wrapper; | Workspace ws_wrapper; | ||||
| ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ws_wrapper.raw_ptr = reinterpret_cast<dt_byte*>(ws_matmul); | ||||
| @@ -119,7 +119,7 @@ void LocalShareForwardImpl::AlgoBatchedMatMul::exec(const ExecArgs& args) const | |||||
| C1.stride[6] = ocpg; | C1.stride[6] = ocpg; | ||||
| TensorLayout C2 = args.dst_layout; | TensorLayout C2 = args.dst_layout; | ||||
| TensorND ts_C1{ws_posttranspose, C1}; | TensorND ts_C1{ws_posttranspose, C1}; | ||||
| TensorND ts_C2{args.dst_tensor->raw_ptr, C2}; | |||||
| TensorND ts_C2{args.dst_tensor->raw_ptr(), C2}; | |||||
| auto&& relayout_opr = args.opr->handle()->create_operator<Relayout>(); | auto&& relayout_opr = args.opr->handle()->create_operator<Relayout>(); | ||||
| relayout_opr->exec(ts_C1, ts_C2); | relayout_opr->exec(ts_C1, ts_C2); | ||||
| } | } | ||||
| @@ -29,7 +29,7 @@ void LRNForwardImpl::exec( | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| cudnn_check(cudnnLRNCrossChannelForward( | cudnn_check(cudnnLRNCrossChannelForward( | ||||
| handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, src_desc.desc, | handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, src_desc.desc, | ||||
| src.raw_ptr, &beta, dst_desc.desc, dst.raw_ptr)); | |||||
| src.raw_ptr(), &beta, dst_desc.desc, dst.raw_ptr())); | |||||
| } | } | ||||
| void LRNBackwardImpl::setup_descs( | void LRNBackwardImpl::setup_descs( | ||||
| @@ -51,8 +51,8 @@ void LRNBackwardImpl::exec( | |||||
| float alpha = 1.0f, beta = 0.0f; | float alpha = 1.0f, beta = 0.0f; | ||||
| cudnn_check(cudnnLRNCrossChannelBackward( | cudnn_check(cudnnLRNCrossChannelBackward( | ||||
| handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dst_desc.desc, | handle, lrn_desc.desc, CUDNN_LRN_CROSS_CHANNEL_DIM1, &alpha, dst_desc.desc, | ||||
| dst.raw_ptr, diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, | |||||
| &beta, grad_desc.desc, grad.raw_ptr)); | |||||
| dst.raw_ptr(), diff_desc.desc, diff.raw_ptr(), src_desc.desc, src.raw_ptr(), | |||||
| &beta, grad_desc.desc, grad.raw_ptr())); | |||||
| } | } | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -37,11 +37,11 @@ void MatrixInverseImpl::exec( | |||||
| auto stream = handle->stream(); | auto stream = handle->stream(); | ||||
| batched_matrix_mul::arange<uintptr_t>( | batched_matrix_mul::arange<uintptr_t>( | ||||
| reinterpret_cast<uintptr_t*>(psrc_batch), | reinterpret_cast<uintptr_t*>(psrc_batch), | ||||
| reinterpret_cast<uintptr_t>(src.raw_ptr), n * n * sizeof(float), batch, | |||||
| reinterpret_cast<uintptr_t>(src.raw_ptr()), n * n * sizeof(float), batch, | |||||
| stream); | stream); | ||||
| batched_matrix_mul::arange<uintptr_t>( | batched_matrix_mul::arange<uintptr_t>( | ||||
| reinterpret_cast<uintptr_t*>(pdst_batch), | reinterpret_cast<uintptr_t*>(pdst_batch), | ||||
| reinterpret_cast<uintptr_t>(dst.raw_ptr), n * n * sizeof(float), batch, | |||||
| reinterpret_cast<uintptr_t>(dst.raw_ptr()), n * n * sizeof(float), batch, | |||||
| stream); | stream); | ||||
| cublas_check(cublasSmatinvBatched( | cublas_check(cublasSmatinvBatched( | ||||
| handle->cublas_handle(), n, psrc_batch, n, pdst_batch, n, info, batch)); | handle->cublas_handle(), n, psrc_batch, n, pdst_batch, n, info, batch)); | ||||