| @@ -432,6 +432,84 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::s | |||||
| } | } | ||||
| } | } | ||||
| static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | |||||
| { | |||||
| int node_count = mutable_graph->node_size(); | |||||
| for (int i = 0; i < node_count; i++) | |||||
| { | |||||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||||
| // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) | |||||
| if (node->op_type() == "ShuffleChannel") | |||||
| { | |||||
| // reverse = 1 | |||||
| int reverse = get_node_attr_i(*node, "reverse"); | |||||
| if (reverse != 1) | |||||
| continue; | |||||
| if (i + 2 >= node_count) | |||||
| continue; | |||||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | |||||
| onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); | |||||
| if (node2->op_type() != "Gather" || node3->op_type() != "Gather") | |||||
| continue; | |||||
| if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) | |||||
| continue; | |||||
| // axis = 0 | |||||
| int gather2_axis = get_node_attr_i(*node2, "axis"); | |||||
| if (gather2_axis != 0) | |||||
| continue; | |||||
| // indices = 0 | |||||
| if (weights.find(node2->input(1)) == weights.end()) | |||||
| continue; | |||||
| std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); | |||||
| if (gather2_indices.size() != 1 || gather2_indices[0] != 0) | |||||
| continue; | |||||
| // axis = 0 | |||||
| int gather3_axis = get_node_attr_i(*node3, "axis"); | |||||
| if (gather3_axis != 0) | |||||
| continue; | |||||
| // indices = 1 | |||||
| if (weights.find(node3->input(1)) == weights.end()) | |||||
| continue; | |||||
| std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); | |||||
| if (gather3_indices.size() != 1 || gather3_indices[0] != 1) | |||||
| continue; | |||||
| // reduce | |||||
| node2->set_op_type("noop_reducedncnn"); | |||||
| node_reference[node->output(0)] -= 1; | |||||
| node_reference.erase(node_reference.find(node2->output(0))); | |||||
| blob_names.erase(node2->output(0)); | |||||
| node3->set_op_type("Split"); | |||||
| node3->clear_input(); | |||||
| node3->add_input(node->output(0)); | |||||
| node3->add_output(node3->output(0)); | |||||
| node3->set_output(0, node2->output(0)); | |||||
| node3->clear_attribute(); | |||||
| onnx::AttributeProto* attr_axis = node3->add_attribute(); | |||||
| attr_axis->set_name("axis"); | |||||
| attr_axis->set_i(1); | |||||
| reduced_node_count += 1; | |||||
| i += 1; | |||||
| } | |||||
| } | |||||
| } | |||||
| static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | ||||
| { | { | ||||
| int node_count = mutable_graph->node_size(); | int node_count = mutable_graph->node_size(); | ||||
| @@ -1383,6 +1461,7 @@ int main(int argc, char** argv) | |||||
| std::vector<std::string> reduced_binaryop_weights; | std::vector<std::string> reduced_binaryop_weights; | ||||
| fuse_matmul(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | fuse_matmul(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | ||||
| fuse_shufflechannel(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | fuse_shufflechannel(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | ||||
| fuse_shufflechannel_split(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||||
| fuse_hardsigmoid(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | fuse_hardsigmoid(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | ||||
| fuse_hardswish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | fuse_hardswish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | ||||
| fuse_swish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | fuse_swish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | ||||