diff --git a/tools/pnnx/src/pass_level0/shape_inference.cpp b/tools/pnnx/src/pass_level0/shape_inference.cpp index e7c2d759e..a273dd79d 100644 --- a/tools/pnnx/src/pass_level0/shape_inference.cpp +++ b/tools/pnnx/src/pass_level0/shape_inference.cpp @@ -25,39 +25,98 @@ namespace pnnx { -static bool value_link_input(const torch::jit::Value* v, const std::vector& inputs, bool ignore_aten_size) +static bool is_inplace_op(const std::string& optype) { - if (ignore_aten_size) + return optype.size() > 2 && optype[optype.size() - 2] != '_' && optype[optype.size() - 1] == '_'; +} + +static bool is_alias_op(const std::string& optype) +{ + return optype == "aten::slice" || optype == "aten::select" || optype == "aten::view"; +} + +static bool is_static_shape_foldable(const std::string& optype) +{ + return optype == "aten::size" + || optype == "aten::new_empty" + || optype == "aten::new_full" + || optype == "aten::new_ones" + || optype == "aten::new_zeros" + || optype == "aten::empty_like" + || optype == "aten::full_like" + || optype == "aten::ones_like" + || optype == "aten::zeros_like" + || optype == "aten::_shape_as_tensor"; +} + +static void build_value_link_input_map(const torch::jit::Node* node, const std::unordered_map& value_alias_map, std::unordered_map& value_link_input_map, bool ignore_aten_size) +{ + std::string optype = node->kind().toDisplayString(); + + if (ignore_aten_size && is_static_shape_foldable(optype)) { - // any intermediate shape is constant with static input shape - std::string optype = v->node()->kind().toDisplayString(); - if (optype == "aten::size" - || optype == "aten::new_empty" - || optype == "aten::new_full" - || optype == "aten::new_ones" - || optype == "aten::new_zeros" - || optype == "aten::empty_like" - || optype == "aten::full_like" - || optype == "aten::ones_like" - || optype == "aten::zeros_like" - || optype == "aten::_shape_as_tensor") - return false; + return; } - for (auto x : inputs) + for (size_t i = 0; i < node->outputs().size(); i++) { - if (v == x) - return true; + auto out2 = node->outputs()[i]; + + std::string os = out2->debugName(); + + if (!os.empty() && value_link_input_map.find(os) != value_link_input_map.end()) + continue; + + auto tensor_type = out2->type()->cast(); + if (tensor_type) + { + value_link_input_map[os] = 1; + } + + for (size_t j = 0; j < out2->uses().size(); j++) + { + auto node2 = out2->uses()[j].user; + + build_value_link_input_map(node2, value_alias_map, value_link_input_map, true); + } } - for (size_t i = 0; i < v->node()->inputs().size(); i++) + if (is_inplace_op(optype) || is_alias_op(optype)) { - bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size); - if (link) - return true; - } + // infect input0 and its alias + while (1) + { + auto in2 = node->inputs()[0]; - return false; + std::string is = in2->debugName(); + + if (is.empty()) + break; + + if (value_alias_map.find(is) == value_alias_map.end()) + break; + + auto in3 = value_alias_map.at(is); + + auto tensor_type = in3->type()->cast(); + if (!tensor_type) + break; + + is = in3->debugName(); + + if (value_link_input_map.find(is) != value_link_input_map.end()) + break; + + for (size_t j = 0; j < in3->uses().size(); j++) + { + auto node2 = in3->uses()[j].user; + + build_value_link_input_map(node2, value_alias_map, value_link_input_map, true); + } + + break; + } + } } static bool value_link_output(const torch::jit::Value* v, const std::vector& outputs) @@ -79,8 +138,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vectorkind().toDisplayString(); - bool is_inplace_op = op_type.size() > 2 && op_type[op_type.size() - 2] != '_' && op_type[op_type.size() - 1] == '_'; - if (is_inplace_op) + if (is_inplace_op(op_type)) { // optimize me: track other inplace op inputs return true; @@ -154,6 +212,91 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr value_link_input_map; + { + // build value alias map for inplace op + std::unordered_map value_alias_map; + + for (const auto& n : graph->block()->nodes()) + { + if (n->kind() == c10::prim::GetAttr) + continue; + + if (n->kind() == c10::prim::Constant) + continue; + + if (n->kind() == c10::prim::CallMethod) + continue; + + std::string optype = n->kind().toDisplayString(); + + // fprintf(stderr, "optype = %s\n", optype.c_str()); + + if (!is_inplace_op(optype) && !is_alias_op(optype)) + continue; + + if (n->inputs().size() == 0) + continue; + + if (n->outputs().size() == 0) + continue; + + std::string is = n->input(0)->debugName(); + + if (is.empty()) + continue; + + for (size_t i = 0; i < n->outputs().size(); i++) + { + auto out2 = n->output(i); + + auto tensor_type = out2->type()->cast(); + if (!tensor_type) + continue; + + std::string os = out2->debugName(); + + if (os.empty()) + continue; + + if (value_alias_map.find(is) == value_alias_map.end()) + { + value_alias_map[os] = n->input(0); + } + else + { + value_alias_map[os] = value_alias_map[is]; + } + } + } + + // print value_alias_map + // for (const auto& x : value_alias_map) + // { + // fprintf(stderr, "alias %s -> %s\n", x.first.c_str(), x.second->debugName().c_str()); + // } + + bool ignore_aten_size = input_tensors2.empty(); + for (size_t i = 1; i < graph->inputs().size(); i++) + { + auto in0 = graph->inputs()[i]; + + for (size_t j = 0; j < in0->uses().size(); j++) + { + auto node = in0->uses()[j].user; + + build_value_link_input_map(node, value_alias_map, value_link_input_map, ignore_aten_size); + } + } + + // print value_link_input_map + // for (const auto& x : value_link_input_map) + // { + // fprintf(stderr, "link_input %s %d\n", x.first.c_str(), x.second); + // } + } + StoreZipWriter zip; zip.open(foldable_constants_zippath); @@ -246,7 +389,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(c10::TensorType::create(t)); // check if value that does not depend on inputs - if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs)) + if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) { // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); foldable_constants.insert(v->debugName()); @@ -288,7 +431,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptrsetType(finaltype); // check if value that does not depend on inputs - if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs)) + if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) { // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); foldable_constants.insert(v->debugName());