|
|
|
@@ -21,8 +21,14 @@ |
|
|
|
|
|
|
|
namespace pnnx { |
|
|
|
|
|
|
|
static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs) |
|
|
|
static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs, bool ignore_aten_size) |
|
|
|
{ |
|
|
|
if (ignore_aten_size && v->node()->kind().toDisplayString() == std::string("aten::size")) |
|
|
|
{ |
|
|
|
// any intermediate shape is constant with static input shape |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto x : inputs) |
|
|
|
{ |
|
|
|
if (v == x) |
|
|
|
@@ -31,7 +37,7 @@ static bool value_link_input(const torch::jit::Value* v, const std::vector<torch |
|
|
|
|
|
|
|
for (size_t i = 0; i < v->node()->inputs().size(); i++) |
|
|
|
{ |
|
|
|
bool link = value_link_input(v->node()->inputs()[i], inputs); |
|
|
|
bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size); |
|
|
|
if (link) |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -183,7 +189,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit:: |
|
|
|
v->setType(c10::TensorType::create(t)); |
|
|
|
|
|
|
|
// check if value that does not depend on inputs |
|
|
|
if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) |
|
|
|
if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs)) |
|
|
|
{ |
|
|
|
output_tensors[v] = t; |
|
|
|
} |
|
|
|
@@ -221,7 +227,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit:: |
|
|
|
v->setType(finaltype); |
|
|
|
|
|
|
|
// check if value that does not depend on inputs |
|
|
|
if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) |
|
|
|
if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs)) |
|
|
|
{ |
|
|
|
output_tensors[v] = t; |
|
|
|
} |
|
|
|
|