|
|
|
@@ -124,6 +124,74 @@ STATUS TransOpInsertPass::FindOutTransType() { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) { |
|
|
|
if (origin_attr == nullptr || axes == nullptr || element_size == 0) { |
|
|
|
MS_LOG(INFO) << "Attr data is from other nodes."; |
|
|
|
return; |
|
|
|
} |
|
|
|
auto axis_map = GetNc2NhAxisMap(); |
|
|
|
std::vector<int> cur_attr; |
|
|
|
for (int dim = 0; dim < 4; ++dim) { |
|
|
|
for (int index = 0; index < element_size; ++index) { |
|
|
|
int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; |
|
|
|
if (nhwc_dim == dim || (nhwc_dim + 4) == dim) { |
|
|
|
cur_attr.push_back(origin_attr[index]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (int index = 0; index < element_size; ++index) { |
|
|
|
origin_attr[index] = cur_attr[index]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { |
|
|
|
if (node == nullptr && node->primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node or primitive null"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto type = node->primitive->value.type; |
|
|
|
if (type == PrimitiveType_StridedSlice) { |
|
|
|
// onnx input size is equal to 5 always. |
|
|
|
if (node->inputIndex.size() == 5) { |
|
|
|
for (int index = 1; index < 5; ++index) { |
|
|
|
if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { |
|
|
|
MS_LOG(INFO) << "Here don't consider input is from other nodes."; |
|
|
|
return RET_NOT_SUPPORT; |
|
|
|
} |
|
|
|
} |
|
|
|
int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; |
|
|
|
auto axes = graph->allTensors[node->inputIndex[3]]->data; |
|
|
|
for (int index = 1; index < 5; ++index) { |
|
|
|
TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()), |
|
|
|
reinterpret_cast<int *>(axes.data()), element_num); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Slice) { |
|
|
|
auto attr = node->primitive->value.AsSlice(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr."; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
// transform attr |
|
|
|
attr->format = schema::Format_NHWC; |
|
|
|
if (attr->begin.empty() || attr->size.empty()) { |
|
|
|
MS_LOG(INFO) << "Here don't consider these attr are from other nodes."; |
|
|
|
return RET_NOT_SUPPORT; |
|
|
|
} |
|
|
|
int element_num = attr->begin.size(); |
|
|
|
if (attr->axes.empty()) { |
|
|
|
for (int index = 0; index < element_num; ++index) { |
|
|
|
attr->axes.push_back(index); |
|
|
|
} |
|
|
|
} |
|
|
|
TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num); |
|
|
|
TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num); |
|
|
|
TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { |
|
|
|
if (node == nullptr && node->primitive == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node or primitive null"; |
|
|
|
@@ -152,19 +220,6 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni |
|
|
|
} |
|
|
|
node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_StridedSlice) { |
|
|
|
auto attr = node->primitive->value.AsStridedSlice(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsStridedSlice() is nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto origin_begin = attr->begin; |
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; |
|
|
|
auto origin_end = attr->end; |
|
|
|
attr->end = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; |
|
|
|
auto origin_stride = attr->stride; |
|
|
|
attr->stride = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Split) { |
|
|
|
auto origin_axis = node->primitive->value.AsSplit()->splitDim; |
|
|
|
auto axis_map = GetNc2NhAxisMap(); |
|
|
|
@@ -199,20 +254,8 @@ STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::uni |
|
|
|
} |
|
|
|
node->primitive->value.AsCrop()->offsets = offsets; |
|
|
|
} |
|
|
|
if (type == PrimitiveType_Slice) { |
|
|
|
auto attr = node->primitive->value.AsSlice(); |
|
|
|
if (attr == nullptr) { |
|
|
|
MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto origin_begin = attr->begin; |
|
|
|
attr->begin = {origin_begin[NCHW_N], origin_begin[NCHW_H], origin_begin[NCHW_W], origin_begin[NCHW_C]}; |
|
|
|
auto origin_end = attr->axes; |
|
|
|
if (origin_end.size() >= 4) { |
|
|
|
attr->axes = {origin_end[NCHW_N], origin_end[NCHW_H], origin_end[NCHW_W], origin_end[NCHW_C]}; |
|
|
|
} |
|
|
|
auto origin_stride = attr->size; |
|
|
|
attr->size = {origin_stride[NCHW_N], origin_stride[NCHW_H], origin_stride[NCHW_W], origin_stride[NCHW_C]}; |
|
|
|
if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) { |
|
|
|
return ChangeOpAttrForSlice(graph, node); |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
@@ -245,7 +288,7 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { |
|
|
|
} |
|
|
|
ret = ChangeOpAxis(graph, node); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "ChangeOpAxis error"; |
|
|
|
MS_LOG(INFO) << "no need to ChangeOpAxis"; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
has_insert_nodes.push_back(node.get()); |
|
|
|
@@ -257,6 +300,10 @@ STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { |
|
|
|
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; |
|
|
|
return status; |
|
|
|
} |
|
|
|
if ((*iter)->primitive->value.type == schema::PrimitiveType_StridedSlice || |
|
|
|
(*iter)->primitive->value.type == schema::PrimitiveType_Slice) { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
auto output_tensor_size = (*iter)->outputIndex.size(); |
|
|
|
for (size_t i = 0; i < output_tensor_size; i++) { |
|
|
|
|