Browse Source

catch more foldable constant in static shape inference mode (#3738)

tags/20220701
nihui GitHub 4 years ago
parent
commit
f9c4e90962
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 4 deletions
  1. +10
    -4
      tools/pnnx/src/pass_level0/shape_inference.cpp
  2. +2
    -0
      tools/pnnx/src/pass_level5/eliminate_dropout.cpp
  3. +2
    -0
      tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp
  4. +2
    -0
      tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp
  5. +2
    -0
      tools/pnnx/src/pass_level5/eliminate_slice.cpp
  6. +2
    -0
      tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp

+ 10
- 4
tools/pnnx/src/pass_level0/shape_inference.cpp View File

@@ -21,8 +21,14 @@

namespace pnnx {

static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs)
static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs, bool ignore_aten_size)
{
if (ignore_aten_size && v->node()->kind().toDisplayString() == std::string("aten::size"))
{
// any intermediate shape is constant with static input shape
return false;
}

for (auto x : inputs)
{
if (v == x)
@@ -31,7 +37,7 @@ static bool value_link_input(const torch::jit::Value* v, const std::vector<torch

for (size_t i = 0; i < v->node()->inputs().size(); i++)
{
bool link = value_link_input(v->node()->inputs()[i], inputs);
bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size);
if (link)
return true;
}
@@ -183,7 +189,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
v->setType(c10::TensorType::create(t));

// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs))
if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs))
{
output_tensors[v] = t;
}
@@ -221,7 +227,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
v->setType(finaltype);

// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs))
if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs))
{
output_tensors[v] = t;
}


+ 2
- 0
tools/pnnx/src/pass_level5/eliminate_dropout.cpp View File

@@ -53,6 +53,8 @@ void eliminate_dropout(Graph& graph)
op->inputs[0]->consumers.push_back(x);
}

op->inputs[0]->name = dropout_out->name;

dropout_out->producer = 0;
dropout_out->consumers.clear();



+ 2
- 0
tools/pnnx/src/pass_level5/eliminate_noop_expression.cpp View File

@@ -60,6 +60,8 @@ void eliminate_noop_expression(Graph& graph)
op->inputs[0]->consumers.push_back(x);
}

op->inputs[0]->name = expr_out->name;

expr_out->producer = 0;
expr_out->consumers.clear();



+ 2
- 0
tools/pnnx/src/pass_level5/eliminate_noop_pad.cpp View File

@@ -68,6 +68,8 @@ void eliminate_noop_pad(Graph& graph)
op->inputs[0]->consumers.push_back(x);
}

op->inputs[0]->name = pad_out->name;

pad_out->producer = 0;
pad_out->consumers.clear();



+ 2
- 0
tools/pnnx/src/pass_level5/eliminate_slice.cpp View File

@@ -62,6 +62,8 @@ void eliminate_slice(Graph& graph)
op->inputs[0]->consumers.push_back(x);
}

op->inputs[0]->name = slice_out->name;

slice_out->producer = 0;
slice_out->consumers.clear();



+ 2
- 0
tools/pnnx/src/pass_level5/eliminate_view_reshape.cpp View File

@@ -60,6 +60,8 @@ void eliminate_view_reshape(Graph& graph)
op->inputs[0]->consumers.push_back(x);
}

op->inputs[0]->name = op_out->name;

op_out->producer = 0;
op_out->consumers.clear();



Loading…
Cancel
Save