Browse Source

pnnx do not fold aliased tensors with inplace op (#5455)

tags/20240820
nihui GitHub 2 years ago
parent
commit
ab088e05b8
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 171 additions and 28 deletions
  1. +171
    -28
      tools/pnnx/src/pass_level0/shape_inference.cpp

+ 171
- 28
tools/pnnx/src/pass_level0/shape_inference.cpp View File

@@ -25,39 +25,98 @@

namespace pnnx {

static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs, bool ignore_aten_size)
static bool is_inplace_op(const std::string& optype)
{
if (ignore_aten_size)
return optype.size() > 2 && optype[optype.size() - 2] != '_' && optype[optype.size() - 1] == '_';
}

static bool is_alias_op(const std::string& optype)
{
return optype == "aten::slice" || optype == "aten::select" || optype == "aten::view";
}

static bool is_static_shape_foldable(const std::string& optype)
{
return optype == "aten::size"
|| optype == "aten::new_empty"
|| optype == "aten::new_full"
|| optype == "aten::new_ones"
|| optype == "aten::new_zeros"
|| optype == "aten::empty_like"
|| optype == "aten::full_like"
|| optype == "aten::ones_like"
|| optype == "aten::zeros_like"
|| optype == "aten::_shape_as_tensor";
}

static void build_value_link_input_map(const torch::jit::Node* node, const std::unordered_map<std::string, torch::jit::Value*>& value_alias_map, std::unordered_map<std::string, int>& value_link_input_map, bool ignore_aten_size)
{
std::string optype = node->kind().toDisplayString();

if (ignore_aten_size && is_static_shape_foldable(optype))
{
// any intermediate shape is constant with static input shape
std::string optype = v->node()->kind().toDisplayString();
if (optype == "aten::size"
|| optype == "aten::new_empty"
|| optype == "aten::new_full"
|| optype == "aten::new_ones"
|| optype == "aten::new_zeros"
|| optype == "aten::empty_like"
|| optype == "aten::full_like"
|| optype == "aten::ones_like"
|| optype == "aten::zeros_like"
|| optype == "aten::_shape_as_tensor")
return false;
return;
}

for (auto x : inputs)
for (size_t i = 0; i < node->outputs().size(); i++)
{
if (v == x)
return true;
auto out2 = node->outputs()[i];

std::string os = out2->debugName();

if (!os.empty() && value_link_input_map.find(os) != value_link_input_map.end())
continue;

auto tensor_type = out2->type()->cast<torch::jit::TensorType>();
if (tensor_type)
{
value_link_input_map[os] = 1;
}

for (size_t j = 0; j < out2->uses().size(); j++)
{
auto node2 = out2->uses()[j].user;

build_value_link_input_map(node2, value_alias_map, value_link_input_map, true);
}
}

for (size_t i = 0; i < v->node()->inputs().size(); i++)
if (is_inplace_op(optype) || is_alias_op(optype))
{
bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size);
if (link)
return true;
}
// infect input0 and its alias
while (1)
{
auto in2 = node->inputs()[0];

return false;
std::string is = in2->debugName();

if (is.empty())
break;

if (value_alias_map.find(is) == value_alias_map.end())
break;

auto in3 = value_alias_map.at(is);

auto tensor_type = in3->type()->cast<torch::jit::TensorType>();
if (!tensor_type)
break;

is = in3->debugName();

if (value_link_input_map.find(is) != value_link_input_map.end())
break;

for (size_t j = 0; j < in3->uses().size(); j++)
{
auto node2 = in3->uses()[j].user;

build_value_link_input_map(node2, value_alias_map, value_link_input_map, true);
}

break;
}
}
}

static bool value_link_output(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& outputs)
@@ -79,8 +138,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector<torc
}

std::string op_type = node->kind().toDisplayString();
bool is_inplace_op = op_type.size() > 2 && op_type[op_type.size() - 2] != '_' && op_type[op_type.size() - 1] == '_';
if (is_inplace_op)
if (is_inplace_op(op_type))
{
// optimize me: track other inplace op inputs
return true;
@@ -154,6 +212,91 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
inputs2.push_back(it);
}

// bookkeep foldable tensors
std::unordered_map<std::string, int> value_link_input_map;
{
// build value alias map for inplace op
std::unordered_map<std::string, torch::jit::Value*> value_alias_map;

for (const auto& n : graph->block()->nodes())
{
if (n->kind() == c10::prim::GetAttr)
continue;

if (n->kind() == c10::prim::Constant)
continue;

if (n->kind() == c10::prim::CallMethod)
continue;

std::string optype = n->kind().toDisplayString();

// fprintf(stderr, "optype = %s\n", optype.c_str());

if (!is_inplace_op(optype) && !is_alias_op(optype))
continue;

if (n->inputs().size() == 0)
continue;

if (n->outputs().size() == 0)
continue;

std::string is = n->input(0)->debugName();

if (is.empty())
continue;

for (size_t i = 0; i < n->outputs().size(); i++)
{
auto out2 = n->output(i);

auto tensor_type = out2->type()->cast<torch::jit::TensorType>();
if (!tensor_type)
continue;

std::string os = out2->debugName();

if (os.empty())
continue;

if (value_alias_map.find(is) == value_alias_map.end())
{
value_alias_map[os] = n->input(0);
}
else
{
value_alias_map[os] = value_alias_map[is];
}
}
}

// print value_alias_map
// for (const auto& x : value_alias_map)
// {
// fprintf(stderr, "alias %s -> %s\n", x.first.c_str(), x.second->debugName().c_str());
// }

bool ignore_aten_size = input_tensors2.empty();
for (size_t i = 1; i < graph->inputs().size(); i++)
{
auto in0 = graph->inputs()[i];

for (size_t j = 0; j < in0->uses().size(); j++)
{
auto node = in0->uses()[j].user;

build_value_link_input_map(node, value_alias_map, value_link_input_map, ignore_aten_size);
}
}

// print value_link_input_map
// for (const auto& x : value_link_input_map)
// {
// fprintf(stderr, "link_input %s %d\n", x.first.c_str(), x.second);
// }
}

StoreZipWriter zip;
zip.open(foldable_constants_zippath);

@@ -246,7 +389,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
v->setType(c10::TensorType::create(t));

// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs))
if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs))
{
// fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants.insert(v->debugName());
@@ -288,7 +431,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::
v->setType(finaltype);

// check if value that does not depend on inputs
if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs))
if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs))
{
// fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str());
foldable_constants.insert(v->debugName());


Loading…
Cancel
Save