|
|
|
@@ -144,21 +144,35 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec( |
|
|
|
midout_iv(Mode::_mode), _type_midout_id) { \ |
|
|
|
thin_function<void( \ |
|
|
|
const _type*, const _type*, const _type*, _type*, DType, DType, \ |
|
|
|
DType, DType, size_t, size_t, size_t)> \ |
|
|
|
DType, DType, size_t, size_t, size_t, size_t)> \ |
|
|
|
run = OpCallerTernary< \ |
|
|
|
_op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
static_cast<naive::HandleImpl*>(kern_param.handle), \ |
|
|
|
run(static_cast<const _type*>(src0.raw_ptr()), \ |
|
|
|
static_cast<const _type*>(src1.raw_ptr()), \ |
|
|
|
static_cast<const _type*>(src2.raw_ptr()), \ |
|
|
|
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ |
|
|
|
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ |
|
|
|
binfo.x, binfo.y, binfo.z)); \ |
|
|
|
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ |
|
|
|
binfo, dst, run](size_t task_id, size_t) { \ |
|
|
|
size_t offset = task_id * nr_channels_per_thread; \ |
|
|
|
size_t nr_channels_thread = \ |
|
|
|
std::min(nr_channels - offset, nr_channels_per_thread); \ |
|
|
|
run(static_cast<const _type*>(src0.raw_ptr()) + offset, \ |
|
|
|
static_cast<const _type*>(src1.raw_ptr()) + offset * binfo.z, \ |
|
|
|
static_cast<const _type*>(src2.raw_ptr()) + offset, \ |
|
|
|
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ |
|
|
|
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ |
|
|
|
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ |
|
|
|
binfo.y * binfo.z); \ |
|
|
|
}; \ |
|
|
|
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ |
|
|
|
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ |
|
|
|
kernel); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
return |
|
|
|
|
|
|
|
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) |
|
|
|
->megcore_dispatcher() |
|
|
|
->nr_threads(); |
|
|
|
|
|
|
|
size_t nr_channels = binfo.y; |
|
|
|
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; |
|
|
|
auto&& dst = *(kern_param.m_dst); |
|
|
|
DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash); |
|
|
|
#undef DISPATCH_TERNARY |
|
|
|
@@ -181,23 +195,39 @@ void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec( |
|
|
|
midout_iv(Mode::_mode), _type_midout_id) { \ |
|
|
|
thin_function<void( \ |
|
|
|
const _type*, const _type*, size_t, const _type*, _type*, DType, \ |
|
|
|
DType, DType, DType, size_t, size_t, size_t)> \ |
|
|
|
DType, DType, DType, size_t, size_t, size_t, size_t)> \ |
|
|
|
run = OpCallerTernary< \ |
|
|
|
_op<_type, _type>, \ |
|
|
|
BcastType::BCAST111C_VEC_BCAST111C>::run; \ |
|
|
|
MEGDNN_DISPATCH_CPU_KERN( \ |
|
|
|
static_cast<naive::HandleImpl*>(kern_param.handle), \ |
|
|
|
run(static_cast<const _type*>(src0.raw_ptr()), \ |
|
|
|
static_cast<const _type*>(src1.raw_ptr()), \ |
|
|
|
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \ |
|
|
|
static_cast<const _type*>(src2.raw_ptr()), \ |
|
|
|
static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype, \ |
|
|
|
src1.layout.dtype, src2.layout.dtype, dst.layout.dtype, \ |
|
|
|
binfo.x, binfo.y, binfo.z)); \ |
|
|
|
auto kernel = [nr_channels, nr_channels_per_thread, src0, src1, src2, \ |
|
|
|
binfo, dst, run](size_t task_id, size_t) { \ |
|
|
|
size_t offset = task_id * nr_channels_per_thread; \ |
|
|
|
size_t nr_channels_thread = \ |
|
|
|
std::min(nr_channels - offset, nr_channels_per_thread); \ |
|
|
|
size_t src1_offset = \ |
|
|
|
is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z; \ |
|
|
|
run(static_cast<const _type*>(src0.raw_ptr()), \ |
|
|
|
static_cast<const _type*>(src1.raw_ptr()) + \ |
|
|
|
offset * (binfo.z + src1_offset), \ |
|
|
|
src1_offset, static_cast<const _type*>(src2.raw_ptr()), \ |
|
|
|
static_cast<_type*>(dst.raw_ptr()) + offset * binfo.z, \ |
|
|
|
src0.layout.dtype, src1.layout.dtype, src2.layout.dtype, \ |
|
|
|
dst.layout.dtype, binfo.x, nr_channels_thread, binfo.z, \ |
|
|
|
binfo.y * binfo.z); \ |
|
|
|
}; \ |
|
|
|
MEGDNN_DISPATCH_MULTI_THREAD_CPU_KERN( \ |
|
|
|
static_cast<naive::HandleImpl*>(kern_param.handle), nr_threads, \ |
|
|
|
kernel); \ |
|
|
|
} \ |
|
|
|
MIDOUT_END(); \ |
|
|
|
return |
|
|
|
|
|
|
|
size_t nr_threads = static_cast<naive::HandleImpl*>(kern_param.handle) |
|
|
|
->megcore_dispatcher() |
|
|
|
->nr_threads(); |
|
|
|
|
|
|
|
size_t nr_channels = binfo.y; |
|
|
|
size_t nr_channels_per_thread = (nr_channels + nr_threads - 1) / nr_threads; |
|
|
|
auto&& dst = *(kern_param.m_dst); |
|
|
|
DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash); |
|
|
|
#undef DISPATCH_TERNARY |
|
|
|
|