GitOrigin-RevId: a802a14e8d
tags/v0.4.0
| @@ -15,18 +15,16 @@ | |||
| using namespace megdnn; | |||
| void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, | |||
| const TensorLayout& table, | |||
| const TensorLayout& offsets, | |||
| const TensorLayout& parts) { | |||
| megdnn_assert(table.dtype == dtype::Int32{}, "bad dtype: %s", | |||
| table.dtype.name()); | |||
| megdnn_assert(concated.ndim == 1 && table.ndim == 1 && parts.ndim == 1 && | |||
| concated.stride[0] == 1 && table.stride[0] == 1 && | |||
| megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s", | |||
| offsets.dtype.name()); | |||
| megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 && | |||
| concated.stride[0] == 1 && offsets.stride[0] == 1 && | |||
| parts.stride[0] == 1, | |||
| "bad layout: concated=%s table=%s parts=%s", | |||
| concated.to_string().c_str(), table.to_string().c_str(), | |||
| "bad layout: concated=%s offsets=%s parts=%s", | |||
| concated.to_string().c_str(), offsets.to_string().c_str(), | |||
| parts.to_string().c_str()); | |||
| megdnn_assert(table.shape[0] == concated.shape[0] * 2, | |||
| "concated=%zu table=%zu", concated.shape[0], table.shape[0]); | |||
| } | |||
| std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||
| @@ -46,11 +44,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||
| return v + ((alignment - mod) & (alignment - 1)); | |||
| }; | |||
| std::vector<dt_int32> offsets(shapes.size()); | |||
| std::vector<dt_int32> offsets(shapes.size() << 1); | |||
| size_t offset = 0; | |||
| for (size_t i = 0; i < shapes.size(); i++) { | |||
| offsets[i] = offset; | |||
| offset = get_aligned(offset) + shapes[i].total_nr_elems(); | |||
| offset = get_aligned(offset); | |||
| offsets[i * 2] = offset; | |||
| offset += shapes[i].total_nr_elems(); | |||
| offsets[i * 2 + 1] = offset; | |||
| } | |||
| return offsets; | |||
| } | |||
| @@ -24,7 +24,7 @@ size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, | |||
| template <typename T> | |||
| void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
| _megdnn_tensor_in table, | |||
| _megdnn_tensor_in offsets, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| size_t inp_size = srcs.layout.shape[0], | |||
| @@ -35,25 +35,25 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
| megdnn_assert_internal(src_cpu); | |||
| auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr); | |||
| auto table_outer_gpu = table.ptr<int32_t>(), | |||
| table_inner_gpu = table_outer_gpu + out_size; | |||
| auto offsets_gpu = offsets.ptr<int32_t>(); | |||
| cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, | |||
| cudaMemcpyHostToDevice, stream)); | |||
| param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), out_size, | |||
| table_outer_gpu, table_inner_gpu, stream); | |||
| param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size, | |||
| offsets_gpu, stream); | |||
| } | |||
| void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||
| void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||
| _megdnn_tensor_in offsets, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(dst.layout, table.layout, srcs.layout); | |||
| #define cb(DType) \ | |||
| if (dst.layout.dtype == DType()) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| exec_internal<ctype>(srcs, table, dst, workspace); \ | |||
| return; \ | |||
| check_exec(dst.layout, offsets.layout, srcs.layout); | |||
| #define cb(DType) \ | |||
| if (dst.layout.dtype == DType()) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| exec_internal<ctype>(srcs, offsets, dst, workspace); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| megdnn_throw("bad type"); | |||
| @@ -19,17 +19,24 @@ namespace param_pack { | |||
| template <typename T> | |||
| __global__ void concat_kernel(const T** srcs, T* dst, | |||
| const int32_t* table_outer, | |||
| const int32_t* table_inner, | |||
| const int32_t* offsets, | |||
| size_t srcs_size, | |||
| size_t total_size) { | |||
| size_t addr = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (addr < total_size) { | |||
| int32_t i = table_outer[addr]; | |||
| int32_t idx = table_inner[addr]; | |||
| if (idx != -1) | |||
| dst[addr] = srcs[i][idx]; | |||
| else | |||
| size_t l = 0, r = srcs_size - 1, mid; | |||
| while (l < r) { | |||
| mid = (l + r) >> 1; | |||
| if (offsets[(mid << 1) + 1] > addr) { | |||
| r = mid; | |||
| } else { | |||
| l = mid + 1; | |||
| } | |||
| } | |||
| if (addr < offsets[l << 1]) | |||
| dst[addr] = 0; | |||
| else | |||
| dst[addr] = srcs[l][addr - offsets[l << 1]]; | |||
| } | |||
| } | |||
| @@ -59,20 +66,20 @@ void split_proxy(const T* src, T** dsts, size_t total_size, | |||
| } | |||
| template <typename T> | |||
| void concat_proxy(const T** srcs, T* dst, size_t total_size, | |||
| const int32_t* table_outer, | |||
| const int32_t* table_inner, cudaStream_t stream) { | |||
| void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
| const int32_t* offsets, | |||
| cudaStream_t stream) { | |||
| size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS); | |||
| concat_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>( | |||
| srcs, dst, table_outer, table_inner, total_size); | |||
| srcs, dst, offsets, srcs_size, total_size); | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(T) \ | |||
| template void concat_proxy<T>(const T**, T*, size_t, \ | |||
| const int32_t*, const int32_t*, \ | |||
| template void concat_proxy<T>(const T**, T*, size_t, size_t, \ | |||
| const int32_t*, \ | |||
| cudaStream_t); \ | |||
| template void split_proxy<T>(const T*, T**, size_t, \ | |||
| template void split_proxy<T>(const T*, T**, size_t, \ | |||
| const int32_t*, const int32_t*, \ | |||
| cudaStream_t); | |||
| #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
| @@ -25,9 +25,8 @@ void split_proxy(const T* src, T** dsts, size_t total_size, | |||
| cudaStream_t stream); | |||
| template <typename T> | |||
| void concat_proxy(const T** srcs, T* dst, size_t total_size, | |||
| const int32_t* table_outer, | |||
| const int32_t* table_inner, cudaStream_t stream); | |||
| void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size, | |||
| const int32_t* offsets, cudaStream_t stream); | |||
| } // namespace param_pack | |||
| } // namespace cuda | |||
| @@ -54,38 +54,40 @@ void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table, | |||
| } | |||
| template <typename T> | |||
| void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, int32_t* table, | |||
| void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||
| int32_t* offsets, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace) { | |||
| size_t out_size = dst.layout.total_nr_elems(); | |||
| auto srcs_ptr = static_cast<const T**>(srcs.raw_ptr); | |||
| auto dst_ptr = dst.ptr<T>(); | |||
| auto table_outer = table, table_inner = table_outer + out_size; | |||
| for (size_t j = 0; j < out_size; j++) { | |||
| int32_t i = table_outer[j]; | |||
| int32_t idx = table_inner[j]; | |||
| if (idx != -1) | |||
| dst_ptr[j] = srcs_ptr[i][idx]; | |||
| else | |||
| dst_ptr[j] = 0; | |||
| int32_t last_pos = 0; | |||
| for (size_t i = 0; i < srcs.layout[0]; i++) { | |||
| int32_t begin = offsets[i * 2], end = offsets[i * 2 + 1]; | |||
| while (last_pos < begin) { | |||
| dst_ptr[last_pos] = 0; | |||
| last_pos++; | |||
| } | |||
| for (int32_t j = 0; j < end - begin; j++) { | |||
| dst_ptr[begin + j] = srcs_ptr[i][j]; | |||
| } | |||
| last_pos = end; | |||
| } | |||
| } | |||
| void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table, | |||
| void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||
| _megdnn_tensor_in offsets, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(dst.layout, table.layout, srcs.layout); | |||
| auto table_ptr = table.ptr<int32_t>(); | |||
| check_exec(dst.layout, offsets.layout, srcs.layout); | |||
| auto offsets_ptr = offsets.ptr<int32_t>(); | |||
| #define cb(DType) \ | |||
| if (dst.layout.dtype == DType()) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| exec_internal<ctype>(srcs, table_ptr, dst, workspace)); \ | |||
| return; \ | |||
| #define cb(DType) \ | |||
| if (dst.layout.dtype == DType()) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| exec_internal<ctype>(srcs, offsets_ptr, dst, workspace)); \ | |||
| return; \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| megdnn_throw("bad type"); | |||
| @@ -1339,8 +1339,10 @@ void Concat::init_output_comp_node() { | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat); | |||
| ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table, | |||
| const std::vector<dt_int32> offsets_val, | |||
| const OperatorNodeConfig& config) | |||
| : Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp) { | |||
| : Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp), | |||
| m_offsets(offsets_val) { | |||
| CompNode cn = inp[0]->comp_node(); | |||
| add_input({inp[0]}); | |||
| for (size_t i = 1; i < inp.size(); i++) { | |||
| @@ -1361,14 +1363,16 @@ void ParamPackConcat::add_input_layout_constraint(){ | |||
| } | |||
| } | |||
| SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar> &inp, | |||
| const SymbolVar &table, const OperatorNodeConfig& config) { | |||
| SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar>& inp, | |||
| const SymbolVar& offsets, | |||
| const std::vector<dt_int32> offsets_val, | |||
| const OperatorNodeConfig& config) { | |||
| VarNodeArray array(inp.size()); | |||
| for (size_t i = 0; i < inp.size(); i++) { | |||
| array[i] = inp[i].node(); | |||
| } | |||
| return inp.front(). | |||
| insert_single_output_opr<ParamPackConcat>(array, table.node(), config); | |||
| return inp.front().insert_single_output_opr<ParamPackConcat>( | |||
| array, offsets.node(), offsets_val, config); | |||
| } | |||
| void ParamPackConcat::scn_do_execute() { | |||
| @@ -1379,13 +1383,13 @@ void ParamPackConcat::scn_do_execute() { | |||
| for (size_t i = 0; i < inputs.size() - 1; i++) { | |||
| ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr; | |||
| } | |||
| auto table = inputs.back()->dev_tensor().as_megdnn(); | |||
| auto offsets = inputs.back()->dev_tensor().as_megdnn(); | |||
| megdnn::TensorND srcs( | |||
| ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32())); | |||
| auto&& dst = output(0)->dev_tensor().as_megdnn(); | |||
| m_opr->exec(srcs, table, dst, get_megdnn_workspace_from_var(output(1))); | |||
| m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1))); | |||
| } | |||
| void ParamPackConcat::init_output_dtype() { | |||
| @@ -1396,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){ | |||
| using namespace cg::static_infer; | |||
| auto &&mgr = owner_graph()->static_infer_manager(); | |||
| auto infer_out = [](TensorShape &dest, const InpVal &inp) { | |||
| dest = {inp.val.back().shape().total_nr_elems()/2}; | |||
| auto infer_out = [this](TensorShape &dest, const InpVal &inp) { | |||
| dest = {m_offsets.back()}; | |||
| return true; | |||
| }; | |||
| DepVal shp_deps; | |||
| @@ -1480,10 +1484,10 @@ void ParamPackSplit::init_output_dtype() { | |||
| } | |||
| void ParamPackSplit::mem_plan_fwd_in2out_readonly() { | |||
| mgb_assert(m_offsets.size() == output().size()); | |||
| mgb_assert(m_offsets.size() == output().size() * 2); | |||
| for (size_t i = 0; i < output().size(); i++) { | |||
| auto layout = output(i)->layout(); | |||
| auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]); | |||
| auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]); | |||
| m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( | |||
| input(0), spec); | |||
| mgb_assert(m_mem_fwd_success[i]); | |||
| @@ -1524,7 +1528,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) { | |||
| } | |||
| return ParamPackConcat::make( | |||
| grad, opr.input(1), | |||
| grad, opr.input(1), opr.get_offsets(), | |||
| OperatorNodeConfig{}.follow_comp_node(opr.input(0))) | |||
| .node(); | |||
| } | |||
| @@ -31,31 +31,6 @@ namespace serialization { | |||
| struct OprMaker<opr::GetVarShape, 0>: | |||
| public OprMakerVariadic<opr::GetVarShape>{}; | |||
| template<> | |||
| struct OprLoadDumpImpl<opr::ParamPackConcat, 0> | |||
| { | |||
| using ParamPackConcat = opr::ParamPackConcat; | |||
| using Param = opr::ParamPackConcat::Param; | |||
| static void dump(OprDumpContext &ctx, | |||
| const cg::OperatorNodeBase &opr_) { | |||
| auto &&opr = opr_.cast_final_safe<ParamPackConcat>(); | |||
| ctx.write_param<Param>(opr.param()); | |||
| } | |||
| static cg::OperatorNodeBase* load( | |||
| OprLoadContext &ctx, const cg::VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config) { | |||
| auto param = ctx.read_param<Param>(); | |||
| mgb_assert(!inputs.empty()); | |||
| SymbolVarArray ivar{inputs.size() - 1}; | |||
| for (size_t i = 0; i < inputs.size() - 1; ++ i) | |||
| ivar[i] = inputs[i]; | |||
| return ParamPackConcat::make(ivar, inputs.back(), | |||
| param, config).node()->owner_opr(); | |||
| } | |||
| }; | |||
| template<> | |||
| struct OprLoadDumpImpl<opr::Split, 0> { | |||
| using Split = opr::Split; | |||
| @@ -151,7 +126,6 @@ namespace opr { | |||
| MGB_SEREG_OPR(Dimshuffle, 1); | |||
| MGB_SEREG_OPR(AxisAddRemove, 1); | |||
| MGB_SEREG_OPR(Concat, 0); | |||
| MGB_SEREG_OPR(ParamPackConcat, 0); | |||
| using GetVarShapeV1 = opr::GetVarShape; | |||
| MGB_SEREG_OPR(GetVarShapeV1, 0); | |||
| using ReshapeV1 = opr::Reshape; | |||
| @@ -193,6 +167,22 @@ namespace opr { | |||
| } | |||
| MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split); | |||
| cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat( | |||
| const serialization::OprShallowCopyContext &ctx, | |||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
| const OperatorNodeConfig &config){ | |||
| auto &&opr = opr_.cast_final_safe<ParamPackConcat>(); | |||
| auto &&offsets = opr.get_offsets(); | |||
| SymbolVarArray ivar{inputs.size() - 1}; | |||
| for (size_t i = 0; i < inputs.size() - 1; ++i) | |||
| ivar[i] = inputs[i]; | |||
| return ParamPackConcat::make(ivar, inputs.back(), offsets, config). | |||
| node()->owner_opr(); | |||
| } | |||
| MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat); | |||
| MGB_SEREG_OPR(RelayoutFormat, 1); | |||
| MGB_SEREG_OPR(WinogradFilterPreprocess, 1); | |||
| } // namespace opr | |||
| @@ -539,6 +539,7 @@ MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // { | |||
| MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // { | |||
| //! input pointer buffer | |||
| SmallVector<void*> m_inp_ptr; | |||
| std::vector<dt_int32> m_offsets; | |||
| intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr; | |||
| void add_input_layout_constraint() override; | |||
| @@ -554,15 +555,23 @@ public: | |||
| return {}; | |||
| } | |||
| ParamPackConcat(VarNodeArray &inp, VarNode *table, | |||
| const OperatorNodeConfig &config); | |||
| static SymbolVar make(const SmallVector<SymbolVar> &inp, | |||
| const SymbolVar &table, const OperatorNodeConfig &config = {}); | |||
| ParamPackConcat(VarNodeArray& inp, VarNode* offsets, | |||
| const std::vector<dt_int32> offsets_val, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVar make(const SmallVector<SymbolVar>& inp, | |||
| const SymbolVar& offsets, | |||
| const std::vector<dt_int32> offsets_val, | |||
| const OperatorNodeConfig& config = {}); | |||
| static SymbolVar make(const SmallVector<SymbolVar>& inp, | |||
| const SymbolVar& offsets, | |||
| const std::vector<dt_int32> offsets_val, const Param&, | |||
| const OperatorNodeConfig& config) { | |||
| return make(inp, offsets, offsets_val, config); | |||
| } | |||
| static SymbolVar make(const SmallVector<SymbolVar> &inp, | |||
| const SymbolVar &table, const Param &, | |||
| const OperatorNodeConfig &config) { | |||
| return make(inp, table, config); | |||
| const std::vector<dt_int32>& get_offsets() const { | |||
| return m_offsets; | |||
| } | |||
| }; | |||
| @@ -1906,7 +1906,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ | |||
| memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8); | |||
| auto table = opr::Host2DeviceCopy::make(*graph, host_table); | |||
| auto z = opr::ParamPackConcat::make(srcs, table); | |||
| auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen); | |||
| HostTensorND host_z; | |||
| auto func = graph->compile({make_callback_copy(z, host_z)}); | |||