Browse Source

feat(mgb/opr): let more ops support empty IO

GitOrigin-RevId: 84dddb4b23
tags/v1.0.0
Megvii Engine Team 5 years ago
parent
commit
95eb6ae380
4 changed files with 125 additions and 23 deletions
  1. +14
    -7
      dnn/src/common/basic_types.cpp
  2. +24
    -13
      src/opr/impl/tensor_manip.cpp
  3. +4
    -2
      src/opr/include/megbrain/opr/tensor_manip.h
  4. +83
    -1
      src/opr/test/tensor_manip.cpp

+ 14
- 7
dnn/src/common/basic_types.cpp View File

@@ -392,8 +392,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
TensorLayout result{dtype, format};
result.ndim = tshape.ndim;
for (size_t i = 0; i < tshape.ndim; i++) {
megdnn_throw_if(!tshape.shape[i], tensor_reshape_error,
megdnn_mangle("target shape is 0"));
result.shape[i] = tshape.shape[i];
result.stride[i] = (tshape.shape[i] == 1);
}
@@ -409,8 +407,6 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
for (size_t i = 0; i < tshape.ndim; ++i) {
int target_idx = tshape.ndim - i - 1;
int cur_idx = ndim - i - 1;
megdnn_throw_if(!tshape.shape[target_idx], tensor_reshape_error,
megdnn_mangle("target shape is 0"));
size_t cur_shape = (cur_idx >= 0 ? shape[cur_idx] : 1),
cur_stride = (cur_idx >= 0 ? stride[cur_idx] : 0);
if (tshape.shape[target_idx] != cur_shape) {
@@ -434,10 +430,16 @@ TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const {
bool TensorLayout::try_reshape(TensorLayout& result,
const TensorShape& tshp) const {
megdnn_assert(tshp.ndim);

bool is_empty_shape = false;
for (size_t i = 0; i < tshp.ndim; ++i) {
megdnn_throw_if(!tshp.shape[i], tensor_reshape_error,
megdnn_mangle(ssprintf("bad target tshp: %s",
tshp.to_string().c_str())));
if (!tshp.shape[i]) {
megdnn_throw_if(!format.is_default(), tensor_reshape_error,
megdnn_mangle(ssprintf("bad target tshp: %s",
tshp.to_string().c_str())));
is_empty_shape = true;
break;
}
}

megdnn_throw_if(
@@ -454,6 +456,11 @@ bool TensorLayout::try_reshape(TensorLayout& result,
result.format = this->format;
result.TensorShape::operator=(tshp);

if (is_empty_shape) {
result.init_contiguous_stride();
return true;
}

size_t sdim = 0, prod = 1, cont_sdim = 0;
for (size_t i = 0; i < tshp.ndim; ++i) {
megdnn_assert(cont_sdim < cont.ndim);


+ 24
- 13
src/opr/impl/tensor_manip.cpp View File

@@ -237,7 +237,8 @@ void GetVarShape::record_execute_deps(ExecDependencyArray& deps) {

void ReshapeBrdcastHelper::reshapebrdcast_init(VarNode *inp, VarNode *tshp) {
add_input({inp, tshp});
add_output(None)->dtype(inp->dtype());
add_output(None)->dtype(inp->dtype())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
if (reshapebrdcast_output_shape_need_input_shape())
outshape_by_symvar_enable(1, 1);
else
@@ -340,6 +341,14 @@ void ReshapeBrdcastHelper::init_output_static_infer_desc() {
infer_value});
}

ReshapeBrdcastHelper::NodeProp*
ReshapeBrdcastHelper::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

// f}}}

/* f{{{ ======================= Reshape ======================= */
@@ -394,7 +403,7 @@ Maybe<TensorLayout> Reshape::reshapebrdcast_get_dest_layout(
}
auto tot_nr_elem = src.total_nr_elems();
actual_tshape.shape[unspec] = 0;
mgb_throw_if(tot_nr_elem % rem_nr_elem, TensorReshapeError,
mgb_throw_if(!rem_nr_elem || tot_nr_elem % rem_nr_elem, TensorReshapeError,
"could not reshape: src=%s tshape=%s unspec_axis=%zd",
static_cast<const TensorShape&>(src).to_string().c_str(),
actual_tshape.to_string().c_str(),
@@ -484,6 +493,17 @@ void AxisManipOprBase::init_output_static_infer_desc() {
{SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value});
}

AxisManipOprBase::NodeProp* AxisManipOprBase::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

void AxisManipOprBase::axis_manip_init(VarNode* inp) {
add_input({inp});
add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
}

// f}}}

@@ -504,8 +524,7 @@ Dimshuffle::Dimshuffle(VarNode *inp, const std::vector<int> &pattern,
mgb_throw_if(i < -1 || i >= int(ndim), GraphError,
"bad Dimshuffle pattern");
}
add_input({inp});
add_output(None);
axis_manip_init(inp);
add_equivalence_component<PODHash<int>>(m_pattern.data(), m_pattern.size());
}

@@ -587,8 +606,7 @@ AxisAddRemove::AxisAddRemove(
{
mgb_throw_if(desc.empty(), GraphError,
"desc for AxisAddRemove could not be empty");
add_input({inp});
add_output(None)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
axis_manip_init(inp);
add_equivalence_component<PODHash<AxisDesc>>(m_desc.data(), m_desc.size());
}

@@ -631,13 +649,6 @@ TensorLayout AxisAddRemove::axis_manip_get_output_layout(
return layout;
}

AxisAddRemove::NodeProp* AxisAddRemove::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}

#ifdef MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(AxisAddRemove) {
MGB_MARK_USED_VAR(wrt_idx);


+ 4
- 2
src/opr/include/megbrain/opr/tensor_manip.h View File

@@ -92,6 +92,7 @@ MGB_DEFINE_CLS_WITH_SUPER(ReshapeBrdcastHelper,
void scn_do_execute() override final;
void add_input_layout_constraint() override final;
void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;

protected:
using Super::Super;
@@ -199,11 +200,14 @@ MGB_DEFINE_CLS_WITH_SUPER(AxisManipOprBase,
void mem_plan_fwd_in2out_readonly() override final;
void scn_do_execute() override final;
void init_output_static_infer_desc() override final;
NodeProp* do_make_node_prop() const override;

protected:
using Super::Super;
virtual TensorLayout axis_manip_get_output_layout(
const TensorLayout &inp_layout) const = 0;

void axis_manip_init(VarNode* inp);
};

}
@@ -319,8 +323,6 @@ MGB_DEFINE_OPR_CLASS(AxisAddRemove, intl::AxisManipOprBase) // {

TensorLayout axis_manip_get_output_layout(
const TensorLayout &inp_layout) const override;

NodeProp* do_make_node_prop() const override;
};

namespace intl {


+ 83
- 1
src/opr/test/tensor_manip.cpp View File

@@ -17,6 +17,7 @@
#include "megbrain/opr/io.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/misc.h"
#include "megbrain/utils/arith_helper.h"

using namespace mgb;
@@ -138,7 +139,7 @@ TEST(TestTensorManip, Reshape) {
auto &&dep_map = opr0_reshp.node()->owner_opr()->node_prop().dep_map();
using DT = cg::OperatorNodeBase::NodeProp::DepType;
ASSERT_EQ(2u, dep_map.size());
ASSERT_EQ(DT::DEV_VALUE, dep_map.at(op->input(0)));
ASSERT_EQ(DT::DEV_VALUE | DT::VALUE_ALLOW_EMPTY, dep_map.at(op->input(0)));
ASSERT_EQ(DT::HOST_VALUE, dep_map.at(op->input(1)));
}

@@ -318,6 +319,39 @@ TEST(TestTensorManip, ReshapeInferShapeForDynamicInput) {
run({23, 12, 5});
}

TEST(TestTensorManip, ReshapeEmptyShape) {
HostTensorGenerator<> gen;
constexpr size_t x_length = 233;
auto host_x = gen({x_length}),
host_v = gen({2, 3, 3, 3});
for (size_t i = 0; i < x_length; ++ i) {
host_x->ptr<float>()[i] = 1.f;
}
constexpr auto INVALID_AXIS = opr::Reshape::Param::INVALID_AXIS;
for (auto unspec_axis: {INVALID_AXIS, 0, 1, 3}) {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
TensorShape tshape{2, 3, 3, 3};
auto zero_axis = unspec_axis;
if (unspec_axis == INVALID_AXIS) {
tshape[zero_axis = 2] = 0;
}
using CondTakeMode = opr::CondTake::Param::Mode;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
x_empty = opr::CondTake::make(x, x, {CondTakeMode::EQ, 0.f})[0],
v = opr::Host2DeviceCopy::make(*graph, host_v),
x_reshape = opr::Reshape::make(x_empty, tshape, {unspec_axis}),
y = opr::Concat::make({x_reshape, v}, zero_axis);
HostTensorND host_empty, host_y;
auto func = graph->compile({
make_callback_copy(x_reshape, host_empty),
make_callback_copy(y, host_y)});
func->execute().wait();
ASSERT_TRUE(host_empty.layout().is_empty());
MGB_ASSERT_TENSOR_EQ(*host_v, host_y);
}
}

TEST(TestTensorManip, ReshapeWithNegativeUnspec) {
HostTensorGenerator<> gen;
auto host_x = gen({4, 8});
@@ -365,6 +399,26 @@ TEST(TestTensorManip, Broadcast) {
}
}

TEST(TestTensorManip, BroadcastEmptyShape) {
HostTensorGenerator<> gen;
for (auto&& arg:
{std::make_pair(TensorShape{1}, TensorShape{0}),
{{1, 2, 3}, {0, 2, 3}},
{{2, 3}, {1, 0, 2, 3}},
{{1, 0, 2, 3}, {4, 0, 2, 3}},
{{0, 1, 2, 3}, {3, 0, 4, 2, 3}}}) {
auto host_x = gen(arg.first);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Broadcast::make(x, arg.second);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_TRUE(host_y.shape().eq_shape(arg.second));
}
}

TEST(TestTensorManip, Dimshuffle) {
HostTensorGenerator<> gen;
constexpr size_t S0 = 8, S1 = 3;
@@ -395,6 +449,34 @@ TEST(TestTensorManip, Dimshuffle) {
}
}

TEST(TestTensorManip, DimshuffleEmptyShape) {
HostTensorGenerator<> gen;
for (auto&& arg:
{std::make_pair(
TensorShape{3, 0},
std::vector<int>{1, -1, 0, -1}),
{{3, 1, 0, 4}, {-1, 3, -1, 0, 2}},
{{2, 0, 3, 0}, {1, 0, 2, 3}}}) {
auto host_x = gen(arg.first);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto x = opr::Host2DeviceCopy::make(*graph, host_x),
y = opr::Dimshuffle::make(x, arg.second);
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
auto&& y_shape = host_y.shape();
for(size_t idx = 0; idx < arg.second.size(); ++ idx) {
auto elem = arg.second[idx];
if (elem == -1) {
ASSERT_EQ(y_shape[idx], 1u);
} else {
ASSERT_EQ(arg.first[elem], y_shape[idx]);
}
}
}
}

TEST(TestTensorManip, DimshuffleCombined) {
using Checker = AutoOprChecker<1, 1>;
constexpr int RED0 = 2, RED1 = 3;


Loading…
Cancel
Save