| @@ -432,6 +432,84 @@ static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::s | |||
| } | |||
| } | |||
| static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | |||
| { | |||
| int node_count = mutable_graph->node_size(); | |||
| for (int i = 0; i < node_count; i++) | |||
| { | |||
| onnx::NodeProto* node = mutable_graph->mutable_node(i); | |||
| // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) | |||
| if (node->op_type() == "ShuffleChannel") | |||
| { | |||
| // reverse = 1 | |||
| int reverse = get_node_attr_i(*node, "reverse"); | |||
| if (reverse != 1) | |||
| continue; | |||
| if (i + 2 >= node_count) | |||
| continue; | |||
| onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); | |||
| onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); | |||
| if (node2->op_type() != "Gather" || node3->op_type() != "Gather") | |||
| continue; | |||
| if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) | |||
| continue; | |||
| // axis = 0 | |||
| int gather2_axis = get_node_attr_i(*node2, "axis"); | |||
| if (gather2_axis != 0) | |||
| continue; | |||
| // indices = 0 | |||
| if (weights.find(node2->input(1)) == weights.end()) | |||
| continue; | |||
| std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); | |||
| if (gather2_indices.size() != 1 || gather2_indices[0] != 0) | |||
| continue; | |||
| // axis = 0 | |||
| int gather3_axis = get_node_attr_i(*node3, "axis"); | |||
| if (gather3_axis != 0) | |||
| continue; | |||
| // indices = 1 | |||
| if (weights.find(node3->input(1)) == weights.end()) | |||
| continue; | |||
| std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); | |||
| if (gather3_indices.size() != 1 || gather3_indices[0] != 1) | |||
| continue; | |||
| // reduce | |||
| node2->set_op_type("noop_reducedncnn"); | |||
| node_reference[node->output(0)] -= 1; | |||
| node_reference.erase(node_reference.find(node2->output(0))); | |||
| blob_names.erase(node2->output(0)); | |||
| node3->set_op_type("Split"); | |||
| node3->clear_input(); | |||
| node3->add_input(node->output(0)); | |||
| node3->add_output(node3->output(0)); | |||
| node3->set_output(0, node2->output(0)); | |||
| node3->clear_attribute(); | |||
| onnx::AttributeProto* attr_axis = node3->add_attribute(); | |||
| attr_axis->set_name("axis"); | |||
| attr_axis->set_i(1); | |||
| reduced_node_count += 1; | |||
| i += 1; | |||
| } | |||
| } | |||
| } | |||
| static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, onnx::TensorProto>& binaryop_weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count, std::vector<std::string>& reduced_binaryop_weights) | |||
| { | |||
| int node_count = mutable_graph->node_size(); | |||
| @@ -1383,6 +1461,7 @@ int main(int argc, char** argv) | |||
| std::vector<std::string> reduced_binaryop_weights; | |||
| fuse_matmul(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_shufflechannel(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_shufflechannel_split(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_hardsigmoid(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_hardswish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||
| fuse_swish(mutable_graph, weights, binaryop_weights, node_reference, blob_names, reduced_node_count, reduced_binaryop_weights); | |||