| @@ -76,6 +76,7 @@ | |||
| #include "layer/shufflechannel.h" | |||
| #include "layer/slice.h" | |||
| #include "layer/softmax.h" | |||
| #include "layer/split.h" | |||
| #include "layer/squeeze.h" | |||
| #include "layer/threshold.h" | |||
| #include "layer/unaryop.h" | |||
| @@ -1885,6 +1886,12 @@ int NetOptimize::fuse_memorydata_binaryop() | |||
| ncnn::MemoryData* memorydata = (ncnn::MemoryData*)layers[i]; | |||
| ncnn::BinaryOp* binaryop = (ncnn::BinaryOp*)layers[j]; | |||
| if (memorydata->w != 1 || memorydata->h != 0 || memorydata->c != 0) | |||
| { | |||
| // not a scalar | |||
| continue; | |||
| } | |||
| int memorydata_index = 1; | |||
| if (binaryop->bottoms[0] == top_blob_index) | |||
| @@ -1915,12 +1922,109 @@ int NetOptimize::fuse_memorydata_binaryop() | |||
| } | |||
| } | |||
| float scalar = memorydata->data[0]; | |||
| binaryop->with_scalar = 1; | |||
| binaryop->b = scalar; | |||
| fprintf(stderr, "fuse_memorydata_binaryop %s %s\n", memorydata->name.c_str(), binaryop->name.c_str()); | |||
| binaryop->bottoms.erase(binaryop->bottoms.begin() + memorydata_index); | |||
| memorydata->type = "ncnnfused"; | |||
| } | |||
| for (int i = 0; i < layer_count; i++) | |||
| { | |||
| if (layers[i]->type != "MemoryData") | |||
| continue; | |||
| // MemoryData - Split - BinaryOp | |||
| int top_blob_index = layers[i]->tops[0]; | |||
| int j0 = i + 1; | |||
| for (; j0 < layer_count; j0++) | |||
| { | |||
| if (layers[j0]->type != "Split") | |||
| continue; | |||
| if (layers[j0]->bottoms.size() != 1) | |||
| continue; | |||
| if (layers[j0]->bottoms[0] == top_blob_index) | |||
| break; | |||
| } | |||
| if (j0 == layer_count) | |||
| continue; | |||
| int split_top_blob_index = -1; | |||
| int j1 = j0 + 1; | |||
| for (; j1 < layer_count; j1++) | |||
| { | |||
| if (layers[j1]->type != "BinaryOp") | |||
| continue; | |||
| if (layers[j1]->bottoms.size() != 2) | |||
| continue; | |||
| for (int k = 0; k < (int)layers[j0]->tops.size(); k++) | |||
| { | |||
| if (layers[j1]->bottoms[0] == layers[j0]->tops[k] || layers[j1]->bottoms[1] == layers[j0]->tops[k]) | |||
| { | |||
| split_top_blob_index = k; | |||
| break; | |||
| } | |||
| } | |||
| if (split_top_blob_index != -1) | |||
| break; | |||
| } | |||
| if (j1 == layer_count) | |||
| continue; | |||
| // fuse MemoryData - Split - BinaryOp to BinaryOp | |||
| ncnn::MemoryData* memorydata = (ncnn::MemoryData*)layers[i]; | |||
| ncnn::Split* split = (ncnn::Split*)layers[j0]; | |||
| ncnn::BinaryOp* binaryop = (ncnn::BinaryOp*)layers[j1]; | |||
| if (memorydata->w != 1 || memorydata->h != 0 || memorydata->c != 0) | |||
| { | |||
| // not a scalar | |||
| continue; | |||
| } | |||
| int memorydata_index = 1; | |||
| if (binaryop->bottoms[0] == split->tops[split_top_blob_index]) | |||
| { | |||
| int op_type = binaryop->op_type; | |||
| if (op_type == ncnn::BinaryOp::Operation_ADD | |||
| || op_type == ncnn::BinaryOp::Operation_MUL | |||
| || op_type == ncnn::BinaryOp::Operation_MAX | |||
| || op_type == ncnn::BinaryOp::Operation_MIN) | |||
| { | |||
| memorydata_index = 0; | |||
| } | |||
| else if (op_type == ncnn::BinaryOp::Operation_SUB) | |||
| { | |||
| binaryop->op_type = ncnn::BinaryOp::Operation_RSUB; | |||
| memorydata_index = 0; | |||
| } | |||
| else if (op_type == ncnn::BinaryOp::Operation_DIV) | |||
| { | |||
| binaryop->op_type = ncnn::BinaryOp::Operation_RDIV; | |||
| memorydata_index = 0; | |||
| } | |||
| else | |||
| { | |||
| // non interchangeable binaryop | |||
| continue; | |||
| } | |||
| } | |||
| float scalar = memorydata->data[0]; | |||
| binaryop->with_scalar = 1; | |||
| @@ -1929,7 +2033,14 @@ int NetOptimize::fuse_memorydata_binaryop() | |||
| fprintf(stderr, "fuse_memorydata_binaryop %s %s\n", memorydata->name.c_str(), binaryop->name.c_str()); | |||
| binaryop->bottoms.erase(binaryop->bottoms.begin() + memorydata_index); | |||
| memorydata->type = "ncnnfused"; | |||
| split->tops.erase(split->tops.begin() + split_top_blob_index); | |||
| if (split->tops.empty()) | |||
| { | |||
| split->type = "ncnnfused"; | |||
| memorydata->type = "ncnnfused"; | |||
| } | |||
| i--; | |||
| } | |||
| return 0; | |||