| @@ -17,6 +17,7 @@ | |||
| #include <string.h> | |||
| #include <map> | |||
| #include <set> | |||
| #include <string> | |||
| #include <vector> | |||
| @@ -439,6 +440,8 @@ class PyTorchNode | |||
| { | |||
| public: | |||
| std::string op; | |||
| std::string name; | |||
| std::vector<std::string> inputs; | |||
| std::vector<std::string> outputs; | |||
| std::vector<std::string> args; | |||
| }; | |||
| @@ -563,6 +566,8 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes) | |||
| { | |||
| nodes.clear(); | |||
| int internal_unknown = 0; | |||
| // read code line | |||
| bool forward_input = false; | |||
| @@ -595,8 +600,24 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes) | |||
| if (nscan != 1) | |||
| continue; | |||
| fprintf(stderr, "netinput = %s\n", netinput); | |||
| // fprintf(stderr, "netinput = %s\n", netinput); | |||
| forward_input = false; | |||
| PyTorchNode n; | |||
| n.op = "Input"; | |||
| n.outputs.push_back(netinput); | |||
| { | |||
| // assign default unknown name | |||
| char unknownname[256]; | |||
| sprintf(unknownname, "unknownncnn_%d", internal_unknown); | |||
| n.name = unknownname; | |||
| internal_unknown++; | |||
| } | |||
| nodes.push_back(n); | |||
| } | |||
| else if (strstr(line, " = torch.")) | |||
| { | |||
| @@ -627,6 +648,17 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes) | |||
| n.op = op; | |||
| n.outputs = parse_op_output_list(outputs); | |||
| n.args = parse_op_arg_list(op_args.c_str()); | |||
| { | |||
| // assign default unknown name | |||
| char unknownname[256]; | |||
| sprintf(unknownname, "unknownncnn_%d", internal_unknown); | |||
| n.name = unknownname; | |||
| internal_unknown++; | |||
| } | |||
| nodes.push_back(n); | |||
| } | |||
| else if (strstr(line, "return torch.")) | |||
| @@ -653,8 +685,19 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes) | |||
| PyTorchNode n; | |||
| n.op = op; | |||
| n.outputs = parse_op_output_list("ncnnoutput_0"); | |||
| n.outputs.push_back("ncnnoutput_0"); | |||
| n.args = parse_op_arg_list(op_args.c_str()); | |||
| { | |||
| // assign default unknown name | |||
| char unknownname[256]; | |||
| sprintf(unknownname, "unknownncnn_%d", internal_unknown); | |||
| n.name = unknownname; | |||
| internal_unknown++; | |||
| } | |||
| nodes.push_back(n); | |||
| } | |||
| } | |||
| @@ -665,6 +708,8 @@ static int read_code(const std::string& code, std::vector<PyTorchNode>& nodes) | |||
| int main(int argc, char** argv) | |||
| { | |||
| const char* ptpath = "model_base.pt"; | |||
| const char* ncnn_prototxt = "ncnn.param"; | |||
| const char* ncnn_modelbin = "ncnn.bin"; | |||
| std::string code; | |||
| std::string model_json; | |||
| @@ -682,26 +727,153 @@ int main(int argc, char** argv) | |||
| std::vector<PyTorchNode> nodes; | |||
| read_code(code, nodes); | |||
| FILE* pp = fopen(ncnn_prototxt, "wb"); | |||
| FILE* bp = fopen(ncnn_modelbin, "wb"); | |||
| // magic | |||
| fprintf(pp, "7767517\n"); | |||
| const int node_count = nodes.size(); | |||
| fprintf(stderr, "node_count = %d\n", node_count); | |||
| // node reference | |||
| std::map<std::string, int> node_reference; | |||
| // weight node | |||
| std::vector<int> weight_nodes; | |||
| // global definition line | |||
| // [layer count] [blob count] | |||
| std::set<std::string> blob_names; | |||
| for (int i=0; i<node_count; i++) | |||
| { | |||
| PyTorchNode& n = nodes[i]; | |||
| for (int j=0; j<(int)n.outputs.size(); j++) | |||
| { | |||
| blob_names.insert(n.outputs[j]); | |||
| } | |||
| // distinguish weights and inputs | |||
| std::vector<std::string> op_arg_list = n.args; | |||
| std::vector<std::string> input_list; | |||
| std::vector<std::string> arg_list; | |||
| for (int i=0; i<(int)op_arg_list.size(); i++) | |||
| { | |||
| const std::string& arg = op_arg_list[i]; | |||
| if (blob_names.find(arg) == blob_names.end()) | |||
| { | |||
| arg_list.push_back(arg); | |||
| } | |||
| else | |||
| { | |||
| input_list.push_back(arg); | |||
| if (node_reference.find(arg) == node_reference.end()) | |||
| { | |||
| node_reference[arg] = 1; | |||
| } | |||
| else | |||
| { | |||
| node_reference[arg] = node_reference[arg] + 1; | |||
| } | |||
| } | |||
| } | |||
| n.inputs = input_list; | |||
| n.args = arg_list; | |||
| } | |||
| // remove node_reference entry with reference equals to one | |||
| int splitncnn_blob_count = 0; | |||
| std::map<std::string, int>::iterator it = node_reference.begin(); | |||
| while (it != node_reference.end()) | |||
| { | |||
| if (it->second == 1) | |||
| { | |||
| node_reference.erase(it++); | |||
| } | |||
| else | |||
| { | |||
| splitncnn_blob_count += it->second; | |||
| // fprintf(stderr, "%s %d\n", it->first.c_str(), it->second); | |||
| ++it; | |||
| } | |||
| } | |||
| fprintf(pp, "%lu %lu\n", node_count + node_reference.size(), blob_names.size() + splitncnn_blob_count); | |||
| int internal_split = 0; | |||
| for (int i=0; i<node_count; i++) | |||
| { | |||
| const PyTorchNode& n = nodes[i]; | |||
| fprintf(stderr, "op = %s\n", n.op.c_str()); | |||
| fprintf(pp, "%-24s", n.op.c_str()); | |||
| fprintf(pp, " %-24s %d %d", n.name.c_str(), (int)n.inputs.size(), (int)n.outputs.size()); | |||
| for (int j=0; j<(int)n.inputs.size(); j++) | |||
| { | |||
| std::string input_name = n.inputs[j]; | |||
| if (node_reference.find(input_name) != node_reference.end()) | |||
| { | |||
| int refidx = node_reference[input_name] - 1; | |||
| node_reference[input_name] = refidx; | |||
| char splitsuffix[256]; | |||
| sprintf(splitsuffix, "_splitncnn_%d", refidx); | |||
| input_name = input_name + splitsuffix; | |||
| } | |||
| fprintf(pp, " %s", input_name.c_str()); | |||
| } | |||
| for (int j=0; j<(int)n.outputs.size(); j++) | |||
| { | |||
| fprintf(stderr, "output = %s\n", n.outputs[j].c_str()); | |||
| fprintf(pp, " %s", n.outputs[j].c_str()); | |||
| } | |||
| // TODO op specific params | |||
| { | |||
| for (int j=0; j<(int)n.args.size(); j++) | |||
| { | |||
| fprintf(pp, " %s", n.args[j].c_str()); | |||
| } | |||
| } | |||
| for (int j=0; j<(int)n.args.size(); j++) | |||
| fprintf(pp, "\n"); | |||
| for (int j=0; j<(int)n.outputs.size(); j++) | |||
| { | |||
| fprintf(stderr, "arg = %s\n", n.args[j].c_str()); | |||
| const std::string& output_name = n.outputs[j]; | |||
| if (node_reference.find(output_name) != node_reference.end()) | |||
| { | |||
| int refcount = node_reference[output_name]; | |||
| if (refcount > 1) | |||
| { | |||
| char splitname[256]; | |||
| sprintf(splitname, "splitncnn_%d", internal_split); | |||
| fprintf(pp, "%-24s %-24s %d %d", "Split", splitname, 1, refcount); | |||
| fprintf(pp, " %s", output_name.c_str()); | |||
| for (int k=0; k<refcount; k++) | |||
| { | |||
| fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k); | |||
| } | |||
| fprintf(pp, "\n"); | |||
| internal_split++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| fclose(pp); | |||
| fclose(bp); | |||
| return 0; | |||
| } | |||