|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
#include <string> |
|
|
|
#include <memory> |
|
|
|
#include <vector> |
|
|
|
#include <utility> |
|
|
|
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" |
|
|
|
#include "tools/common/converter_op_utils.h" |
|
|
|
@@ -117,48 +118,86 @@ STATUS TransOpInsertPass::FindOutTransType() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { |
|
|
|
if (node == nullptr && node->primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node or primitive null"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto type = node->primitive->value.type; |
|
|
|
if (graph->allTensors.at(node->inputIndex[0])->dims.size() != 4) { |
|
|
|
MS_LOG(ERROR) << "change op axis only support 4 dims"; |
|
|
|
return RET_NOT_SUPPORT; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Concat) { |
|
|
|
auto origin_axis = node->primitive->value.AsConcat()->axis; |
|
|
|
auto axis_map = GetNc2NhAxisMap(); |
|
|
|
node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_StridedSlice) { |
|
|
|
auto attr = node->primitive->value.AsStridedSlice(); |
|
|
|
auto origin_begin = attr->begin; |
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; |
|
|
|
auto origin_end = attr->end; |
|
|
|
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; |
|
|
|
auto origin_stride = attr->stride; |
|
|
|
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Split) { |
|
|
|
auto origin_axis = node->primitive->value.AsSplit()->splitDim; |
|
|
|
auto axis_map = GetNc2NhAxisMap(); |
|
|
|
node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { |
|
|
|
MS_ASSERT(graph != nullptr); |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
auto type = node->primitive->value.type; |
|
|
|
if (!IsContain(GetInsertOpList(), type)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto node_name = node->name; |
|
|
|
if (!CanFusion(graph, node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto ret = FindOutTransType(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "FindOutTransType error"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
// 4 dims means infershape success,can delete |
|
|
|
if (type == PrimitiveType_Concat) { |
|
|
|
if (graph->allTensors.at(node->inputIndex[0])->dims.size() == 4) { |
|
|
|
node->primitive->value.AsConcat()->axis = -1; |
|
|
|
} else { |
|
|
|
bool changed = true; |
|
|
|
int run_counts = 0; |
|
|
|
std::vector<CNodeT *> has_insert_nodes; |
|
|
|
while (changed && run_counts < 10) { |
|
|
|
changed = false; |
|
|
|
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { |
|
|
|
auto &node = *iter; |
|
|
|
auto type = node->primitive->value.type; |
|
|
|
if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
STATUS status = RET_OK; |
|
|
|
auto input_tensor_size = (*iter)->inputIndex.size(); |
|
|
|
for (size_t i = 0; i < input_tensor_size; i++) { |
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
auto node_name = node->name; |
|
|
|
if (!CanFusion(graph, node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size(); |
|
|
|
for (size_t i = 0; i < output_tensor_size; i++) { |
|
|
|
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
auto ret = FindOutTransType(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "FindOutTransType error"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
ret = ChangeOpAxis(graph, node); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ChangeOpAxis error"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
has_insert_nodes.push_back(node.get()); |
|
|
|
STATUS status = RET_OK; |
|
|
|
auto input_tensor_size = (*iter)->inputIndex.size(); |
|
|
|
for (size_t i = 0; i < input_tensor_size; i++) { |
|
|
|
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size(); |
|
|
|
for (size_t i = 0; i < output_tensor_size; i++) { |
|
|
|
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); |
|
|
|
if (status != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
} |
|
|
|
changed = true; |
|
|
|
} |
|
|
|
run_counts++; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|