From 95eb6ae380d383d29ad68622cf2be8d5d8f01a3c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 31 Aug 2020 17:37:16 +0800 Subject: [PATCH] feat(mgb/opr): let more ops support empty IO GitOrigin-RevId: 84dddb4b23638b29950e438bba2af8b5fd5166fa --- dnn/src/common/basic_types.cpp | 21 ++++-- src/opr/impl/tensor_manip.cpp | 37 +++++---- src/opr/include/megbrain/opr/tensor_manip.h | 6 +- src/opr/test/tensor_manip.cpp | 84 ++++++++++++++++++++- 4 files changed, 125 insertions(+), 23 deletions(-) diff --git a/dnn/src/common/basic_types.cpp b/dnn/src/common/basic_types.cpp index e9b90d0e..74624414 100644 --- a/dnn/src/common/basic_types.cpp +++ b/dnn/src/common/basic_types.cpp @@ -392,8 +392,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { TensorLayout result{dtype, format}; result.ndim = tshape.ndim; for (size_t i = 0; i < tshape.ndim; i++) { - megdnn_throw_if(!tshape.shape[i], tensor_reshape_error, - megdnn_mangle("target shape is 0")); result.shape[i] = tshape.shape[i]; result.stride[i] = (tshape.shape[i] == 1); } @@ -409,8 +407,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { for (size_t i = 0; i < tshape.ndim; ++i) { int target_idx = tshape.ndim - i - 1; int cur_idx = ndim - i - 1; - megdnn_throw_if(!tshape.shape[target_idx], tensor_reshape_error, - megdnn_mangle("target shape is 0")); size_t cur_shape = (cur_idx >= 0 ? shape[cur_idx] : 1), cur_stride = (cur_idx >= 0 ? stride[cur_idx] : 0); if (tshape.shape[target_idx] != cur_shape) { @@ -434,10 +430,16 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { bool TensorLayout::try_reshape(TensorLayout& result, const TensorShape& tshp) const { megdnn_assert(tshp.ndim); + + bool is_empty_shape = false; for (size_t i = 0; i < tshp.ndim; ++i) { - megdnn_throw_if(!tshp.shape[i], tensor_reshape_error, - megdnn_mangle(ssprintf("bad target tshp: %s", - tshp.to_string().c_str()))); + if (!tshp.shape[i]) { + megdnn_throw_if(!format.is_default(), tensor_reshape_error, + megdnn_mangle(ssprintf("bad target tshp: %s", + tshp.to_string().c_str()))); + is_empty_shape = true; + break; + } } megdnn_throw_if( @@ -454,6 +456,11 @@ bool TensorLayout::try_reshape(TensorLayout& result, result.format = this->format; result.TensorShape::operator=(tshp); + if (is_empty_shape) { + result.init_contiguous_stride(); + return true; + } + size_t sdim = 0, prod = 1, cont_sdim = 0; for (size_t i = 0; i < tshp.ndim; ++i) { megdnn_assert(cont_sdim < cont.ndim); diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index 862e44ce..d49bd82c 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -237,7 +237,8 @@ void GetVarShape::record_execute_deps(ExecDependencyArray& deps) { void ReshapeBrdcastHelper::reshapebrdcast_init(VarNode *inp, VarNode *tshp) { add_input({inp, tshp}); - add_output(None)->dtype(inp->dtype()); + add_output(None)->dtype(inp->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); if (reshapebrdcast_output_shape_need_input_shape()) outshape_by_symvar_enable(1, 1); else @@ -340,6 +341,14 @@ void ReshapeBrdcastHelper::init_output_static_infer_desc() { infer_value}); } +ReshapeBrdcastHelper::NodeProp* +ReshapeBrdcastHelper::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + // f}}} /* f{{{ ======================= Reshape ======================= */ @@ -394,7 +403,7 @@ Maybe Reshape::reshapebrdcast_get_dest_layout( } auto tot_nr_elem = src.total_nr_elems(); actual_tshape.shape[unspec] = 0; - mgb_throw_if(tot_nr_elem % rem_nr_elem, TensorReshapeError, + mgb_throw_if(!rem_nr_elem || tot_nr_elem % rem_nr_elem, TensorReshapeError, "could not reshape: src=%s tshape=%s unspec_axis=%zd", static_cast(src).to_string().c_str(), actual_tshape.to_string().c_str(), @@ -484,6 +493,17 @@ void AxisManipOprBase::init_output_static_infer_desc() { {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); } +AxisManipOprBase::NodeProp* AxisManipOprBase::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + +void AxisManipOprBase::axis_manip_init(VarNode* inp) { + add_input({inp}); + add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); +} // f}}} @@ -504,8 +524,7 @@ Dimshuffle::Dimshuffle(VarNode *inp, const std::vector &pattern, mgb_throw_if(i < -1 || i >= int(ndim), GraphError, "bad Dimshuffle pattern"); } - add_input({inp}); - add_output(None); + axis_manip_init(inp); add_equivalence_component>(m_pattern.data(), m_pattern.size()); } @@ -587,8 +606,7 @@ AxisAddRemove::AxisAddRemove( { mgb_throw_if(desc.empty(), GraphError, "desc for AxisAddRemove could not be empty"); - add_input({inp}); - add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + axis_manip_init(inp); add_equivalence_component>(m_desc.data(), m_desc.size()); } @@ -631,13 +649,6 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout( return layout; } -AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const { - auto ret = Super::do_make_node_prop(); - ret->add_dep_type_existing_var(input(0), - NodeProp::DepType::VALUE_ALLOW_EMPTY); - return ret; -} - #ifdef MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(AxisAddRemove) { MGB_MARK_USED_VAR(wrt_idx); diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index 345c7f6d..d030e239 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -92,6 +92,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ReshapeBrdcastHelper, void scn_do_execute() override final; void add_input_layout_constraint() override final; void init_output_static_infer_desc() override; + NodeProp* do_make_node_prop() const override; protected: using Super::Super; @@ -199,11 +200,14 @@ MGB_DEFINE_CLS_WITH_SUPER(AxisManipOprBase, void mem_plan_fwd_in2out_readonly() override final; void scn_do_execute() override final; void init_output_static_infer_desc() override final; + NodeProp* do_make_node_prop() const override; protected: using Super::Super; virtual TensorLayout axis_manip_get_output_layout( const TensorLayout &inp_layout) const = 0; + + void axis_manip_init(VarNode* inp); }; } @@ -319,8 +323,6 @@ MGB_DEFINE_OPR_CLASS(AxisAddRemove, intl::AxisManipOprBase) // { TensorLayout axis_manip_get_output_layout( const TensorLayout &inp_layout) const override; - - NodeProp* do_make_node_prop() const override; }; namespace intl { diff --git a/src/opr/test/tensor_manip.cpp b/src/opr/test/tensor_manip.cpp index 11dd9727..3fd5dd1b 100644 --- a/src/opr/test/tensor_manip.cpp +++ b/src/opr/test/tensor_manip.cpp @@ -17,6 +17,7 @@ #include "megbrain/opr/io.h" #include "megbrain/opr/blas.h" #include "megbrain/opr/utility.h" +#include "megbrain/opr/misc.h" #include "megbrain/utils/arith_helper.h" using namespace mgb; @@ -138,7 +139,7 @@ TEST(TestTensorManip, Reshape) { auto &&dep_map = opr0_reshp.node()->owner_opr()->node_prop().dep_map(); using DT = cg::OperatorNodeBase::NodeProp::DepType; ASSERT_EQ(2u, dep_map.size()); - ASSERT_EQ(DT::DEV_VALUE, dep_map.at(op->input(0))); + ASSERT_EQ(DT::DEV_VALUE | DT::VALUE_ALLOW_EMPTY, dep_map.at(op->input(0))); ASSERT_EQ(DT::HOST_VALUE, dep_map.at(op->input(1))); } @@ -318,6 +319,39 @@ TEST(TestTensorManip, ReshapeInferShapeForDynamicInput) { run({23, 12, 5}); } +TEST(TestTensorManip, ReshapeEmptyShape) { + HostTensorGenerator<> gen; + constexpr size_t x_length = 233; + auto host_x = gen({x_length}), + host_v = gen({2, 3, 3, 3}); + for (size_t i = 0; i < x_length; ++ i) { + host_x->ptr()[i] = 1.f; + } + constexpr auto INVALID_AXIS = opr::Reshape::Param::INVALID_AXIS; + for (auto unspec_axis: {INVALID_AXIS, 0, 1, 3}) { + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + TensorShape tshape{2, 3, 3, 3}; + auto zero_axis = unspec_axis; + if (unspec_axis == INVALID_AXIS) { + tshape[zero_axis = 2] = 0; + } + using CondTakeMode = opr::CondTake::Param::Mode; + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + x_empty = opr::CondTake::make(x, x, {CondTakeMode::EQ, 0.f})[0], + v = opr::Host2DeviceCopy::make(*graph, host_v), + x_reshape = opr::Reshape::make(x_empty, tshape, {unspec_axis}), + y = opr::Concat::make({x_reshape, v}, zero_axis); + HostTensorND host_empty, host_y; + auto func = graph->compile({ + make_callback_copy(x_reshape, host_empty), + make_callback_copy(y, host_y)}); + func->execute().wait(); + ASSERT_TRUE(host_empty.layout().is_empty()); + MGB_ASSERT_TENSOR_EQ(*host_v, host_y); + } +} + TEST(TestTensorManip, ReshapeWithNegativeUnspec) { HostTensorGenerator<> gen; auto host_x = gen({4, 8}); @@ -365,6 +399,26 @@ TEST(TestTensorManip, Broadcast) { } } +TEST(TestTensorManip, BroadcastEmptyShape) { + HostTensorGenerator<> gen; + for (auto&& arg: + {std::make_pair(TensorShape{1}, TensorShape{0}), + {{1, 2, 3}, {0, 2, 3}}, + {{2, 3}, {1, 0, 2, 3}}, + {{1, 0, 2, 3}, {4, 0, 2, 3}}, + {{0, 1, 2, 3}, {3, 0, 4, 2, 3}}}) { + auto host_x = gen(arg.first); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Broadcast::make(x, arg.second); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_TRUE(host_y.shape().eq_shape(arg.second)); + } +} + TEST(TestTensorManip, Dimshuffle) { HostTensorGenerator<> gen; constexpr size_t S0 = 8, S1 = 3; @@ -395,6 +449,34 @@ TEST(TestTensorManip, Dimshuffle) { } } +TEST(TestTensorManip, DimshuffleEmptyShape) { + HostTensorGenerator<> gen; + for (auto&& arg: + {std::make_pair( + TensorShape{3, 0}, + std::vector{1, -1, 0, -1}), + {{3, 1, 0, 4}, {-1, 3, -1, 0, 2}}, + {{2, 0, 3, 0}, {1, 0, 2, 3}}}) { + auto host_x = gen(arg.first); + auto graph = ComputingGraph::make(); + graph->options().graph_opt_level = 0; + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::Dimshuffle::make(x, arg.second); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + auto&& y_shape = host_y.shape(); + for(size_t idx = 0; idx < arg.second.size(); ++ idx) { + auto elem = arg.second[idx]; + if (elem == -1) { + ASSERT_EQ(y_shape[idx], 1u); + } else { + ASSERT_EQ(arg.first[elem], y_shape[idx]); + } + } + } +} + TEST(TestTensorManip, DimshuffleCombined) { using Checker = AutoOprChecker<1, 1>; constexpr int RED0 = 2, RED1 = 3;