|
|
|
@@ -21,148 +21,120 @@ |
|
|
|
#include "tools/common/converter_op_utils.h" |
|
|
|
#include "tools/common/node_util.h" |
|
|
|
#include "utils/log_adapter.h" |
|
|
|
#include "src/common/common.h" |
|
|
|
#include "src/common/utils.h" |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace lite { |
|
|
|
#define kMinInputNum 1 |
|
|
|
#define kOutputNum 1 |
|
|
|
|
|
|
|
STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { |
|
|
|
if (fmkType == converter::FmkType_TF) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
auto status = DoModelInputFormatTrans(graph); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; |
|
|
|
return status; |
|
|
|
} |
|
|
|
status = DoNodeInoutFormatTrans(graph); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; |
|
|
|
return status; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS EltwiseFormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { |
|
|
|
if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
// insert trans node in model input tensor |
|
|
|
if (graph->nodes.empty()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
auto graphInputIdxes = graph->inputIndex; |
|
|
|
for (size_t i = 0; i < graphInputIdxes.size(); i++) { |
|
|
|
auto inputIdx = graphInputIdxes.at(i); |
|
|
|
MS_ASSERT(inputIdx < subGraph->allTensors.size()); |
|
|
|
auto &tensor = graph->allTensors.at(inputIdx); |
|
|
|
if (tensor->dims.size() != kNCHWDimNumber) { |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
if (node->primitive->value.type != PrimitiveType_Eltwise) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { |
|
|
|
if (node->inputIndex.at(inputIndexIdx) == inputIdx) { |
|
|
|
STATUS status = RET_OK; |
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
auto node_name = node->name; |
|
|
|
auto input_node_indexes = GetInputNodeIdx(*graph, *node); |
|
|
|
auto pre_type = schema::PrimitiveType_NONE; |
|
|
|
size_t has_trans_count = 0; |
|
|
|
auto can_fusion = true; |
|
|
|
for (auto input_node_index : input_node_indexes) { |
|
|
|
MS_ASSERT(graph->nodes.size() > input_node_index); |
|
|
|
auto &pre_node = graph->nodes.at(input_node_index); |
|
|
|
MS_ASSERT(pre_node != nullptr); |
|
|
|
if (pre_type == schema::PrimitiveType_NONE) { |
|
|
|
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || |
|
|
|
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { |
|
|
|
pre_type = pre_node->primitive->value.type; |
|
|
|
has_trans_count++; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || |
|
|
|
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { |
|
|
|
if (pre_type != pre_node->primitive->value.type) { |
|
|
|
can_fusion = false; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
has_trans_count++; |
|
|
|
} |
|
|
|
// set first tensor format to nhwc |
|
|
|
auto &transNode = *(iter - 1); |
|
|
|
MS_ASSERT(transNode != nullptr); |
|
|
|
MS_ASSERT(transNode->inputIndex.size() == 1); |
|
|
|
MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); |
|
|
|
auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); |
|
|
|
graphInTensor->format = schema::Format_NHWC; |
|
|
|
// assume parser not reformat shape |
|
|
|
auto oldDims = graphInTensor->dims; |
|
|
|
graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
// inference needed inputFormat: |
|
|
|
// conv deconv depth dedepth |
|
|
|
// fp32 NCHW NCHW NCHW NCHW |
|
|
|
// uint8 NCHW ? NCHW ? |
|
|
|
STATUS EltwiseFormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
// insert before and after the op cal by nchw/nc4hw4 |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
FormatTransNodeType beforeNodeType, afterNodeType; |
|
|
|
if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc |
|
|
|
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use |
|
|
|
// nhwc |
|
|
|
// if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only |
|
|
|
// support nhwc |
|
|
|
// continue; |
|
|
|
// } |
|
|
|
// if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { |
|
|
|
// continue; |
|
|
|
// } |
|
|
|
// } else { |
|
|
|
// if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { |
|
|
|
if (!can_fusion) { |
|
|
|
continue; |
|
|
|
// } |
|
|
|
// } |
|
|
|
// beforeNodeType = kNCHW2NHWC; |
|
|
|
// afterNodeType = kNHWC2NCHW; |
|
|
|
} else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw |
|
|
|
// if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc |
|
|
|
// if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc |
|
|
|
// continue; |
|
|
|
// } |
|
|
|
// } else { |
|
|
|
// continue; |
|
|
|
// } |
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
beforeNodeType = kNCHW2NHWC; |
|
|
|
afterNodeType = kNHWC2NCHW; |
|
|
|
} else if (fmkType == converter::FmkType_MS) { |
|
|
|
if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto output_node_indexes = GetOutputNodeIdx(*graph, *node); |
|
|
|
auto post_type = schema::PrimitiveType_NONE; |
|
|
|
for (auto output_node_index : output_node_indexes) { |
|
|
|
MS_ASSERT(graph->nodes.size() > output_node_index); |
|
|
|
auto &post_node = graph->nodes.at(output_node_index); |
|
|
|
MS_ASSERT(post_node != nullptr); |
|
|
|
if (post_type == schema::PrimitiveType_NONE) { |
|
|
|
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || |
|
|
|
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { |
|
|
|
post_type = post_node->primitive->value.type; |
|
|
|
has_trans_count++; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || |
|
|
|
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { |
|
|
|
if (post_type != post_node->primitive->value.type) { |
|
|
|
can_fusion = false; |
|
|
|
break; |
|
|
|
} else { |
|
|
|
has_trans_count++; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
beforeNodeType = kNCHW2NHWC; |
|
|
|
afterNodeType = kNHWC2NCHW; |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
auto &node = *iter; |
|
|
|
auto nodeName = node->name; |
|
|
|
if (node->inputIndex.size() < kMinInputNum) { |
|
|
|
MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; |
|
|
|
return RET_ERROR; |
|
|
|
if (!can_fusion) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto total_node_count = input_node_indexes.size() + output_node_indexes.size(); |
|
|
|
size_t half_count = total_node_count / 2; |
|
|
|
if (total_node_count % 2 == 0) { |
|
|
|
can_fusion = has_trans_count > half_count; |
|
|
|
} else { |
|
|
|
can_fusion = has_trans_count >= half_count; |
|
|
|
} |
|
|
|
if (node->outputIndex.size() != kOutputNum) { |
|
|
|
MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; |
|
|
|
return RET_ERROR; |
|
|
|
if (!can_fusion) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
STATUS status; |
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; |
|
|
|
return RET_ERROR; |
|
|
|
FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW; |
|
|
|
FormatTransNodeType post_insert_trans_type = kNHWC2NCHW; |
|
|
|
if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) { |
|
|
|
pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; |
|
|
|
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; |
|
|
|
} else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) { |
|
|
|
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; |
|
|
|
post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; |
|
|
|
} else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
if (pre_type == post_type) { |
|
|
|
MS_LOG(ERROR) << "Unknow error"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; |
|
|
|
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; |
|
|
|
} |
|
|
|
|
|
|
|
iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; |
|
|
|
return RET_ERROR; |
|
|
|
STATUS status = RET_OK; |
|
|
|
auto input_tensor_size = (*iter)->inputIndex.size(); |
|
|
|
for (auto i = 0; i < input_tensor_size; i++) { |
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size(); |
|
|
|
for (auto i = 0; i < output_tensor_size; i++) { |
|
|
|
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
@@ -195,6 +167,5 @@ NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph |
|
|
|
void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } |
|
|
|
|
|
|
|
void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } |
|
|
|
|
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |