| @@ -59,7 +59,7 @@ set(pnnx_pass_level1_SRCS | |||
| pass_level1/nn_MaxPool1d.cpp | |||
| pass_level1/nn_MaxPool2d.cpp | |||
| pass_level1/nn_MaxPool3d.cpp | |||
| pass_level1/nn_maxunpool2d.cpp | |||
| #pass_level1/nn_maxunpool2d.cpp | |||
| pass_level1/nn_Mish.cpp | |||
| pass_level1/nn_MultiheadAttention.cpp | |||
| pass_level1/nn_PixelShuffle.cpp | |||
| @@ -186,6 +186,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_clone.cpp | |||
| pass_level2/torch_dequantize.cpp | |||
| pass_level2/torch_flatten.cpp | |||
| pass_level2/torch_flip.cpp | |||
| pass_level2/torch_logsumexp.cpp | |||
| pass_level2/torch_matmul.cpp | |||
| pass_level2/torch_mean.cpp | |||
| @@ -193,6 +194,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_normal.cpp | |||
| pass_level2/torch_prod.cpp | |||
| pass_level2/torch_quantize_per_tensor.cpp | |||
| pass_level2/torch_randn.cpp | |||
| pass_level2/torch_roll.cpp | |||
| pass_level2/torch_split.cpp | |||
| pass_level2/torch_squeeze.cpp | |||
| @@ -200,6 +202,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/torch_sum.cpp | |||
| pass_level2/torch_permute.cpp | |||
| pass_level2/torch_transpose.cpp | |||
| pass_level2/torch_unbind.cpp | |||
| pass_level2/torch_unsqueeze.cpp | |||
| pass_level2/torch_var.cpp | |||
| pass_level2/torch_zeros.cpp | |||
| @@ -210,11 +213,11 @@ set(pnnx_pass_level2_SRCS | |||
| set(pnnx_pass_level3_SRCS | |||
| pass_level3/assign_unique_name.cpp | |||
| pass_level3/eliminate_noop_math.cpp | |||
| pass_level3/eliminate_tuple_pair.cpp | |||
| pass_level3/expand_quantization_modules.cpp | |||
| pass_level3/fuse_attribute_expression.cpp | |||
| pass_level3/fuse_cat_stack_tensors.cpp | |||
| pass_level3/fuse_chunk_split_unpack.cpp | |||
| pass_level3/fuse_chunk_split_unbind_unpack.cpp | |||
| pass_level3/fuse_expression.cpp | |||
| pass_level3/fuse_index_expression.cpp | |||
| pass_level3/fuse_rnn_unpack.cpp | |||
| @@ -235,6 +238,7 @@ set(pnnx_pass_level5_SRCS | |||
| pass_level5/eliminate_slice.cpp | |||
| pass_level5/eliminate_view_reshape.cpp | |||
| pass_level5/eval_expression.cpp | |||
| pass_level5/fold_constants.cpp | |||
| pass_level5/fuse_channel_shuffle.cpp | |||
| pass_level5/fuse_constant_expression.cpp | |||
| pass_level5/fuse_conv1d_batchnorm1d.cpp | |||
| @@ -1343,7 +1343,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| fprintf(pyfp, ","); | |||
| } | |||
| fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); | |||
| if (attr.type == 1 || attr.type == 2 || attr.type == 3) | |||
| { | |||
| fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); | |||
| } | |||
| } | |||
| } | |||
| @@ -1373,11 +1380,11 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| if (is_running_mean_var) | |||
| { | |||
| fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), key.c_str(), sanitize_identifier(op->name).c_str(), key.c_str()); | |||
| fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_tensor(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), key.c_str(), sanitize_identifier(op->name).c_str(), key.c_str()); | |||
| fprintf(pyfp, " self.%s_%s = self.load_pnnx_bin_as_parameter(archive, '%s.%s', (", sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str(), op->name.c_str(), key.c_str()); | |||
| } | |||
| for (size_t i = 0; i < attr.shape.size(); i++) | |||
| @@ -1387,7 +1394,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| fprintf(pyfp, ","); | |||
| } | |||
| fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); | |||
| if (attr.type == 1 || attr.type == 2 || attr.type == 3) | |||
| { | |||
| fprintf(pyfp, "), '%s')\n", type_to_numpy_string(attr.type)); | |||
| } | |||
| else | |||
| { | |||
| fprintf(pyfp, "), '%s', requires_grad=False)\n", type_to_numpy_string(attr.type)); | |||
| } | |||
| } | |||
| fprintf(pyfp, " archive.close()\n"); | |||
| @@ -1452,7 +1466,7 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| else if (op->type == "pnnx.Attribute") | |||
| { | |||
| const std::string& key = op->attrs.begin()->first; | |||
| fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), key.c_str()); | |||
| fprintf(pyfp, "v_%s = self.%s_%s\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->name).c_str(), sanitize_identifier(key).c_str()); | |||
| } | |||
| else if (op->type == "Tensor.slice") | |||
| { | |||
| @@ -1463,8 +1477,16 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) | |||
| else if (op->type == "Tensor.index") | |||
| { | |||
| // index expr | |||
| std::string index_expr = make_index_expression(op); | |||
| fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); | |||
| if (op->inputs.size() == 2) | |||
| { | |||
| std::string expanded_expr = expand_expression(op->inputs[1]->producer); | |||
| fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), expanded_expr.c_str()); | |||
| } | |||
| else | |||
| { | |||
| std::string index_expr = make_index_expression(op); | |||
| fprintf(pyfp, "v_%s = v_%s[%s]\n", sanitize_identifier(op->outputs[0]->name).c_str(), sanitize_identifier(op->inputs[0]->name).c_str(), index_expr.c_str()); | |||
| } | |||
| } | |||
| else if (op->type == "Tensor.view" || op->type == "Tensor.reshape") | |||
| { | |||
| @@ -279,9 +279,6 @@ int main(int argc, char** argv) | |||
| fprintf(stderr, "\n"); | |||
| } | |||
| // at::AutoNonVariableTypeMode nonVarTypeModeGuard(true); | |||
| // torch::autograd::AutoGradMode guard(false); | |||
| for (auto m : customop_modules) | |||
| { | |||
| fprintf(stderr, "load custom module %s\n", m.c_str()); | |||
| @@ -339,7 +336,8 @@ int main(int argc, char** argv) | |||
| fprintf(stderr, "############# pass_level0\n"); | |||
| pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators); | |||
| std::map<std::string, pnnx::Attribute> foldable_constants; | |||
| pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants); | |||
| // g->dump(); | |||
| @@ -373,7 +371,7 @@ int main(int argc, char** argv) | |||
| { | |||
| fprintf(stderr, "############# pass_level5\n"); | |||
| pnnx::pass_level5(pnnx_graph); | |||
| pnnx::pass_level5(pnnx_graph, foldable_constants); | |||
| } | |||
| pnnx_graph.save(pnnxparampath, pnnxbinpath); | |||
| @@ -20,7 +20,7 @@ | |||
| namespace pnnx { | |||
| void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators) | |||
| void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants) | |||
| { | |||
| inline_block(g, module_operators); | |||
| @@ -28,7 +28,7 @@ void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Grap | |||
| if (!input_tensors.empty()) | |||
| { | |||
| shape_inference(mod, g, input_tensors, input_tensors2); | |||
| shape_inference(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants); | |||
| } | |||
| } | |||
| @@ -16,10 +16,11 @@ | |||
| #define PNNX_PASS_LEVEL0_H | |||
| #include <torch/script.h> | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators); | |||
| void pass_level0(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& g, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants); | |||
| } // namespace pnnx | |||
| @@ -13,74 +13,233 @@ | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "shape_inference.h" | |||
| #include <unordered_set> | |||
| #include "pass_level0/constant_unpooling.h" | |||
| #include "pass_level0/inline_block.h" | |||
| #include "pass_level0/shape_inference.h" | |||
| namespace pnnx { | |||
| void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2) | |||
| static bool value_link_input(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& inputs) | |||
| { | |||
| // collect all intermediate output tensors | |||
| std::vector<torch::jit::Value*> values; | |||
| for (const auto& n : graph->nodes()) | |||
| for (auto x : inputs) | |||
| { | |||
| for (const auto& on : n->outputs()) | |||
| { | |||
| auto tensor_type = on->type()->cast<torch::jit::TensorType>(); | |||
| if (!tensor_type) | |||
| continue; | |||
| if (v == x) | |||
| return true; | |||
| } | |||
| for (size_t i = 0; i < v->node()->inputs().size(); i++) | |||
| { | |||
| bool link = value_link_input(v->node()->inputs()[i], inputs); | |||
| if (link) | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| static bool value_link_output(const torch::jit::Value* v, const std::vector<torch::jit::Value*>& outputs) | |||
| { | |||
| for (auto x : outputs) | |||
| { | |||
| if (v == x) | |||
| return true; | |||
| } | |||
| values.push_back(on); | |||
| for (size_t i = 0; i < v->uses().size(); i++) | |||
| { | |||
| auto node = v->uses()[i].user; | |||
| for (auto x : node->outputs()) | |||
| { | |||
| bool link = value_link_output(x, outputs); | |||
| if (link) | |||
| return true; | |||
| } | |||
| } | |||
| // set new graph output | |||
| auto old_output = graph->outputs()[0]; | |||
| return false; | |||
| } | |||
| torch::jit::Node* new_return_node = graph->createTuple(at::ArrayRef<torch::jit::Value*>(values)); | |||
| void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants) | |||
| { | |||
| // collect all intermediate output tensors | |||
| std::vector<std::unordered_set<std::string> > more_value_names; | |||
| std::vector<std::vector<torch::jit::Value*> > more_values; | |||
| { | |||
| std::unordered_set<std::string> value_names; | |||
| std::vector<torch::jit::Value*> values; | |||
| for (const auto& n : graph->nodes()) | |||
| { | |||
| for (const auto& v : n->outputs()) | |||
| { | |||
| auto tensor_type = v->type()->cast<torch::jit::TensorType>(); | |||
| if (!tensor_type) | |||
| continue; | |||
| graph->appendNode(new_return_node); | |||
| value_names.insert(v->debugName()); | |||
| values.push_back(v); | |||
| } | |||
| graph->eraseOutput(0); | |||
| graph->registerOutput(new_return_node->outputs()[0]); | |||
| // too many intermediate blobs in one inference results oom | |||
| if (value_names.size() >= 1000) | |||
| { | |||
| more_value_names.push_back(value_names); | |||
| value_names.clear(); | |||
| more_values.push_back(values); | |||
| values.clear(); | |||
| } | |||
| } | |||
| if (value_names.size() > 0) | |||
| { | |||
| more_value_names.push_back(value_names); | |||
| more_values.push_back(values); | |||
| } | |||
| } | |||
| // collect graph inputs outputs | |||
| std::vector<torch::jit::Value*> g_inputs; | |||
| for (size_t i = 1; i < graph->inputs().size(); i++) | |||
| { | |||
| g_inputs.push_back(graph->inputs()[i]); | |||
| } | |||
| std::vector<torch::jit::Value*> g_outputs; | |||
| for (size_t i = 0; i < graph->outputs().size(); i++) | |||
| { | |||
| g_outputs.push_back(graph->outputs()[i]); | |||
| } | |||
| // inference for all tensors | |||
| std::vector<torch::jit::IValue> inputs; | |||
| for (size_t i = 0; i < input_tensors.size(); i++) | |||
| { | |||
| const at::Tensor& it = input_tensors[i]; | |||
| inputs.push_back(it); | |||
| graph->inputs()[1 + i]->setType(c10::TensorType::create(it)); | |||
| } | |||
| auto outputs = mod.copy().forward(inputs).toTuple(); | |||
| std::vector<torch::jit::IValue> inputs2; | |||
| for (size_t i = 0; i < input_tensors2.size(); i++) | |||
| { | |||
| const at::Tensor& it = input_tensors2[i]; | |||
| inputs2.push_back(it); | |||
| } | |||
| if (input_tensors2.empty()) | |||
| std::map<torch::jit::Value*, at::Tensor> output_tensors; | |||
| for (size_t p = 0; p < more_value_names.size(); p++) | |||
| { | |||
| // assign shape info | |||
| int index = 0; | |||
| for (auto e : outputs->elements()) | |||
| std::unordered_set<std::string>& value_names = more_value_names[p]; | |||
| std::vector<torch::jit::Value*>& values = more_values[p]; | |||
| // auto mod2 = mod.deepcopy(); | |||
| torch::jit::Module mod2 = torch::jit::load(ptpath); | |||
| mod2.eval(); | |||
| auto graph2 = mod2.get_method("forward").graph(); | |||
| inline_block(graph2, module_operators); | |||
| constant_unpooling(graph2); | |||
| std::vector<torch::jit::Value*> values2; | |||
| for (auto n : graph2->nodes()) | |||
| { | |||
| values[index]->setType(c10::TensorType::create(e.toTensor())); | |||
| for (const auto& v : n->outputs()) | |||
| { | |||
| auto tensor_type = v->type()->cast<torch::jit::TensorType>(); | |||
| if (!tensor_type) | |||
| continue; | |||
| index++; | |||
| if (value_names.find(v->debugName()) != value_names.end()) | |||
| { | |||
| values2.push_back(v); | |||
| fprintf(stderr, "%s ", v->debugName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| std::vector<torch::jit::IValue> inputs2; | |||
| for (size_t i = 0; i < input_tensors2.size(); i++) | |||
| fprintf(stderr, "\n----------------\n\n"); | |||
| // set new graph output | |||
| torch::jit::Node* new_return_node = graph2->createTuple(at::ArrayRef<torch::jit::Value*>(values2)); | |||
| graph2->appendNode(new_return_node); | |||
| graph2->eraseOutput(0); | |||
| graph2->registerOutput(new_return_node->outputs()[0]); | |||
| // inference for all tensors | |||
| auto outputs = mod2.copy().forward(inputs).toTuple(); | |||
| if (input_tensors2.empty()) | |||
| { | |||
| const at::Tensor& it = input_tensors2[i]; | |||
| // assign shape info | |||
| for (size_t i = 0; i < values2.size(); i++) | |||
| { | |||
| auto v = values[i]; | |||
| auto t = outputs->elements()[i].toTensor(); | |||
| v->setType(c10::TensorType::create(t)); | |||
| inputs2.push_back(it); | |||
| graph->inputs()[1 + i]->setType(c10::TensorType::create(it)); | |||
| // check if value that does not depend on inputs | |||
| if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) | |||
| { | |||
| output_tensors[v] = t; | |||
| } | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // assign dynamic shape info | |||
| auto outputs2 = mod2.copy().forward(inputs2).toTuple(); | |||
| fprintf(stderr, "assign dynamic shape info\n"); | |||
| auto outputs2 = mod.copy().forward(inputs2).toTuple(); | |||
| for (size_t i = 0; i < values2.size(); i++) | |||
| { | |||
| auto v = values[i]; | |||
| auto t = outputs->elements()[i].toTensor(); | |||
| auto t2 = outputs2->elements()[i].toTensor(); | |||
| fprintf(stderr, "assign dynamic shape info\n"); | |||
| auto type1 = c10::TensorType::create(t); | |||
| auto type2 = c10::TensorType::create(t2); | |||
| std::vector<c10::ShapeSymbol> sizes1 = type1->symbolic_sizes().sizes().value(); | |||
| std::vector<c10::ShapeSymbol> sizes2 = type2->symbolic_sizes().sizes().value(); | |||
| for (size_t i = 0; i < sizes1.size(); i++) | |||
| { | |||
| if (sizes1[i] == sizes2[i]) | |||
| continue; | |||
| sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); | |||
| } | |||
| auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); | |||
| v->setType(finaltype); | |||
| // check if value that does not depend on inputs | |||
| if (!value_link_input(v, g_inputs) && value_link_output(v, g_outputs)) | |||
| { | |||
| output_tensors[v] = t; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| if (input_tensors2.empty()) | |||
| { | |||
| for (size_t i = 0; i < input_tensors.size(); i++) | |||
| { | |||
| auto type = c10::TensorType::create(input_tensors[i]); | |||
| // assign dynamic shape info | |||
| graph->inputs()[1 + i]->setType(type); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| for (size_t i = 0; i < input_tensors.size(); i++) | |||
| { | |||
| auto type1 = c10::TensorType::create(input_tensors[i]); | |||
| @@ -101,38 +260,34 @@ void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit:: | |||
| graph->inputs()[1 + i]->setType(finaltype); | |||
| } | |||
| } | |||
| int index = 0; | |||
| for (auto e : outputs->elements()) | |||
| { | |||
| auto type1 = c10::TensorType::create(e.toTensor()); | |||
| auto type2 = c10::TensorType::create(outputs2->elements()[index].toTensor()); | |||
| std::vector<c10::ShapeSymbol> sizes1 = type1->symbolic_sizes().sizes().value(); | |||
| std::vector<c10::ShapeSymbol> sizes2 = type2->symbolic_sizes().sizes().value(); | |||
| for (auto xx : output_tensors) | |||
| { | |||
| auto v = xx.first; | |||
| auto tensor = xx.second; | |||
| for (size_t i = 0; i < sizes1.size(); i++) | |||
| bool link_to_output = false; | |||
| for (size_t i = 0; i < v->uses().size(); i++) | |||
| { | |||
| auto node = v->uses()[i].user; | |||
| for (auto x : node->outputs()) | |||
| { | |||
| if (sizes1[i] == sizes2[i]) | |||
| continue; | |||
| sizes1[i] = c10::ShapeSymbol::fromStaticSize(-1); | |||
| if (output_tensors.find(x) == output_tensors.end()) | |||
| { | |||
| link_to_output = true; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| auto finaltype = type1->withSymbolicShapes(c10::SymbolicShape(sizes1)); | |||
| values[index]->setType(finaltype); | |||
| index++; | |||
| const int ndim = (int)tensor.dim(); | |||
| if (link_to_output && ndim > 0) | |||
| { | |||
| fprintf(stderr, "foldable_constant %s\n", v->debugName().c_str()); | |||
| foldable_constants[v->debugName()] = Attribute(tensor); | |||
| } | |||
| } | |||
| // restore old graph output | |||
| graph->eraseOutput(0); | |||
| graph->registerOutput(old_output); | |||
| new_return_node->removeAllInputs(); | |||
| new_return_node->destroy(); | |||
| } | |||
| } // namespace pnnx | |||
| @@ -13,9 +13,11 @@ | |||
| // specific language governing permissions and limitations under the License. | |||
| #include <torch/script.h> | |||
| #include <map> | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2); | |||
| void shape_inference(const torch::jit::Module& mod, std::shared_ptr<torch::jit::Graph>& graph, const std::vector<at::Tensor>& input_tensors, const std::vector<at::Tensor>& input_tensors2, const std::vector<std::string>& module_operators, const std::string& ptpath, std::map<std::string, Attribute>& foldable_constants); | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,41 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_flip : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dims | |||
| aten::flip op_0 2 1 input dims out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.flip"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_flip, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,44 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_randn : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input_0 0 1 size | |||
| prim::Constant op_0 0 1 dtype value=* | |||
| prim::Constant op_1 0 1 layout value=* | |||
| prim::Constant op_2 0 1 device value=* | |||
| prim::Constant op_3 0 1 requires_grad value=* | |||
| aten::randn op_4 5 1 size dtype layout device requires_grad out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.randn"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_randn, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,41 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_unbind : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 4 3 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 dim | |||
| aten::unbind op_0 2 1 input dim out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.unbind"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_unbind, 20) | |||
| } // namespace pnnx | |||
| @@ -15,11 +15,11 @@ | |||
| #include "pass_level3.h" | |||
| #include "pass_level3/assign_unique_name.h" | |||
| #include "pass_level3/eliminate_noop_math.h" | |||
| #include "pass_level3/eliminate_tuple_pair.h" | |||
| #include "pass_level3/expand_quantization_modules.h" | |||
| #include "pass_level3/fuse_attribute_expression.h" | |||
| #include "pass_level3/fuse_cat_stack_tensors.h" | |||
| #include "pass_level3/fuse_chunk_split_unpack.h" | |||
| #include "pass_level3/fuse_chunk_split_unbind_unpack.h" | |||
| #include "pass_level3/fuse_expression.h" | |||
| #include "pass_level3/fuse_index_expression.h" | |||
| #include "pass_level3/fuse_rnn_unpack.h" | |||
| @@ -39,14 +39,12 @@ void pass_level3(Graph& g) | |||
| fuse_cat_stack_tensors(g); | |||
| fuse_chunk_split_unpack(g); | |||
| fuse_chunk_split_unbind_unpack(g); | |||
| fuse_rnn_unpack(g); | |||
| expand_quantization_modules(g); | |||
| fuse_attribute_expression(g); | |||
| eliminate_tuple_pair(g); | |||
| rename_F_conv_transposend(g); | |||
| @@ -55,6 +53,8 @@ void pass_level3(Graph& g) | |||
| rename_F_dropoutnd(g); | |||
| eliminate_noop_math(g); | |||
| fuse_expression(g); | |||
| fuse_index_expression(g); | |||
| @@ -45,33 +45,6 @@ void assign_unique_name(Graph& graph) | |||
| } | |||
| } | |||
| } | |||
| // assign unique name for all operands | |||
| { | |||
| std::unordered_set<std::string> names; | |||
| int make_unique_index = 0; | |||
| for (size_t i = 0; i < graph.operands.size(); i++) | |||
| { | |||
| Operand* operand = graph.operands[i]; | |||
| const std::string& name = operand->name; | |||
| if (names.find(name) == names.end()) | |||
| { | |||
| names.insert(name); | |||
| } | |||
| else | |||
| { | |||
| // duplicated found | |||
| std::string new_name = std::string("pnnx_unique_") + std::to_string(make_unique_index); | |||
| fprintf(stderr, "assign unique operand name %s to %s\n", new_name.c_str(), name.c_str()); | |||
| operand->name = new_name; | |||
| names.insert(new_name); | |||
| make_unique_index++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,297 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "eliminate_noop_math.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| #include "pass_level4/dead_code_elimination.h" | |||
| namespace pnnx { | |||
| static bool constant_is_all_constant(const Operator* op_constant, float vf, int vi) | |||
| { | |||
| const Parameter& param = op_constant->params.at("value"); | |||
| if (param.type == 2) | |||
| { | |||
| if (param.i != vi) | |||
| return false; | |||
| } | |||
| else if (param.type == 3) | |||
| { | |||
| if (param.f != vf) | |||
| return false; | |||
| } | |||
| else | |||
| { | |||
| // unsupported data type | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| static bool attribute_is_all_constant(const Operator* op_attr, float vf, int vi) | |||
| { | |||
| const Attribute& attr = op_attr->attrs.begin()->second; | |||
| if (attr.shape.empty()) | |||
| { | |||
| fprintf(stderr, "shape empty!\n"); | |||
| return false; | |||
| } | |||
| int size = attr.shape[0]; | |||
| for (size_t i = 1; i < attr.shape.size(); i++) | |||
| { | |||
| size *= attr.shape[i]; | |||
| } | |||
| if (attr.type == 1) | |||
| { | |||
| const float* p = (const float*)attr.data.data(); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (p[i] != vf) | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.type == 2) | |||
| { | |||
| const double* p = (const double*)attr.data.data(); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (p[i] != vf) | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.type == 4) | |||
| { | |||
| const int* p = (const int*)attr.data.data(); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (p[i] != vi) | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.type == 5) | |||
| { | |||
| const int64_t* p = (const int64_t*)attr.data.data(); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (p[i] != vi) | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.type == 7) | |||
| { | |||
| const signed char* p = (const signed char*)attr.data.data(); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (p[i] != vi) | |||
| return false; | |||
| } | |||
| } | |||
| else if (attr.type == 8) | |||
| { | |||
| const unsigned char* p = (const unsigned char*)attr.data.data(); | |||
| for (int i = 0; i < size; i++) | |||
| { | |||
| if (p[i] != vi) | |||
| return false; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // unsupported data type | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| static bool operator_is_all_constant(const Operator* op, float vf, int vi) | |||
| { | |||
| if (op->type == "pnnx.Attribute") | |||
| return attribute_is_all_constant(op, vf, vi); | |||
| if (op->type == "prim::Constant") | |||
| return constant_is_all_constant(op, vf, vi); | |||
| return false; | |||
| } | |||
| void eliminate_noop_math(Graph& graph) | |||
| { | |||
| for (;;) | |||
| { | |||
| bool need_eliminate = false; | |||
| // build expression via reverse order | |||
| for (int i = (int)graph.ops.size() - 1; i >= 0; i--) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| int identity_input_id = 0; | |||
| if (op->type == "aten::add" || op->type == "aten::add_") | |||
| { | |||
| Operator* op0 = op->inputs[0]->producer; | |||
| Operator* op1 = op->inputs[1]->producer; | |||
| Operator* op2 = op->inputs[2]->producer; | |||
| if (operator_is_all_constant(op1, 0.f, 0)) | |||
| { | |||
| // x <= a + 0 * c | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| else if (operator_is_all_constant(op2, 0.f, 0)) | |||
| { | |||
| // x <= a + b * 0 | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| else if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op0, 1.f, 1)) | |||
| { | |||
| // x <= 0 + b * 1 | |||
| need_eliminate = true; | |||
| identity_input_id = 1; | |||
| } | |||
| } | |||
| if (op->type == "aten::sub") | |||
| { | |||
| Operator* op1 = op->inputs[1]->producer; | |||
| Operator* op2 = op->inputs[2]->producer; | |||
| if (operator_is_all_constant(op1, 0.f, 0)) | |||
| { | |||
| // x <= a - 0 * c | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| else if (operator_is_all_constant(op2, 0.f, 0)) | |||
| { | |||
| // x <= a - b * 0 | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| } | |||
| if (op->type == "aten::rsub") | |||
| { | |||
| Operator* op0 = op->inputs[0]->producer; | |||
| Operator* op1 = op->inputs[1]->producer; | |||
| Operator* op2 = op->inputs[2]->producer; | |||
| if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op2, 1.f, 1)) | |||
| { | |||
| // x <= b * 1 - 0 | |||
| need_eliminate = true; | |||
| identity_input_id = 1; | |||
| } | |||
| else if (operator_is_all_constant(op0, 0.f, 0) && operator_is_all_constant(op1, 1.f, 1)) | |||
| { | |||
| // x <= 1 * c - 0 | |||
| need_eliminate = true; | |||
| identity_input_id = 2; | |||
| } | |||
| } | |||
| if (op->type == "aten::mul") | |||
| { | |||
| Operator* op0 = op->inputs[0]->producer; | |||
| Operator* op1 = op->inputs[1]->producer; | |||
| if (operator_is_all_constant(op0, 1.f, 1)) | |||
| { | |||
| // x <= 1 * b | |||
| need_eliminate = true; | |||
| identity_input_id = 1; | |||
| } | |||
| if (operator_is_all_constant(op1, 1.f, 1)) | |||
| { | |||
| // x <= a * 1 | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| } | |||
| if (op->type == "aten::div" || op->type == "aten::div_") | |||
| { | |||
| Operator* op1 = op->inputs[1]->producer; | |||
| if (operator_is_all_constant(op1, 1.f, 1)) | |||
| { | |||
| // x <= a / 1 | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| } | |||
| if (op->type == "aten::pow") | |||
| { | |||
| Operator* op1 = op->inputs[1]->producer; | |||
| if (operator_is_all_constant(op1, 1.f, 1)) | |||
| { | |||
| // x <= x ^ 1 | |||
| need_eliminate = true; | |||
| identity_input_id = 0; | |||
| } | |||
| } | |||
| if (!need_eliminate) | |||
| continue; | |||
| fprintf(stderr, "eliminate_noop_math %s %s\n", op->type.c_str(), op->name.c_str()); | |||
| for (auto& x : op->inputs) | |||
| { | |||
| x->remove_consumer(op); | |||
| } | |||
| Operand* math_out = op->outputs[0]; | |||
| for (auto& x : math_out->consumers) | |||
| { | |||
| for (size_t j = 0; j < x->inputs.size(); j++) | |||
| { | |||
| if (x->inputs[j] == math_out) | |||
| x->inputs[j] = op->inputs[identity_input_id]; | |||
| } | |||
| op->inputs[identity_input_id]->consumers.push_back(x); | |||
| } | |||
| math_out->producer = 0; | |||
| math_out->consumers.clear(); | |||
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), math_out)); | |||
| delete math_out; | |||
| op->inputs.clear(); | |||
| op->outputs.clear(); | |||
| graph.ops.erase(graph.ops.begin() + i); | |||
| delete op; | |||
| break; | |||
| } | |||
| if (!need_eliminate) | |||
| break; | |||
| } | |||
| // dce | |||
| dead_code_elimination(graph); | |||
| } | |||
| } // namespace pnnx | |||
| @@ -1,6 +1,6 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| @@ -16,6 +16,6 @@ | |||
| namespace pnnx { | |||
| void fuse_chunk_split_unpack(Graph& graph); | |||
| void eliminate_noop_math(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -1,197 +0,0 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_attribute_expression.h" | |||
| #include <math.h> | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void fuse_attribute_expression(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "pnnx.Attribute") | |||
| continue; | |||
| if (op->outputs.size() != 1) | |||
| continue; | |||
| if (op->outputs[0]->consumers.size() != 1) | |||
| continue; | |||
| Operator* op2 = op->outputs[0]->consumers[0]; | |||
| Operator* op3 = 0; | |||
| Operator* op4 = 0; | |||
| float y = 0.f; | |||
| float z = 0.f; | |||
| if (op2->type == "aten::add" || op2->type == "aten::sub") | |||
| { | |||
| if (op2->inputs[0] != op->outputs[0]) | |||
| continue; | |||
| op3 = op2->inputs[1]->producer; | |||
| if (op3->type != "prim::Constant") | |||
| continue; | |||
| if (op3->params["value"].type == 2) | |||
| { | |||
| y = op3->params["value"].i; | |||
| } | |||
| else if (op3->params["value"].type == 3) | |||
| { | |||
| y = op3->params["value"].f; | |||
| } | |||
| else | |||
| { | |||
| // not a scalar | |||
| continue; | |||
| } | |||
| op4 = op2->inputs[2]->producer; | |||
| if (op4->type != "prim::Constant") | |||
| continue; | |||
| if (op4->params["value"].type == 2) | |||
| { | |||
| z = op4->params["value"].i; | |||
| } | |||
| else if (op4->params["value"].type == 3) | |||
| { | |||
| z = op4->params["value"].f; | |||
| } | |||
| else | |||
| { | |||
| // not a scalar | |||
| continue; | |||
| } | |||
| } | |||
| else if (op2->type == "aten::mul" || op2->type == "aten::div" || op2->type == "aten::pow") | |||
| { | |||
| if (op2->inputs[0] != op->outputs[0]) | |||
| continue; | |||
| op3 = op2->inputs[1]->producer; | |||
| if (op3->type != "prim::Constant") | |||
| continue; | |||
| if (op3->params["value"].type == 2) | |||
| { | |||
| y = op3->params["value"].i; | |||
| } | |||
| else if (op3->params["value"].type == 3) | |||
| { | |||
| y = op3->params["value"].f; | |||
| } | |||
| else | |||
| { | |||
| // not a scalar | |||
| continue; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| // todo more operator type | |||
| continue; | |||
| } | |||
| matched = true; | |||
| // apply mul | |||
| { | |||
| auto it = op->attrs.begin(); | |||
| std::string attr_key = it->first; | |||
| const Attribute& attr = it->second; | |||
| float* weight = (float*)attr.data.data(); | |||
| const int weight_size = attr.data.size() / sizeof(float); | |||
| if (op2->type == "aten::add") | |||
| { | |||
| for (int i = 0; i < weight_size; i++) | |||
| weight[i] += y * z; | |||
| } | |||
| else if (op2->type == "aten::sub") | |||
| { | |||
| for (int i = 0; i < weight_size; i++) | |||
| weight[i] -= y * z; | |||
| } | |||
| else if (op2->type == "aten::mul") | |||
| { | |||
| for (int i = 0; i < weight_size; i++) | |||
| weight[i] *= y; | |||
| } | |||
| else if (op2->type == "aten::div") | |||
| { | |||
| for (int i = 0; i < weight_size; i++) | |||
| weight[i] /= y; | |||
| } | |||
| else if (op2->type == "aten::pow") | |||
| { | |||
| for (int i = 0; i < weight_size; i++) | |||
| weight[i] = (float)pow(weight[i], y); | |||
| } | |||
| op->attrs[attr_key] = attr; | |||
| } | |||
| op2->outputs[0]->producer = op; | |||
| for (auto& x : op2->inputs) | |||
| { | |||
| x->producer = 0; | |||
| x->remove_consumer(op2); | |||
| } | |||
| op->outputs = op2->outputs; | |||
| op2->inputs.clear(); | |||
| op2->outputs.clear(); | |||
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op2)); | |||
| delete op2; | |||
| if (op3 && op3->outputs[0]->consumers.empty()) | |||
| { | |||
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op3)); | |||
| delete op3; | |||
| } | |||
| if (op4 && op4->outputs[0]->consumers.empty()) | |||
| { | |||
| graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), op4)); | |||
| delete op4; | |||
| } | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -12,13 +12,13 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_chunk_split_unpack.h" | |||
| #include "fuse_chunk_split_unbind_unpack.h" | |||
| #include <algorithm> | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| void fuse_chunk_split_unpack(Graph& graph) | |||
| void fuse_chunk_split_unbind_unpack(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| @@ -28,7 +28,7 @@ void fuse_chunk_split_unpack(Graph& graph) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "torch.chunk" && op->type != "torch.split") | |||
| if (op->type != "torch.chunk" && op->type != "torch.split" && op->type != "torch.unbind") | |||
| continue; | |||
| if (op->outputs.size() != 1) | |||
| @@ -16,6 +16,6 @@ | |||
| namespace pnnx { | |||
| void fuse_attribute_expression(Graph& graph); | |||
| void fuse_chunk_split_unbind_unpack(Graph& graph); | |||
| } // namespace pnnx | |||
| @@ -18,13 +18,13 @@ | |||
| namespace pnnx { | |||
| static bool operand_maybe_tensor(Operand* operand) | |||
| static bool operand_maybe_tensor(const Operand* operand) | |||
| { | |||
| Operator* op = operand->producer; | |||
| const Operator* op = operand->producer; | |||
| if (op->type == "prim::Constant") | |||
| { | |||
| const Parameter& param = op->params["value"]; | |||
| const Parameter& param = op->params.at("value"); | |||
| if (param.type == 0 || param.type == 1 || param.type == 2 || param.type == 3 || param.type == 4) | |||
| { | |||
| return false; | |||
| @@ -83,9 +83,25 @@ static bool operand_maybe_tensor(Operand* operand) | |||
| return true; | |||
| } | |||
| static bool operand_is_foldable(const Operand* operand) | |||
| { | |||
| const Operator* op = operand->producer; | |||
| if (op->type == "pnnx.Input") | |||
| return false; | |||
| for (auto x : op->inputs) | |||
| { | |||
| if (!operand_is_foldable(x)) | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, std::vector<Operand*>& inputs, bool checksubgraph = true) | |||
| { | |||
| // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); | |||
| // fprintf(stderr, "fuse_expression %s\n", operand->name.c_str()); | |||
| Operator* op = operand->producer; | |||
| @@ -164,6 +180,28 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s | |||
| } | |||
| } | |||
| } | |||
| else if (checksubgraph && operand_maybe_tensor(operand) && operand_is_foldable(operand)) | |||
| { | |||
| // fprintf(stderr, "operand_is_foldable %s\n", operand->name.c_str()); | |||
| auto it = std::find(inputs.begin(), inputs.end(), operand); | |||
| if (it == inputs.end()) | |||
| { | |||
| // tensor | |||
| char tmp[32]; | |||
| sprintf(tmp, "@%d", (int)inputs.size()); | |||
| expr += tmp; | |||
| inputs.push_back(operand); | |||
| } | |||
| else | |||
| { | |||
| // tensor | |||
| char tmp[32]; | |||
| sprintf(tmp, "@%d", (int)(it - inputs.begin())); | |||
| expr += tmp; | |||
| } | |||
| } | |||
| else if (op->type == "prim::NumToTensor") | |||
| { | |||
| fuse_expression(graph, op->inputs[0], expr, inputs); | |||
| @@ -26,7 +26,7 @@ void pass_level4(Graph& g) | |||
| dead_code_elimination(g); | |||
| canonicalize(g); | |||
| //canonicalize(g); | |||
| } | |||
| } // namespace pnnx | |||
| @@ -14,6 +14,7 @@ | |||
| #include "pass_level5.h" | |||
| #include "pass_level5/fold_constants.h" | |||
| #include "pass_level5/eliminate_dropout.h" | |||
| #include "pass_level5/eliminate_slice.h" | |||
| #include "pass_level5/eliminate_view_reshape.h" | |||
| @@ -29,10 +30,11 @@ | |||
| #include "pass_level5/fuse_slice_indices.h" | |||
| #include "pass_level4/dead_code_elimination.h" | |||
| #include "pass_level4/canonicalize.h" | |||
| #include "pass_level3/fuse_index_expression.h" | |||
| namespace pnnx { | |||
| void pass_level5(Graph& g) | |||
| void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_constants) | |||
| { | |||
| eval_expression(g); | |||
| @@ -60,6 +62,10 @@ void pass_level5(Graph& g) | |||
| fuse_channel_shuffle(g); | |||
| fold_constants(g, foldable_constants); | |||
| fuse_index_expression(g); | |||
| dead_code_elimination(g); | |||
| canonicalize(g); | |||
| @@ -19,7 +19,7 @@ | |||
| namespace pnnx { | |||
| void pass_level5(Graph& g); | |||
| void pass_level5(Graph& g, const std::map<std::string, Attribute>& foldable_constants); | |||
| } // namespace pnnx | |||
| @@ -40,24 +40,24 @@ void eliminate_dropout(Graph& graph) | |||
| x->remove_consumer(op); | |||
| } | |||
| Operand* slice_out = op->outputs[0]; | |||
| Operand* dropout_out = op->outputs[0]; | |||
| for (auto& x : slice_out->consumers) | |||
| for (auto& x : dropout_out->consumers) | |||
| { | |||
| for (size_t j = 0; j < x->inputs.size(); j++) | |||
| { | |||
| if (x->inputs[j] == slice_out) | |||
| if (x->inputs[j] == dropout_out) | |||
| x->inputs[j] = op->inputs[0]; | |||
| } | |||
| op->inputs[0]->consumers.push_back(x); | |||
| } | |||
| slice_out->producer = 0; | |||
| slice_out->consumers.clear(); | |||
| dropout_out->producer = 0; | |||
| dropout_out->consumers.clear(); | |||
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), slice_out)); | |||
| delete slice_out; | |||
| graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), dropout_out)); | |||
| delete dropout_out; | |||
| op->inputs.clear(); | |||
| op->outputs.clear(); | |||
| @@ -0,0 +1,50 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fold_constants.h" | |||
| #include <unordered_set> | |||
| #include "pass_level4/dead_code_elimination.h" | |||
| namespace pnnx { | |||
| void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldable_constants) | |||
| { | |||
| for (size_t i = 0; i < graph.operands.size(); i++) | |||
| { | |||
| Operand* operand = graph.operands[i]; | |||
| const std::string& name = operand->name; | |||
| if (foldable_constants.find(name) == foldable_constants.end()) | |||
| continue; | |||
| Operator* op = operand->producer; | |||
| if (op->type == "pnnx.Attribute") | |||
| continue; | |||
| // replace producer with attribute | |||
| Operator* op_new = graph.new_operator_before("pnnx.Attribute", std::string("pnnx_fold_") + name, op); | |||
| op_new->attrs[std::string("pnnx_fold_") + name] = foldable_constants.at(name); | |||
| op_new->outputs.push_back(operand); | |||
| operand->producer = op_new; | |||
| op->outputs.clear(); | |||
| } | |||
| // dce | |||
| dead_code_elimination(graph); | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void fold_constants(Graph& graph, const std::map<std::string, Attribute>& foldable_constants); | |||
| } // namespace pnnx | |||
| @@ -181,6 +181,7 @@ pnnx_add_test(torch_split) | |||
| pnnx_add_test(torch_squeeze) | |||
| pnnx_add_test(torch_stack) | |||
| pnnx_add_test(torch_transpose) | |||
| pnnx_add_test(torch_unbind) | |||
| pnnx_add_test(torch_unsqueeze) | |||
| pnnx_add_test(mobilenet_v2) | |||
| @@ -192,6 +193,8 @@ pnnx_add_test(squeezenet1_1) | |||
| # TODO enable end2end quantization model test | |||
| #pnnx_add_test(quantization_shufflenet_v2_x1_0) | |||
| pnnx_add_test(pnnx_eliminate_noop_math) | |||
| pnnx_add_test(pnnx_fold_constant) | |||
| pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d) | |||
| pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d) | |||
| pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d) | |||
| @@ -39,15 +39,15 @@ def test(): | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_Tensor_slice.pt") | |||
| mod.save("test_Tensor_index.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_Tensor_slice.pt inputshape=[3,6],[5,9,2],[2,4,5,10]") | |||
| os.system("../src/pnnx test_Tensor_index.pt inputshape=[3,6],[5,9,2],[2,4,5,10]") | |||
| # pnnx inference | |||
| import test_Tensor_slice_pnnx | |||
| b = test_Tensor_slice_pnnx.test_inference() | |||
| import test_Tensor_index_pnnx | |||
| b = test_Tensor_index_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| @@ -0,0 +1,62 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.w0 = nn.Parameter(torch.zeros(1, 12, 52)) | |||
| self.w1 = nn.Parameter(torch.ones(1, 12, 52)) | |||
| self.w2 = nn.Parameter(torch.ones(1, 12, 52)) | |||
| def forward(self, x): | |||
| x = x + 0 | |||
| x = x * 1 / 1 | |||
| x = 0 + 1 * x | |||
| x = x + self.w0 * self.w1 | |||
| x = x * self.w2 | |||
| return x | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 12, 52) | |||
| a = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_pnnx_eliminate_noop_math.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_eliminate_noop_math.pt inputshape=[1,12,52]") | |||
| # pnnx inference | |||
| import test_pnnx_eliminate_noop_math_pnnx | |||
| b = test_pnnx_eliminate_noop_math_pnnx.test_inference() | |||
| return torch.equal(a, b) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,60 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.w0 = nn.Parameter(torch.rand(1, 12, 52)) | |||
| self.w1 = nn.Parameter(torch.rand(1, 12, 52)) | |||
| self.w2 = nn.Parameter(torch.rand(1, 12, 1)) | |||
| self.w3 = nn.Parameter(torch.rand(1, 12, 52)) | |||
| def forward(self, x): | |||
| b = (self.w0 + self.w1 + 0.22) + self.w2 * 0.1 | |||
| x = x + b - self.w3 / 2 | |||
| return x | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 12, 52) | |||
| a = net(x) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, x) | |||
| mod.save("test_pnnx_fold_constant.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_pnnx_fold_constant.pt inputshape=[1,12,52]") | |||
| # pnnx inference | |||
| import test_pnnx_fold_constant_pnnx | |||
| b = test_pnnx_fold_constant_pnnx.test_inference() | |||
| return torch.equal(a, b) | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x0, x1, x2 = torch.unbind(x, dim=1) | |||
| y0, y1, y2, y3, y4, y5, y6, y7, y8 = torch.unbind(y, dim=2) | |||
| z0, z1, z2, z3 = torch.unbind(z, dim=0) | |||
| return x0, x1, y0, y1, y2, y3, y4, y5, y6, y7, y8, z0, z1, z2, z3 | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(4, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_unbind.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_unbind.pt inputshape=[1,3,16],[1,5,9,11],[4,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_unbind_pnnx | |||
| b = test_torch_unbind_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||