| @@ -25,39 +25,98 @@ | |||
| namespace pnnx { | |||
| static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs, bool ignore_aten_size) | |||
| static bool is_inplace_op(const std::string& optype) | |||
| { | |||
| if (ignore_aten_size) | |||
| return optype.size() > 2 && optype[optype.size() - 2] != '_' && optype[optype.size() - 1] == '_'; | |||
| } | |||
| static bool is_alias_op(const std::string& optype) | |||
| { | |||
| return optype == "aten::slice" || optype == "aten::select" || optype == "aten::view"; | |||
| } | |||
| static bool is_static_shape_foldable(const std::string& optype) | |||
| { | |||
| return optype == "aten::size" | |||
| || optype == "aten::new_empty" | |||
| || optype == "aten::new_full" | |||
| || optype == "aten::new_ones" | |||
| || optype == "aten::new_zeros" | |||
| || optype == "aten::empty_like" | |||
| || optype == "aten::full_like" | |||
| || optype == "aten::ones_like" | |||
| || optype == "aten::zeros_like" | |||
| || optype == "aten::_shape_as_tensor"; | |||
| } | |||
| static void build_value_link_input_map(const torch::jit::Node* node, const std::unordered_map<std::string, torch::jit::Value*>& value_alias_map, std::unordered_map<std::string, int>& value_link_input_map, bool ignore_aten_size) | |||
| { | |||
| std::string optype = node->kind().toDisplayString(); | |||
| if (ignore_aten_size && is_static_shape_foldable(optype)) | |||
| { | |||
| // any intermediate shape is constant with static input shape | |||
| std::string optype = v->node()->kind().toDisplayString(); | |||
| if (optype == "aten::size" | |||
| || optype == "aten::new_empty" | |||
| || optype == "aten::new_full" | |||
| || optype == "aten::new_ones" | |||
| || optype == "aten::new_zeros" | |||
| || optype == "aten::empty_like" | |||
| || optype == "aten::full_like" | |||
| || optype == "aten::ones_like" | |||
| || optype == "aten::zeros_like" | |||
| || optype == "aten::_shape_as_tensor") | |||
| return false; | |||
| return; | |||
| } | |||
| for (auto x : inputs) | |||
| for (size_t i = 0; i < node->outputs().size(); i++) | |||
| { | |||
| if (v == x) | |||
| return true; | |||
| auto out2 = node->outputs()[i]; | |||
| std::string os = out2->debugName(); | |||
| if (!os.empty() && value_link_input_map.find(os) != value_link_input_map.end()) | |||
| continue; | |||
| auto tensor_type = out2->type()->cast<torch::jit::TensorType>(); | |||
| if (tensor_type) | |||
| { | |||
| value_link_input_map[os] = 1; | |||
| } | |||
| for (size_t j = 0; j < out2->uses().size(); j++) | |||
| { | |||
| auto node2 = out2->uses()[j].user; | |||
| build_value_link_input_map(node2, value_alias_map, value_link_input_map, true); | |||
| } | |||
| } | |||
| for (size_t i = 0; i < v->node()->inputs().size(); i++) | |||
| if (is_inplace_op(optype) || is_alias_op(optype)) | |||
| { | |||
| bool link = value_link_input(v->node()->inputs()[i], inputs, ignore_aten_size); | |||
| if (link) | |||
| return true; | |||
| } | |||
| // infect input0 and its alias | |||
| while (1) | |||
| { | |||
| auto in2 = node->inputs()[0]; | |||
| return false; | |||
| std::string is = in2->debugName(); | |||
| if (is.empty()) | |||
| break; | |||
| if (value_alias_map.find(is) == value_alias_map.end()) | |||
| break; | |||
| auto in3 = value_alias_map.at(is); | |||
| auto tensor_type = in3->type()->cast<torch::jit::TensorType>(); | |||
| if (!tensor_type) | |||
| break; | |||
| is = in3->debugName(); | |||
| if (value_link_input_map.find(is) != value_link_input_map.end()) | |||
| break; | |||
| for (size_t j = 0; j < in3->uses().size(); j++) | |||
| { | |||
| auto node2 = in3->uses()[j].user; | |||
| build_value_link_input_map(node2, value_alias_map, value_link_input_map, true); | |||
| } | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| static bool value_link_output(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& outputs) | |||
| @@ -79,8 +138,7 @@ static bool value_link_output(const torch::jit::Value* v, const std::vector<torc | |||
| } | |||
| std::string op_type = node->kind().toDisplayString(); | |||
| bool is_inplace_op = op_type.size() > 2 && op_type[op_type.size() - 2] != '_' && op_type[op_type.size() - 1] == '_'; | |||
| if (is_inplace_op) | |||
| if (is_inplace_op(op_type)) | |||
| { | |||
| // optimize me: track other inplace op inputs | |||
| return true; | |||
| @@ -154,6 +212,91 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit:: | |||
| inputs2.push_back(it); | |||
| } | |||
| // bookkeep foldable tensors | |||
| std::unordered_map<std::string, int> value_link_input_map; | |||
| { | |||
| // build value alias map for inplace op | |||
| std::unordered_map<std::string, torch::jit::Value*> value_alias_map; | |||
| for (const auto& n : graph->block()->nodes()) | |||
| { | |||
| if (n->kind() == c10::prim::GetAttr) | |||
| continue; | |||
| if (n->kind() == c10::prim::Constant) | |||
| continue; | |||
| if (n->kind() == c10::prim::CallMethod) | |||
| continue; | |||
| std::string optype = n->kind().toDisplayString(); | |||
| // fprintf(stderr, "optype = %s\n", optype.c_str()); | |||
| if (!is_inplace_op(optype) && !is_alias_op(optype)) | |||
| continue; | |||
| if (n->inputs().size() == 0) | |||
| continue; | |||
| if (n->outputs().size() == 0) | |||
| continue; | |||
| std::string is = n->input(0)->debugName(); | |||
| if (is.empty()) | |||
| continue; | |||
| for (size_t i = 0; i < n->outputs().size(); i++) | |||
| { | |||
| auto out2 = n->output(i); | |||
| auto tensor_type = out2->type()->cast<torch::jit::TensorType>(); | |||
| if (!tensor_type) | |||
| continue; | |||
| std::string os = out2->debugName(); | |||
| if (os.empty()) | |||
| continue; | |||
| if (value_alias_map.find(is) == value_alias_map.end()) | |||
| { | |||
| value_alias_map[os] = n->input(0); | |||
| } | |||
| else | |||
| { | |||
| value_alias_map[os] = value_alias_map[is]; | |||
| } | |||
| } | |||
| } | |||
| // print value_alias_map | |||
| // for (const auto& x : value_alias_map) | |||
| // { | |||
| // fprintf(stderr, "alias %s -> %s\n", x.first.c_str(), x.second->debugName().c_str()); | |||
| // } | |||
| bool ignore_aten_size = input_tensors2.empty(); | |||
| for (size_t i = 1; i < graph->inputs().size(); i++) | |||
| { | |||
| auto in0 = graph->inputs()[i]; | |||
| for (size_t j = 0; j < in0->uses().size(); j++) | |||
| { | |||
| auto node = in0->uses()[j].user; | |||
| build_value_link_input_map(node, value_alias_map, value_link_input_map, ignore_aten_size); | |||
| } | |||
| } | |||
| // print value_link_input_map | |||
| // for (const auto& x : value_link_input_map) | |||
| // { | |||
| // fprintf(stderr, "link_input %s %d\n", x.first.c_str(), x.second); | |||
| // } | |||
| } | |||
| StoreZipWriter zip; | |||
| zip.open(foldable_constants_zippath); | |||
| @@ -246,7 +389,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit:: | |||
| v->setType(c10::TensorType::create(t)); | |||
| // check if value that does not depend on inputs | |||
| if (!value_link_input(v, g_inputs, true) && value_link_output(v, g_outputs)) | |||
| if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) | |||
| { | |||
| // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); | |||
| foldable_constants.insert(v->debugName()); | |||
| @@ -288,7 +431,7 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit:: | |||
| v->setType(finaltype); | |||
| // check if value that does not depend on inputs | |||
| if (!value_link_input(v, g_inputs, false) && value_link_output(v, g_outputs)) | |||
| if (value_link_input_map.find(v->debugName()) == value_link_input_map.end() && value_link_output(v, g_outputs)) | |||
| { | |||
| // fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); | |||
| foldable_constants.insert(v->debugName()); | |||