GitOrigin-RevId: bc56f09037
tags/v0.4.0
| @@ -469,22 +469,23 @@ using Split = SplitForward; | |||||
| * large number of inputs and can handle alignment requirements. Axis is also | * large number of inputs and can handle alignment requirements. Axis is also | ||||
| * not supported. | * not supported. | ||||
| * | * | ||||
| * The table can be generated by gen_table(). The \p srcs in ParamPackSplit and | |||||
| * The offsets can be generated by gen_offsets(). The \p srcs in ParamPackSplit and | |||||
| * \p dsts in ParamPackConcat must be on CPU, and must remain valid until the | * \p dsts in ParamPackConcat must be on CPU, and must remain valid until the | ||||
| * execution stream is synchronized. | * execution stream is synchronized. | ||||
| */ | */ | ||||
| class ParamPackConcatSplitBase : public OperatorBase { | class ParamPackConcatSplitBase : public OperatorBase { | ||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& concated, const TensorLayout& table, | |||||
| void check_exec(const TensorLayout& concated, const TensorLayout& offsets, | |||||
| const TensorLayout& parts); | const TensorLayout& parts); | ||||
| public: | public: | ||||
| using Param = megdnn::param::Empty; | using Param = megdnn::param::Empty; | ||||
| ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {} | ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {} | ||||
| //! generate table to be used with ParamPackConcat and ParamPackSplit | |||||
| static std::vector<dt_int32> gen_table(const TensorShapeArray& shapes, | |||||
| size_t alignment, size_t dtype_size); | |||||
| //! generate offsets to be used with ParamPackConcat and ParamPackSplit | |||||
| static std::vector<dt_int32> gen_offsets(const TensorShapeArray& shapes, | |||||
| size_t alignment, | |||||
| size_t dtype_size); | |||||
| }; | }; | ||||
| /** | /** | ||||
| @@ -29,7 +29,7 @@ void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, | |||||
| "concated=%zu table=%zu", concated.shape[0], table.shape[0]); | "concated=%zu table=%zu", concated.shape[0], table.shape[0]); | ||||
| } | } | ||||
| std::vector<dt_int32> ParamPackConcatSplitBase::gen_table( | |||||
| std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets( | |||||
| const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) { | const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) { | ||||
| megdnn_assert(alignment && (alignment & (alignment - 1)) == 0, | megdnn_assert(alignment && (alignment & (alignment - 1)) == 0, | ||||
| "alignment must be power of 2: %zu", alignment); | "alignment must be power of 2: %zu", alignment); | ||||
| @@ -46,30 +46,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_table( | |||||
| return v + ((alignment - mod) & (alignment - 1)); | return v + ((alignment - mod) & (alignment - 1)); | ||||
| }; | }; | ||||
| std::vector<dt_int32> offsets(shapes.size()); | |||||
| size_t offset = 0; | size_t offset = 0; | ||||
| for (auto&& i : shapes) { | |||||
| offset = get_aligned(offset) + i.total_nr_elems(); | |||||
| for (size_t i = 0; i < shapes.size(); i++) { | |||||
| offsets[i] = offset; | |||||
| offset = get_aligned(offset) + shapes[i].total_nr_elems(); | |||||
| } | } | ||||
| std::vector<dt_int32> table(offset * 2); | |||||
| auto outer_table = table.data(), inner_table = outer_table + offset; | |||||
| offset = 0; | |||||
| for (size_t i = 0; i < shapes.size(); ++i) { | |||||
| auto aligned = get_aligned(offset); | |||||
| for (size_t j = offset; j < aligned; ++j) { | |||||
| inner_table[j] = outer_table[j] = -1; | |||||
| } | |||||
| offset = aligned; | |||||
| auto cur_size = shapes[i].total_nr_elems(); | |||||
| for (size_t j = 0; j < cur_size; ++j) { | |||||
| outer_table[offset + j] = i; | |||||
| inner_table[offset + j] = j; | |||||
| } | |||||
| offset += cur_size; | |||||
| } | |||||
| megdnn_assert(offset * 2 == table.size()); | |||||
| return table; | |||||
| return offsets; | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -112,8 +112,8 @@ void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes, | |||||
| std::vector<int32_t> table = | std::vector<int32_t> table = | ||||
| create_table<T>(shapes, handle->alignment_requirement()); | create_table<T>(shapes, handle->alignment_requirement()); | ||||
| ASSERT_EQ(table, | ASSERT_EQ(table, | ||||
| ParamPackSplit::gen_table(shapes, handle->alignment_requirement(), | |||||
| sizeof(T))); | |||||
| ParamPackSplit::gen_offsets( | |||||
| shapes, handle->alignment_requirement(), sizeof(T))); | |||||
| size_t pack_size = table.size() / 2; | size_t pack_size = table.size() / 2; | ||||
| int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), | int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), | ||||
| table.size()); | table.size()); | ||||
| @@ -47,19 +47,19 @@ SymbolVarArray _Opr::param_pack_split( | |||||
| shapearr[i] = npy::vec2shape(shapes[i]); | shapearr[i] = npy::vec2shape(shapes[i]); | ||||
| } | } | ||||
| auto cn = src.node()->comp_node(); | |||||
| auto table_val = megdnn::ParamPackSplit::gen_offsets( | |||||
| shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); | |||||
| if (!table.node()) { | if (!table.node()) { | ||||
| auto cn = src.node()->comp_node(); | |||||
| if (config.has_comp_node_set()) { | if (config.has_comp_node_set()) { | ||||
| cn = config.get_single_comp_node(); | cn = config.get_single_comp_node(); | ||||
| } | } | ||||
| auto table_val = megdnn::ParamPackSplit::gen_table( | |||||
| shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); | |||||
| HostTensorND hv{cn, TensorShape{table_val.size()}, dtype::Int32{}}; | |||||
| HostTensorND hv{cn, TensorShape{{table_val.size()}}, dtype::Int32{}}; | |||||
| memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int)); | memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int)); | ||||
| table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); | table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); | ||||
| } | } | ||||
| return mgb::opr::ParamPackSplit::make(src, table, shapearr, config); | |||||
| return mgb::opr::ParamPackSplit::make(src, table, table_val, shapearr, config); | |||||
| } | } | ||||
| #if MGB_ENABLE_OPR_MM | #if MGB_ENABLE_OPR_MM | ||||
| @@ -1430,20 +1430,22 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){ | |||||
| /* f{{{ ======================= ParamPackSplit ======================= */ | /* f{{{ ======================= ParamPackSplit ======================= */ | ||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); | ||||
| ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* table, | |||||
| TensorShapeArray& shapes, const OperatorNodeConfig& config) | |||||
| : Super{src->owner_graph(), config, "ParamPackSplit", {src, table}}, | |||||
| m_shapes(shapes){ | |||||
| mgb_assert(src->comp_node() == table->comp_node()); | |||||
| ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* offsets, | |||||
| const std::vector<dt_int32> offsets_val, | |||||
| TensorShapeArray& shapes, | |||||
| const OperatorNodeConfig& config) | |||||
| : Super{src->owner_graph(), config, "ParamPackSplit", {src, offsets}}, | |||||
| m_shapes(shapes), m_offsets(offsets_val) { | |||||
| mgb_assert(src->comp_node() == offsets->comp_node()); | |||||
| add_input({src}); | add_input({src}); | ||||
| add_input({table}); | |||||
| add_input({offsets}); | |||||
| m_mem_fwd_success.resize(m_shapes.size()); | |||||
| for (size_t i = 0; i < shapes.size(); i++) { | for (size_t i = 0; i < shapes.size(); i++) { | ||||
| mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!"); | mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!"); | ||||
| add_output(ssprintf("param_pack_o%zu", i))->dtype(src->dtype()); | |||||
| add_output(ssprintf("param_pack_o%zu", i)) | |||||
| ->dtype(src->dtype()).shape(shapes[i]); | |||||
| } | } | ||||
| cg::add_workspace_output(this); | |||||
| } | } | ||||
| void ParamPackSplit::add_input_layout_constraint(){ | void ParamPackSplit::add_input_layout_constraint(){ | ||||
| @@ -1451,17 +1453,19 @@ void ParamPackSplit::add_input_layout_constraint(){ | |||||
| } | } | ||||
| SymbolVarArray ParamPackSplit::make(const SymbolVar& src, | SymbolVarArray ParamPackSplit::make(const SymbolVar& src, | ||||
| const SymbolVar& table, | |||||
| const SymbolVar& offsets, | |||||
| const std::vector<dt_int32> offsets_val, | |||||
| TensorShapeArray shapes, | TensorShapeArray shapes, | ||||
| const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
| auto&& out = src.node() | auto&& out = src.node() | ||||
| ->owner_graph() | ->owner_graph() | ||||
| ->insert_opr(std::make_unique<ParamPackSplit>( | ->insert_opr(std::make_unique<ParamPackSplit>( | ||||
| src.node(), table.node(), shapes, config)) | |||||
| src.node(), offsets.node(), offsets_val, | |||||
| shapes, config)) | |||||
| ->output(); | ->output(); | ||||
| SymbolVarArray ret; | SymbolVarArray ret; | ||||
| ret.resize(out.size() - 1); // do not return workspace | |||||
| ret.resize(out.size()); | |||||
| for (size_t i = 0; i < ret.size(); ++i) { | for (size_t i = 0; i < ret.size(); ++i) { | ||||
| ret[i] = out[i]; | ret[i] = out[i]; | ||||
| } | } | ||||
| @@ -1469,41 +1473,25 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src, | |||||
| } | } | ||||
| void ParamPackSplit::scn_do_execute() { | void ParamPackSplit::scn_do_execute() { | ||||
| mgb_assert(m_opr.comp_node() == comp_node()); | |||||
| megdnn::TensorND src = input(0)->dev_tensor().as_megdnn(), | |||||
| table = input(1)->dev_tensor().as_megdnn(); | |||||
| auto outputs = output(); | |||||
| m_inp_ptr.resize(outputs.size() - 1); | |||||
| auto ptr = m_inp_ptr.data(); | |||||
| for (size_t i = 0; i < outputs.size() - 1; i++) { | |||||
| ptr[i] = outputs[i]->dev_tensor().as_megdnn().raw_ptr; | |||||
| } | |||||
| megdnn::TensorND dsts( | |||||
| ptr, megdnn::TensorLayout({outputs.size() - 1}, dtype::Int32())); | |||||
| m_opr->exec(src, table, dsts, | |||||
| get_megdnn_workspace_from_var(outputs.back())); | |||||
| } | |||||
| void ParamPackSplit::on_output_comp_node_stream_changed() { | |||||
| Super::on_output_comp_node_stream_changed(); | |||||
| init_megdnn_opr(); | |||||
| } | |||||
| void ParamPackSplit::init_megdnn_opr(){ | |||||
| m_opr = intl::create_megdnn_opr<megdnn::ParamPackSplit>(comp_node()); | |||||
| } | } | ||||
| void ParamPackSplit::init_output_dtype() { | void ParamPackSplit::init_output_dtype() { | ||||
| // already initialized in constructor | // already initialized in constructor | ||||
| } | } | ||||
| void ParamPackSplit::mem_plan_fwd_in2out_readonly() { | |||||
| mgb_assert(m_offsets.size() == output().size()); | |||||
| for (size_t i = 0; i < output().size(); i++) { | |||||
| auto layout = output(i)->layout(); | |||||
| auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]); | |||||
| m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( | |||||
| input(0), spec); | |||||
| mgb_assert(m_mem_fwd_success[i]); | |||||
| } | |||||
| } | |||||
| bool ParamPackSplit::infer_shape(size_t index, TensorShape& dest, | bool ParamPackSplit::infer_shape(size_t index, TensorShape& dest, | ||||
| const cg::static_infer::InpVal& inp) { | const cg::static_infer::InpVal& inp) { | ||||
| if (!m_opr.get()){ | |||||
| init_megdnn_opr(); | |||||
| } | |||||
| dest = m_shapes[index]; | dest = m_shapes[index]; | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -1515,33 +1503,19 @@ void ParamPackSplit::init_output_static_infer_desc() { | |||||
| DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}; | DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}; | ||||
| auto infer_wk = [this](TensorShape &dst, const InpVal &inp){ | |||||
| dst.ndim = 1; | |||||
| if(!m_opr.get()){ | |||||
| init_megdnn_opr(); | |||||
| } | |||||
| dst.shape[0] = m_opr->get_workspace_in_bytes( | |||||
| inp.val.at(0).shape(), inp.val.at(1).shape(), m_shapes); | |||||
| return true; | |||||
| }; | |||||
| for (size_t i = 0; i < output().size() - 1; i++) { | |||||
| for (size_t i = 0; i < output().size(); i++) { | |||||
| auto ov = output(i); | auto ov = output(i); | ||||
| mgr.register_shape_infer( | mgr.register_shape_infer( | ||||
| ov, {SourceType::DEP, shp_deps, | ov, {SourceType::DEP, shp_deps, | ||||
| std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); | std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); | ||||
| } | } | ||||
| mgr.register_shape_infer( | |||||
| output().back(), {SourceType::DEP, shp_deps, infer_wk}); | |||||
| } | } | ||||
| MGB_IMPL_OPR_GRAD(ParamPackSplit) { | MGB_IMPL_OPR_GRAD(ParamPackSplit) { | ||||
| mgb_assert(out_grad.size() == opr.output().size()); | mgb_assert(out_grad.size() == opr.output().size()); | ||||
| SmallVector<SymbolVar> grad; | SmallVector<SymbolVar> grad; | ||||
| // last var is workspace, ignore it | // last var is workspace, ignore it | ||||
| for (size_t i = 0; i < out_grad.size() - 1; ++i) { | |||||
| for (size_t i = 0; i < out_grad.size(); ++i) { | |||||
| auto gval = out_grad[i]; | auto gval = out_grad[i]; | ||||
| if (!gval) { | if (!gval) { | ||||
| gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node(); | gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node(); | ||||
| @@ -185,9 +185,10 @@ namespace opr { | |||||
| const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | ||||
| const OperatorNodeConfig &config){ | const OperatorNodeConfig &config){ | ||||
| auto &&opr = opr_.cast_final_safe<ParamPackSplit>(); | auto &&opr = opr_.cast_final_safe<ParamPackSplit>(); | ||||
| auto &&offsets = opr.get_offsets(); | |||||
| auto &&shape = opr.get_output_shapes(); | auto &&shape = opr.get_output_shapes(); | ||||
| return ParamPackSplit::make(inputs[0], inputs[1], shape, config).at(0). | |||||
| return ParamPackSplit::make(inputs[0], inputs[1], offsets, shape, config).at(0). | |||||
| node()->owner_opr(); | node()->owner_opr(); | ||||
| } | } | ||||
| @@ -570,31 +570,31 @@ public: | |||||
| * \brief Opr used to split parameter | * \brief Opr used to split parameter | ||||
| */ | */ | ||||
| MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { | MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { | ||||
| //! input pointer buffer | |||||
| SmallVector<void*> m_inp_ptr; | |||||
| intl::UniqPtrWithCN<megdnn::ParamPackSplit> m_opr; | |||||
| TensorShapeArray m_shapes; | TensorShapeArray m_shapes; | ||||
| std::vector<dt_int32> m_offsets; | |||||
| std::vector<bool> m_mem_fwd_success; | |||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| void on_output_comp_node_stream_changed() override; | |||||
| bool infer_shape(size_t index, TensorShape &dest, | bool infer_shape(size_t index, TensorShape &dest, | ||||
| const cg::static_infer::InpVal &inp); | const cg::static_infer::InpVal &inp); | ||||
| void init_output_dtype() override; | void init_output_dtype() override; | ||||
| void mem_plan_fwd_in2out_readonly() override; | |||||
| void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
| void init_megdnn_opr(); | |||||
| public: | public: | ||||
| ParamPackSplit(VarNode* src, VarNode* table, TensorShapeArray& shapes, | |||||
| const OperatorNodeConfig &config); | |||||
| ParamPackSplit(VarNode* src, VarNode* offsets, | |||||
| const std::vector<dt_int32> offsets_val, | |||||
| TensorShapeArray& shapes, const OperatorNodeConfig& config); | |||||
| static SymbolVarArray make(const SymbolVar& src, const SymbolVar& offsets, | |||||
| const std::vector<dt_int32> offsets_val, | |||||
| TensorShapeArray shapes, | |||||
| const OperatorNodeConfig& config = {}); | |||||
| static SymbolVarArray make(const SymbolVar &src, const SymbolVar &table, | |||||
| TensorShapeArray shapes, const OperatorNodeConfig &config = {}); | |||||
| const std::vector<dt_int32>& get_offsets() const { | |||||
| return m_offsets; | |||||
| } | |||||
| const TensorShapeArray& get_output_shapes() const { | const TensorShapeArray& get_output_shapes() const { | ||||
| return m_shapes; | return m_shapes; | ||||
| @@ -1898,7 +1898,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ | |||||
| srcs.push_back(nd); | srcs.push_back(nd); | ||||
| } | } | ||||
| auto host_table_gen = megdnn::ParamPackSplit::gen_table(shapes, | |||||
| auto host_table_gen = megdnn::ParamPackSplit::gen_offsets(shapes, | |||||
| cn.get_mem_addr_alignment(), 4); | cn.get_mem_addr_alignment(), 4); | ||||
| ASSERT_EQ(host_table_gen.size(), size * 2); | ASSERT_EQ(host_table_gen.size(), size * 2); | ||||
| auto host_table = std::make_shared<HostTensorND>(); | auto host_table = std::make_shared<HostTensorND>(); | ||||
| @@ -1944,7 +1944,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) { | |||||
| auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> | auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> | ||||
| typename Checker::SymOutArray { | typename Checker::SymOutArray { | ||||
| auto table_val = megdnn::ParamPackSplit::gen_table( | |||||
| auto table_val = megdnn::ParamPackSplit::gen_offsets( | |||||
| shapes, cn.get_mem_addr_alignment(), 4); | shapes, cn.get_mem_addr_alignment(), 4); | ||||
| HostTensorND table; | HostTensorND table; | ||||
| std::copy_n(table_val.data(), table_val.size(), | std::copy_n(table_val.data(), table_val.size(), | ||||
| @@ -1954,7 +1954,8 @@ void test_param_pack_split(const TensorShapeArray& shapes) { | |||||
| .ptr<dt_int32>()); | .ptr<dt_int32>()); | ||||
| auto sym_table = opr::SharedDeviceTensor::make( | auto sym_table = opr::SharedDeviceTensor::make( | ||||
| *inputs[0].node()->owner_graph(), table); | *inputs[0].node()->owner_graph(), table); | ||||
| auto out = opr::ParamPackSplit::make(inputs[0], sym_table, shapes); | |||||
| auto out = opr::ParamPackSplit::make(inputs[0], sym_table, table_val, | |||||
| shapes); | |||||
| mgb_assert(out.size() == nr_out); | mgb_assert(out.size() == nr_out); | ||||
| typename Checker::SymOutArray ret; | typename Checker::SymOutArray ret; | ||||
| for (size_t i = 0; i < nr_out; ++i) { | for (size_t i = 0; i < nr_out; ++i) { | ||||