| @@ -392,8 +392,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||
| TensorLayout result{dtype, format}; | |||
| result.ndim = tshape.ndim; | |||
| for (size_t i = 0; i < tshape.ndim; i++) { | |||
| megdnn_throw_if(!tshape.shape[i], tensor_reshape_error, | |||
| megdnn_mangle("target shape is 0")); | |||
| result.shape[i] = tshape.shape[i]; | |||
| result.stride[i] = (tshape.shape[i] == 1); | |||
| } | |||
| @@ -409,8 +407,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||
| for (size_t i = 0; i < tshape.ndim; ++i) { | |||
| int target_idx = tshape.ndim - i - 1; | |||
| int cur_idx = ndim - i - 1; | |||
| megdnn_throw_if(!tshape.shape[target_idx], tensor_reshape_error, | |||
| megdnn_mangle("target shape is 0")); | |||
| size_t cur_shape = (cur_idx >= 0 ? shape[cur_idx] : 1), | |||
| cur_stride = (cur_idx >= 0 ? stride[cur_idx] : 0); | |||
| if (tshape.shape[target_idx] != cur_shape) { | |||
| @@ -434,10 +430,16 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||
| bool TensorLayout::try_reshape(TensorLayout& result, | |||
| const TensorShape& tshp) const { | |||
| megdnn_assert(tshp.ndim); | |||
| bool is_empty_shape = false; | |||
| for (size_t i = 0; i < tshp.ndim; ++i) { | |||
| megdnn_throw_if(!tshp.shape[i], tensor_reshape_error, | |||
| megdnn_mangle(ssprintf("bad target tshp: %s", | |||
| tshp.to_string().c_str()))); | |||
| if (!tshp.shape[i]) { | |||
| megdnn_throw_if(!format.is_default(), tensor_reshape_error, | |||
| megdnn_mangle(ssprintf("bad target tshp: %s", | |||
| tshp.to_string().c_str()))); | |||
| is_empty_shape = true; | |||
| break; | |||
| } | |||
| } | |||
| megdnn_throw_if( | |||
| @@ -454,6 +456,11 @@ bool TensorLayout::try_reshape(TensorLayout& result, | |||
| result.format = this->format; | |||
| result.TensorShape::operator=(tshp); | |||
| if (is_empty_shape) { | |||
| result.init_contiguous_stride(); | |||
| return true; | |||
| } | |||
| size_t sdim = 0, prod = 1, cont_sdim = 0; | |||
| for (size_t i = 0; i < tshp.ndim; ++i) { | |||
| megdnn_assert(cont_sdim < cont.ndim); | |||
| @@ -237,7 +237,8 @@ void GetVarShape::record_execute_deps(ExecDependencyArray& deps) { | |||
| void ReshapeBrdcastHelper::reshapebrdcast_init(VarNode *inp, VarNode *tshp) { | |||
| add_input({inp, tshp}); | |||
| add_output(None)->dtype(inp->dtype()); | |||
| add_output(None)->dtype(inp->dtype()) | |||
| .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
| if (reshapebrdcast_output_shape_need_input_shape()) | |||
| outshape_by_symvar_enable(1, 1); | |||
| else | |||
| @@ -340,6 +341,14 @@ void ReshapeBrdcastHelper::init_output_static_infer_desc() { | |||
| infer_value}); | |||
| } | |||
| ReshapeBrdcastHelper::NodeProp* | |||
| ReshapeBrdcastHelper::do_make_node_prop() const { | |||
| auto ret = Super::do_make_node_prop(); | |||
| ret->add_dep_type_existing_var(input(0), | |||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
| return ret; | |||
| } | |||
| // f}}} | |||
| /* f{{{ ======================= Reshape ======================= */ | |||
| @@ -394,7 +403,7 @@ Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout( | |||
| } | |||
| auto tot_nr_elem = src.total_nr_elems(); | |||
| actual_tshape.shape[unspec] = 0; | |||
| mgb_throw_if(tot_nr_elem % rem_nr_elem, TensorReshapeError, | |||
| mgb_throw_if(!rem_nr_elem || tot_nr_elem % rem_nr_elem, TensorReshapeError, | |||
| "could not reshape: src=%s tshape=%s unspec_axis=%zd", | |||
| static_cast<const TensorShape&>(src).to_string().c_str(), | |||
| actual_tshape.to_string().c_str(), | |||
| @@ -484,6 +493,17 @@ void AxisManipOprBase::init_output_static_infer_desc() { | |||
| {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); | |||
| } | |||
| AxisManipOprBase::NodeProp* AxisManipOprBase::do_make_node_prop() const { | |||
| auto ret = Super::do_make_node_prop(); | |||
| ret->add_dep_type_existing_var(input(0), | |||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
| return ret; | |||
| } | |||
| void AxisManipOprBase::axis_manip_init(VarNode* inp) { | |||
| add_input({inp}); | |||
| add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
| } | |||
| // f}}} | |||
| @@ -504,8 +524,7 @@ Dimshuffle::Dimshuffle(VarNode *inp, const std::vector<int> &pattern, | |||
| mgb_throw_if(i < -1 || i >= int(ndim), GraphError, | |||
| "bad Dimshuffle pattern"); | |||
| } | |||
| add_input({inp}); | |||
| add_output(None); | |||
| axis_manip_init(inp); | |||
| add_equivalence_component<PODHash<int>>(m_pattern.data(), m_pattern.size()); | |||
| } | |||
| @@ -587,8 +606,7 @@ AxisAddRemove::AxisAddRemove( | |||
| { | |||
| mgb_throw_if(desc.empty(), GraphError, | |||
| "desc for AxisAddRemove could not be empty"); | |||
| add_input({inp}); | |||
| add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); | |||
| axis_manip_init(inp); | |||
| add_equivalence_component<PODHash<AxisDesc>>(m_desc.data(), m_desc.size()); | |||
| } | |||
| @@ -631,13 +649,6 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout( | |||
| return layout; | |||
| } | |||
| AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const { | |||
| auto ret = Super::do_make_node_prop(); | |||
| ret->add_dep_type_existing_var(input(0), | |||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||
| return ret; | |||
| } | |||
| #ifdef MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(AxisAddRemove) { | |||
| MGB_MARK_USED_VAR(wrt_idx); | |||
| @@ -92,6 +92,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ReshapeBrdcastHelper, | |||
| void scn_do_execute() override final; | |||
| void add_input_layout_constraint() override final; | |||
| void init_output_static_infer_desc() override; | |||
| NodeProp* do_make_node_prop() const override; | |||
| protected: | |||
| using Super::Super; | |||
| @@ -199,11 +200,14 @@ MGB_DEFINE_CLS_WITH_SUPER(AxisManipOprBase, | |||
| void mem_plan_fwd_in2out_readonly() override final; | |||
| void scn_do_execute() override final; | |||
| void init_output_static_infer_desc() override final; | |||
| NodeProp* do_make_node_prop() const override; | |||
| protected: | |||
| using Super::Super; | |||
| virtual TensorLayout axis_manip_get_output_layout( | |||
| const TensorLayout &inp_layout) const = 0; | |||
| void axis_manip_init(VarNode* inp); | |||
| }; | |||
| } | |||
| @@ -319,8 +323,6 @@ MGB_DEFINE_OPR_CLASS(AxisAddRemove, intl::AxisManipOprBase) // { | |||
| TensorLayout axis_manip_get_output_layout( | |||
| const TensorLayout &inp_layout) const override; | |||
| NodeProp* do_make_node_prop() const override; | |||
| }; | |||
| namespace intl { | |||
| @@ -17,6 +17,7 @@ | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/opr/misc.h" | |||
| #include "megbrain/utils/arith_helper.h" | |||
| using namespace mgb; | |||
| @@ -138,7 +139,7 @@ TEST(TestTensorManip, Reshape) { | |||
| auto &&dep_map = opr0_reshp.node()->owner_opr()->node_prop().dep_map(); | |||
| using DT = cg::OperatorNodeBase::NodeProp::DepType; | |||
| ASSERT_EQ(2u, dep_map.size()); | |||
| ASSERT_EQ(DT::DEV_VALUE, dep_map.at(op->input(0))); | |||
| ASSERT_EQ(DT::DEV_VALUE | DT::VALUE_ALLOW_EMPTY, dep_map.at(op->input(0))); | |||
| ASSERT_EQ(DT::HOST_VALUE, dep_map.at(op->input(1))); | |||
| } | |||
| @@ -318,6 +319,39 @@ TEST(TestTensorManip, ReshapeInferShapeForDynamicInput) { | |||
| run({23, 12, 5}); | |||
| } | |||
| TEST(TestTensorManip, ReshapeEmptyShape) { | |||
| HostTensorGenerator<> gen; | |||
| constexpr size_t x_length = 233; | |||
| auto host_x = gen({x_length}), | |||
| host_v = gen({2, 3, 3, 3}); | |||
| for (size_t i = 0; i < x_length; ++ i) { | |||
| host_x->ptr<float>()[i] = 1.f; | |||
| } | |||
| constexpr auto INVALID_AXIS = opr::Reshape::Param::INVALID_AXIS; | |||
| for (auto unspec_axis: {INVALID_AXIS, 0, 1, 3}) { | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| TensorShape tshape{2, 3, 3, 3}; | |||
| auto zero_axis = unspec_axis; | |||
| if (unspec_axis == INVALID_AXIS) { | |||
| tshape[zero_axis = 2] = 0; | |||
| } | |||
| using CondTakeMode = opr::CondTake::Param::Mode; | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| x_empty = opr::CondTake::make(x, x, {CondTakeMode::EQ, 0.f})[0], | |||
| v = opr::Host2DeviceCopy::make(*graph, host_v), | |||
| x_reshape = opr::Reshape::make(x_empty, tshape, {unspec_axis}), | |||
| y = opr::Concat::make({x_reshape, v}, zero_axis); | |||
| HostTensorND host_empty, host_y; | |||
| auto func = graph->compile({ | |||
| make_callback_copy(x_reshape, host_empty), | |||
| make_callback_copy(y, host_y)}); | |||
| func->execute().wait(); | |||
| ASSERT_TRUE(host_empty.layout().is_empty()); | |||
| MGB_ASSERT_TENSOR_EQ(*host_v, host_y); | |||
| } | |||
| } | |||
| TEST(TestTensorManip, ReshapeWithNegativeUnspec) { | |||
| HostTensorGenerator<> gen; | |||
| auto host_x = gen({4, 8}); | |||
| @@ -365,6 +399,26 @@ TEST(TestTensorManip, Broadcast) { | |||
| } | |||
| } | |||
| TEST(TestTensorManip, BroadcastEmptyShape) { | |||
| HostTensorGenerator<> gen; | |||
| for (auto&& arg: | |||
| {std::make_pair(TensorShape{1}, TensorShape{0}), | |||
| {{1, 2, 3}, {0, 2, 3}}, | |||
| {{2, 3}, {1, 0, 2, 3}}, | |||
| {{1, 0, 2, 3}, {4, 0, 2, 3}}, | |||
| {{0, 1, 2, 3}, {3, 0, 4, 2, 3}}}) { | |||
| auto host_x = gen(arg.first); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| y = opr::Broadcast::make(x, arg.second); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| func->execute(); | |||
| ASSERT_TRUE(host_y.shape().eq_shape(arg.second)); | |||
| } | |||
| } | |||
| TEST(TestTensorManip, Dimshuffle) { | |||
| HostTensorGenerator<> gen; | |||
| constexpr size_t S0 = 8, S1 = 3; | |||
| @@ -395,6 +449,34 @@ TEST(TestTensorManip, Dimshuffle) { | |||
| } | |||
| } | |||
| TEST(TestTensorManip, DimshuffleEmptyShape) { | |||
| HostTensorGenerator<> gen; | |||
| for (auto&& arg: | |||
| {std::make_pair( | |||
| TensorShape{3, 0}, | |||
| std::vector<int>{1, -1, 0, -1}), | |||
| {{3, 1, 0, 4}, {-1, 3, -1, 0, 2}}, | |||
| {{2, 0, 3, 0}, {1, 0, 2, 3}}}) { | |||
| auto host_x = gen(arg.first); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x), | |||
| y = opr::Dimshuffle::make(x, arg.second); | |||
| HostTensorND host_y; | |||
| auto func = graph->compile({make_callback_copy(y, host_y)}); | |||
| func->execute(); | |||
| auto&& y_shape = host_y.shape(); | |||
| for(size_t idx = 0; idx < arg.second.size(); ++ idx) { | |||
| auto elem = arg.second[idx]; | |||
| if (elem == -1) { | |||
| ASSERT_EQ(y_shape[idx], 1u); | |||
| } else { | |||
| ASSERT_EQ(arg.first[elem], y_shape[idx]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| TEST(TestTensorManip, DimshuffleCombined) { | |||
| using Checker = AutoOprChecker<1, 1>; | |||
| constexpr int RED0 = 2, RED1 = 3; | |||