From: @zhengjun10 Reviewed-by: @HilbertDavid,@hangangqiang Signed-off-by: @HilbertDavidtags/v1.1.0
| @@ -601,6 +601,138 @@ STATUS ValidateFileStr(const std::string &modelFile, std::string fileType) { | |||
| } | |||
| } | |||
| void TransformAttrByAxes(int *origin_attr, int *axes, int element_size) { | |||
| if (origin_attr == nullptr || axes == nullptr || element_size == 0) { | |||
| MS_LOG(INFO) << "Attr data is from other nodes."; | |||
| return; | |||
| } | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| std::vector<int> cur_attr; | |||
| for (int dim = 0; dim < 4; ++dim) { | |||
| for (int index = 0; index < element_size; ++index) { | |||
| int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; | |||
| if (nhwc_dim == dim || (nhwc_dim + 4) == dim) { | |||
| cur_attr.push_back(origin_attr[index]); | |||
| } | |||
| } | |||
| } | |||
| for (int index = 0; index < element_size; ++index) { | |||
| origin_attr[index] = cur_attr[index]; | |||
| } | |||
| } | |||
| STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) { | |||
| auto type = node->primitive->value.type; | |||
| if (type == schema::PrimitiveType_StridedSlice) { | |||
| // onnx input size is equal to 5 always. | |||
| if (node->inputIndex.size() == 5) { | |||
| for (int index = 1; index < 5; ++index) { | |||
| if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { | |||
| MS_LOG(INFO) << "Here don't consider input is from other nodes."; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; | |||
| auto axes = graph->allTensors[node->inputIndex[3]]->data; | |||
| for (int index = 1; index < 5; ++index) { | |||
| TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()), | |||
| reinterpret_cast<int *>(axes.data()), element_num); | |||
| } | |||
| } | |||
| } | |||
| if (type == schema::PrimitiveType_Slice) { | |||
| auto attr = node->primitive->value.AsSlice(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| // transform attr | |||
| attr->format = schema::Format_NHWC; | |||
| if (attr->begin.empty() || attr->size.empty()) { | |||
| MS_LOG(INFO) << "Here don't consider these attr are from other nodes."; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| int element_num = attr->begin.size(); | |||
| if (attr->axes.empty()) { | |||
| for (int index = 0; index < element_num; ++index) { | |||
| attr->axes.push_back(index); | |||
| } | |||
| } | |||
| TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num); | |||
| TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num); | |||
| TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) { | |||
| MS_ASSERT(node->primitive->value != nullptr); | |||
| auto type = node->primitive->value.type; | |||
| auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size(); | |||
| if (input1_ndim != 4 && input1_ndim != 0) { | |||
| if (node->inputIndex.size() > 1) { | |||
| auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size(); | |||
| if (input2_ndim != 4 && input2_ndim != 0) { | |||
| MS_LOG(ERROR) << "change op axis only support 4 dims"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "change op axis only support 4 dims"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| if (type == schema::PrimitiveType_Concat) { | |||
| MS_ASSERT(node->primitive->value.AsConcat() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsConcat()->axis; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsConcat() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; | |||
| } | |||
| if (type == schema::PrimitiveType_Split) { | |||
| MS_ASSERT(node->primitive->value.AsSplit() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsSplit()->splitDim; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsSplit() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; | |||
| } | |||
| if (type == schema::PrimitiveType_Crop) { | |||
| MS_ASSERT(node->primitive->value.AsCrop() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsCrop()->axis; | |||
| auto offsets = node->primitive->value.AsCrop()->offsets; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsCrop() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; | |||
| // nchw->nhwc,offsets need pad 0; | |||
| if (axis_map[origin_axis] == 0) { | |||
| offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | |||
| } else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) { | |||
| // orgin_axis = 2 or orgin_axis = 3 | |||
| offsets.push_back(0); | |||
| } else if (axis_map[origin_axis] == -1) { | |||
| // origin_axis = 1 | |||
| offsets = {offsets[1], offsets[2], offsets[0]}; | |||
| } else { | |||
| // axis error | |||
| MS_LOG(ERROR) << "Crop error"; | |||
| return RET_ERROR; | |||
| } | |||
| node->primitive->value.AsCrop()->offsets = offsets; | |||
| } | |||
| if (type == schema::PrimitiveType_Slice || type == schema::PrimitiveType_StridedSlice) { | |||
| return ChangeOpAttrForSlice(graph, node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| std::string GetModelName(const std::string &modelFile) { | |||
| std::string modelName = modelFile; | |||
| modelName = modelName.substr(modelName.find_last_of('/') + 1); | |||
| @@ -86,6 +86,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz | |||
| std::unique_ptr<schema::CNodeT> toAddNode, STATUS *errorCode, OpDefCopyer opDefCopyer); | |||
| STATUS ValidateFileStr(const std::string &modelFile, std::string fileType); | |||
| void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); | |||
| STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| std::string GetModelName(const std::string &modelFile); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -139,7 +139,8 @@ static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveT | |||
| static const std::vector<schema::PrimitiveType> needInsertOpList = { | |||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, | |||
| schema::PrimitiveType_Power, schema::PrimitiveType_StridedSlice, schema::PrimitiveType_Add, | |||
| schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop}; | |||
| schema::PrimitiveType_Split, schema::PrimitiveType_Slice, schema::PrimitiveType_Crop, | |||
| schema::PrimitiveType_Mul, schema::PrimitiveType_Maximum}; | |||
| static const std::unordered_map<int, int> nc2NhAxisMap = {{0, 0}, {1, -1}, {2, 1}, {3, 2}}; | |||
| @@ -28,6 +28,7 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | |||
| @@ -114,6 +115,10 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| if (ctx.trainModel == false && ctx.fmk != converter::FmkType_ONNX) { | |||
| formatTransOptimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| } | |||
| status = formatTransOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE && status != RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "Run formatTransOptimizer graphPasses Failed"; | |||
| @@ -12,6 +12,7 @@ file(GLOB GRAPH_PASS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc | |||
| ) | |||
| set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| @@ -0,0 +1,197 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h" | |||
| #include <algorithm> | |||
| #include "third_party/securec/include/securec.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| std::set<size_t> need_del_nodes; | |||
| std::set<size_t> need_trans_format_nodes; | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto type = node->primitive->value.type; | |||
| if (type != schema::PrimitiveType_Nchw2Nhwc) { | |||
| continue; | |||
| } | |||
| std::vector<size_t> pre_nh2nc_nodes; | |||
| std::vector<size_t> pre_not_trans_nodes; | |||
| auto status = FindPreNh2NcNodes(graph, iter - graph->nodes.begin(), &pre_nh2nc_nodes, &pre_not_trans_nodes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; | |||
| return status; | |||
| } | |||
| std::copy(pre_nh2nc_nodes.begin(), pre_nh2nc_nodes.end(), std::inserter(need_del_nodes, need_del_nodes.end())); | |||
| std::copy(pre_not_trans_nodes.begin(), pre_not_trans_nodes.end(), | |||
| std::inserter(need_trans_format_nodes, need_trans_format_nodes.end())); | |||
| if (!pre_nh2nc_nodes.empty()) { | |||
| need_del_nodes.insert(iter - graph->nodes.begin()); | |||
| } | |||
| } | |||
| if (need_del_nodes.empty()) { | |||
| return RET_OK; | |||
| } | |||
| for (auto del_node_index : need_del_nodes) { | |||
| auto node_name = graph->nodes.at(del_node_index)->name; | |||
| auto status = IsolateOneWayNode(graph, del_node_index); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Isolate Node failed, node: " << node_name << ", error: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| auto status = TransWeightToNhwc(graph, need_trans_format_nodes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "trans weight to nhwc failed"; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector<int> &pad_dims) { | |||
| if (pad_dims.size() != 4) { | |||
| MS_LOG(ERROR) << "pad dims error"; | |||
| return RET_ERROR; | |||
| } | |||
| auto batch = pad_dims[NCHW_N]; | |||
| auto channel = pad_dims[NCHW_C]; | |||
| auto area = pad_dims[NCHW_H] * pad_dims[NCHW_W]; | |||
| auto size = batch * channel * area; | |||
| auto new_nhwc_data = new (std::nothrow) float[size]; | |||
| if (new_nhwc_data == nullptr) { | |||
| MS_LOG(ERROR) << "create new nhwc data failed"; | |||
| delete[] new_nhwc_data; | |||
| return RET_ERROR; | |||
| } | |||
| memset(new_nhwc_data, 0, sizeof(float) * size); | |||
| auto nchw_data = reinterpret_cast<float *>(tensor->data.data()); | |||
| // nchw to nhwc | |||
| for (auto i = 0; i < batch; i++) { | |||
| float *src_batch = nchw_data + i * channel * area; | |||
| float *dst_batch = new_nhwc_data + i * channel * area; | |||
| for (int j = 0; j < area; ++j) { | |||
| float *src_area = src_batch + i; | |||
| float *dst_area = dst_batch + i * channel; | |||
| for (int k = 0; k < channel; ++k) { | |||
| dst_area[k] = src_area[k * area]; | |||
| } | |||
| } | |||
| } | |||
| memcpy(nchw_data, new_nhwc_data, sizeof(float) * size); | |||
| delete[] new_nhwc_data; | |||
| return RET_OK; | |||
| } | |||
| STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes) { | |||
| if (pre_not_trans_nodes.empty()) { | |||
| return RET_OK; | |||
| } | |||
| for (auto index : pre_not_trans_nodes) { | |||
| auto &cur_node = graph->nodes.at(index); | |||
| // need change axis from nchw to nhwc like concat,slice | |||
| auto ret = ChangeOpAxis(graph, cur_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ChangeOpAxis error"; | |||
| return ret; | |||
| } | |||
| auto node_input_indexs = cur_node->inputIndex; | |||
| for (auto input_index : node_input_indexs) { | |||
| // weight data need trans nhwc layerout | |||
| if (!IsContain(graph->inputIndex, input_index) && | |||
| graph->allTensors.at(input_index)->nodeType == NodeType_ValueNode) { | |||
| auto &weight_tensor = graph->allTensors.at(input_index); | |||
| auto origin_dims = weight_tensor->dims; | |||
| weight_tensor->format = Format_NHWC; | |||
| if (origin_dims.size() > 4) { | |||
| MS_LOG(ERROR) << "tensor origin tensor size error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (origin_dims.size() == 0) { | |||
| continue; | |||
| } | |||
| auto pad_dims = origin_dims; | |||
| if (origin_dims.size() == 1) { | |||
| pad_dims = {1, 1, 1, origin_dims[0]}; | |||
| } else if (origin_dims.size() == 2) { | |||
| pad_dims = {1, 1, origin_dims[0], origin_dims[1]}; | |||
| } else if (origin_dims.size() == 3) { | |||
| pad_dims = {1, origin_dims[0], origin_dims[1], origin_dims[2]}; | |||
| } | |||
| if (ConvertNcTensor2Nh(weight_tensor.get(), pad_dims) != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert nchw to nhwc failed"; | |||
| return RET_ERROR; | |||
| } | |||
| weight_tensor->dims = {pad_dims[NCHW_N], pad_dims[NCHW_H], pad_dims[NCHW_W], pad_dims[NCHW_C]}; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, | |||
| std::vector<size_t> *pre_nh2nc_nodes, | |||
| std::vector<size_t> *pre_not_trans_nodes) { | |||
| MS_ASSERT(graph != nullptr); | |||
| std::vector<size_t> bfs_queue = {nc2nh_index}; | |||
| // find pre node nh2nc start nodes | |||
| while (!bfs_queue.empty()) { | |||
| auto cur_node_index = bfs_queue.back(); | |||
| auto &cur_node = graph->nodes.at(cur_node_index); | |||
| bfs_queue.pop_back(); | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *cur_node); | |||
| for (auto input_node_index : input_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > input_node_index); | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| auto node_type = pre_node->primitive->value.type; | |||
| if (node_type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { | |||
| pre_nh2nc_nodes->emplace_back(input_node_index); | |||
| } | |||
| } else if (IsContain(GetInsertOpList(), node_type)) { | |||
| if (!IsContain(bfs_queue, input_node_index)) { | |||
| bfs_queue.emplace_back(input_node_index); | |||
| } | |||
| // todo multi output,other edge need insert nh2nc node | |||
| auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); | |||
| if ((pre_node_output_indexs.size() != 1) && (node_type == schema::PrimitiveType_Activation)) { | |||
| pre_nh2nc_nodes->clear(); | |||
| pre_not_trans_nodes->clear(); | |||
| return RET_OK; | |||
| } | |||
| } else { | |||
| pre_nh2nc_nodes->clear(); | |||
| pre_not_trans_nodes->clear(); | |||
| return RET_OK; | |||
| } | |||
| if (!IsContain(*pre_not_trans_nodes, cur_node_index) && cur_node_index != nc2nh_index) { | |||
| pre_not_trans_nodes->emplace_back(cur_node_index); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H | |||
| #define MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H | |||
| #include <unordered_map> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/optimizer.h" | |||
| using mindspore::schema::TensorT; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class GlobalFormatTransformPass : public GraphPass { | |||
| public: | |||
| GlobalFormatTransformPass() = default; | |||
| ~GlobalFormatTransformPass() = default; | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| protected: | |||
| STATUS TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes); | |||
| STATUS FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, std::vector<size_t> *to_do_insert_nodes, | |||
| std::vector<size_t> *pre_not_trans_nodes); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H | |||
| @@ -127,146 +127,6 @@ STATUS TransOpInsertPass::FindOutTransType() { | |||
| return RET_OK; | |||
| } | |||
| void TransOpInsertPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) { | |||
| if (origin_attr == nullptr || axes == nullptr || element_size == 0) { | |||
| MS_LOG(INFO) << "Attr data is from other nodes."; | |||
| return; | |||
| } | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| std::vector<int> cur_attr; | |||
| for (int dim = 0; dim < 4; ++dim) { | |||
| for (int index = 0; index < element_size; ++index) { | |||
| int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; | |||
| if (nhwc_dim == dim || (nhwc_dim + 4) == dim) { | |||
| cur_attr.push_back(origin_attr[index]); | |||
| } | |||
| } | |||
| } | |||
| for (int index = 0; index < element_size; ++index) { | |||
| origin_attr[index] = cur_attr[index]; | |||
| } | |||
| } | |||
| STATUS TransOpInsertPass::ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| if (node == nullptr && node->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "node or primitive null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto type = node->primitive->value.type; | |||
| if (type == PrimitiveType_StridedSlice) { | |||
| // onnx input size is equal to 5 always. | |||
| if (node->inputIndex.size() == 5) { | |||
| for (int index = 1; index < 5; ++index) { | |||
| if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { | |||
| MS_LOG(INFO) << "Here don't consider input is from other nodes."; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; | |||
| auto axes = graph->allTensors[node->inputIndex[3]]->data; | |||
| for (int index = 1; index < 5; ++index) { | |||
| TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()), | |||
| reinterpret_cast<int *>(axes.data()), element_num); | |||
| } | |||
| } | |||
| } | |||
| if (type == PrimitiveType_Slice) { | |||
| auto attr = node->primitive->value.AsSlice(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsSlice() is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| // transform attr | |||
| attr->format = schema::Format_NHWC; | |||
| if (attr->begin.empty() || attr->size.empty()) { | |||
| MS_LOG(INFO) << "Here don't consider these attr are from other nodes."; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| int element_num = attr->begin.size(); | |||
| if (attr->axes.empty()) { | |||
| for (int index = 0; index < element_num; ++index) { | |||
| attr->axes.push_back(index); | |||
| } | |||
| } | |||
| TransformAttrByAxes(attr->begin.data(), attr->axes.data(), element_num); | |||
| TransformAttrByAxes(attr->size.data(), attr->axes.data(), element_num); | |||
| TransformAttrByAxes(attr->axes.data(), attr->axes.data(), element_num); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TransOpInsertPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| if (node == nullptr && node->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "node or primitive null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| MS_ASSERT(node->primitive->value != nullptr); | |||
| auto type = node->primitive->value.type; | |||
| auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size(); | |||
| if (input1_ndim != 4) { | |||
| if (node->inputIndex.size() > 1) { | |||
| auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size(); | |||
| if (input2_ndim != 4 && input2_ndim != 0) { | |||
| MS_LOG(ERROR) << "change op axis only support 4 dims"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| MS_LOG(ERROR) << "change op axis only support 4 dims"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| if (type == PrimitiveType_Concat) { | |||
| MS_ASSERT(node->primitive->value.AsConcat() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsConcat()->axis; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsConcat() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsConcat()->axis = axis_map[origin_axis]; | |||
| } | |||
| if (type == PrimitiveType_Split) { | |||
| MS_ASSERT(node->primitive->value.AsSplit() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsSplit()->splitDim; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsSplit() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsSplit()->splitDim = axis_map[origin_axis]; | |||
| } | |||
| if (type == PrimitiveType_Crop) { | |||
| MS_ASSERT(node->primitive->value.AsCrop() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsCrop()->axis; | |||
| auto offsets = node->primitive->value.AsCrop()->offsets; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsCrop() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; | |||
| // nchw->nhwc,offsets need pad 0; | |||
| if (axis_map[origin_axis] == 0) { | |||
| offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | |||
| } else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) { | |||
| // orgin_axis = 2 or orgin_axis = 3 | |||
| offsets.push_back(0); | |||
| } else if (axis_map[origin_axis] == -1) { | |||
| // origin_axis = 1 | |||
| offsets = {offsets[1], offsets[2], offsets[0]}; | |||
| } else { | |||
| // axis error | |||
| MS_LOG(ERROR) << "Crop error"; | |||
| return RET_ERROR; | |||
| } | |||
| node->primitive->value.AsCrop()->offsets = offsets; | |||
| } | |||
| if (type == PrimitiveType_Slice || type == PrimitiveType_StridedSlice) { | |||
| return ChangeOpAttrForSlice(graph, node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| bool changed = true; | |||
| @@ -41,8 +41,6 @@ class TransOpInsertPass : public FormatTransPass { | |||
| STATUS ChangeOpAttrForSlice(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node); | |||
| STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node); | |||
| private: | |||
| FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; | |||
| FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; | |||