diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 9da88a390b..233e704721 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -34,7 +34,6 @@ #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/gllo_utils.h" #include "tools/converter/quantizer/bitpacking.h" -#include "src/tensor.h" #include "src/common/utils.h" #include "ops/op_utils.h" #include "tools/common/graph_util.h" @@ -80,159 +79,8 @@ std::list GetOrderedCNodes(const FuncGraphPtr fg) { } return cnodes; } -STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) { - auto data_type = tensor_info->data_type(); - if (data_type != kObjectTypeString) { - MS_LOG(ERROR) << "This function only used for string tensor."; - return RET_ERROR; - } - shape_vector->clear(); - auto tensor_data = reinterpret_cast(tensor_info->data_c()); - std::string shape_str; - std::string shape_size_str; - *offset = 0; - size_t cnt = 0; - for (; *offset < tensor_info->Size(); (*offset)++) { - if (tensor_data[*offset] == ',') { - (*offset)++; - break; - } - shape_size_str.push_back(tensor_data[*offset]); - } - if (*offset == 0) { - MS_LOG(ERROR) << "string tensor's dim size not found."; - return RET_ERROR; - } - size_t shape_size = std::stoi(shape_size_str); - for (; *offset < tensor_info->Size(); (*offset)++) { - if (tensor_data[*offset] == ',') { - cnt++; - shape_vector->push_back(std::stoi(shape_str)); - shape_str.clear(); - } else { - shape_str.push_back(tensor_data[*offset]); - } - if (cnt == shape_size) { - (*offset)++; - break; - } - } - if (shape_vector->empty()) { - MS_LOG(ERROR) << "string tensor's shape shouldn't be empty."; - return RET_ERROR; - } - return RET_OK; -} -schema::Format GetFormatByFmk(int32_t fmk_type) { - switch (fmk_type) { - case converter::FmkType_ONNX: - case lite::converter::FmkType_CAFFE: - case lite::converter::FmkType_MS: - return schema::Format_NCHW; - case lite::converter::FmkType_TF: - case lite::converter::FmkType_TFLITE: - return schema::Format_NHWC; - default: - MS_LOG(ERROR) << "don't support current fmk: " + fmk_type; - return static_cast(fmk_type); - } -} - -STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) { - auto abstract_base = param_node->abstract(); - if (abstract_base == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); - return RET_PARAM_INVALID; - } - if (!utils::isa(abstract_base)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); - return RET_INPUT_TENSOR_ERROR; - } - auto abstract_tensor = utils::cast(abstract_base); - auto typePtr = abstract_tensor->element()->GetTypeTrack(); - MS_ASSERT(typePtr != nullptr); - *data_type = typePtr->type_id(); - if (!utils::isa(abstract_tensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); - return RET_PARAM_INVALID; - } - *shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); - return RET_OK; -} } // namespace -void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { - bool has_make_tuple = false; - std::vector inputs; - inputs.clear(); - - inputs.emplace_back(cnode->input(0)); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - AnfNodePtr input_node = cnode->input(i); - if (!input_node->isa()) { - inputs.emplace_back(cnode->input(i)); - continue; - } - auto make_tuple_node = utils::cast(input_node); - auto value_node = make_tuple_node->input(0)->cast(); - if (value_node == nullptr) { - MS_LOG(ERROR) << "value node is invalid."; - return; - } - if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || - opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { - has_make_tuple = true; - for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { - inputs.emplace_back(make_tuple_node->input(j)); - } - } else { - inputs.emplace_back(cnode->input(i)); - } - } - if (has_make_tuple) { - cnode->set_inputs(inputs); - } -} - -void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { - bool has_depend = false; - std::vector inputs; - inputs.clear(); - - inputs.emplace_back(cnode->input(0)); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - AnfNodePtr inputNode = cnode->input(i); - if (!inputNode->isa()) { - inputs.emplace_back(cnode->input(i)); - continue; - } - auto depend_node = utils::cast(inputNode); - auto value_node = depend_node->input(0)->cast(); - if (value_node == nullptr) { - MS_LOG(ERROR) << "value node is invalid."; - return; - } - if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { - has_depend = true; - bool mask_out = (depend_node->inputs().size() == 3); - for (size_t j = 1; j < depend_node->inputs().size(); ++j) { - AnfNodePtr depend_input_node = depend_node->input(j); - if (depend_input_node->isa()) { - inputs.emplace_back(depend_input_node); - if (mask_out) { - break; - } - } - } - } else { - inputs.emplace_back(cnode->input(i)); - } - } - if (has_depend) { - cnode->set_inputs(inputs); - } -} - int AnfExporter::SetQuantOutputTensorType(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node) { @@ -653,283 +501,58 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr &input_anode, return RET_OK; } -int AnfExporter::ConvertInputParameter(const std::shared_ptr &input_anode, - const std::shared_ptr &primitive_c, +int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, const std::unique_ptr &meta_graphT, - schema::CNodeT *output_cnode) { - auto param_node = input_anode->cast(); + schema::CNodeT *op_node) { + auto param_node = cnode->input(index)->cast(); + MS_ASSERT(param_node != nullptr); std::string input_name = param_node->fullname_with_scope(); if (node_id_map_.find(input_name) != node_id_map_.end()) { - output_cnode->inputIndex.emplace_back(node_id_map_[param_node->name()]); + op_node->inputIndex.emplace_back(node_id_map_[param_node->name()]); return RET_OK; } - auto schema_tensor = std::make_unique(); - schema_tensor->format = GetFormatByFmk(meta_graphT->fmkType); - if (schema_tensor->format != schema::Format_NHWC && schema_tensor->format != schema::Format_NCHW) { - MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format; + DataInfo data_info; + if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info) != + RET_OK) { + MS_LOG(ERROR) << "parse const node failed."; return RET_ERROR; } - - // attr weightFormat is only used by conv-like ops' second input - if (output_cnode->inputIndex.size() == 1 && primitive_c->GetAttr(opt::kWeightFormat) != nullptr) { - schema_tensor->format = static_cast(GetValue(primitive_c->GetAttr(opt::kWeightFormat))); - } + auto schema_tensor = std::make_unique(); + schema_tensor->format = static_cast(data_info.format_); schema_tensor->name = param_node->name(); - ShapeVector shape_vector; - TypeId data_type; - auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); - if (status != RET_OK) { - MS_LOG(ERROR) << "get data type and shape from param node failed."; - return RET_ERROR; - } - schema_tensor->dataType = data_type; - auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); - size_t offset = 0; - if (!shape_vector.empty() && schema_tensor->dataType == kObjectTypeString) { - status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); - if (status != RET_OK) { - MS_LOG(ERROR) << "get shape vector from string tensor failed."; - return RET_ERROR; - } - } - std::vector dims; - (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), - [](const int64_t &value) { return static_cast(value); }); - schema_tensor->dims = dims; - if (tensor_info != nullptr && tensor_info->Size() != 0) { - if (schema_tensor->dataType == kObjectTypeTensorType && shape_vector.empty() && - meta_graphT->fmkType == converter::FmkType_ONNX) { - schema_tensor->data.resize(0); - } else { - schema_tensor->data.resize(tensor_info->Size() - offset); - if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(), - static_cast(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { - MS_LOG(ERROR) << "memcpy_s failed."; - return RET_ERROR; - } - } - } - schema_tensor->name = input_name; - QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr - ? nullptr - : primitive_c->GetAttr("quant_params")->cast(); - if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && - schema_tensor->dataType == kNumberTypeInt8) { - schema_tensor->enableHuffmanCode = true; - } + schema_tensor->dims = data_info.shape_; + schema_tensor->dataType = data_info.data_type_; + schema_tensor->data = data_info.data_; + schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_; + node_id_map_[input_name] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); return RET_OK; } -int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, const std::shared_ptr &primitive, - schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { - auto valueAbstract = value_node->abstract(); - auto abstract_tensor = utils::cast(valueAbstract); - if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { - MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; - return RET_ERROR; - } - auto typePtr = abstract_tensor->element()->GetTypeTrack(); - (*schema_tensor)->dataType = typePtr->type_id(); - auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); - std::vector dims; - (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), - [](const int64_t &value) { return static_cast(value); }); - (*schema_tensor)->dims = dims; - if (train_flag_ && (*schema_tensor)->dims.empty()) (*schema_tensor)->dims = {1}; - (*schema_tensor)->nodeType = NodeType_ValueNode; - auto data = value->cast(); - (*schema_tensor)->data.resize(data->Size()); - (*schema_tensor)->format = schema::Format_NHWC; - - (*schema_tensor)->format = GetFormatByFmk(meta_graphT->fmkType); - if ((*schema_tensor)->format != schema::Format_NHWC && (*schema_tensor)->format != schema::Format_NCHW) { - MS_LOG(ERROR) << "schema tensor format is wrong, " << (*schema_tensor)->format; - return RET_ERROR; - } - - // process weight tensor - if (data->Size() > 0) { - if (memcpy_s((*schema_tensor)->data.data(), (*schema_tensor)->data.size(), data->data_c(), data->Size()) != EOK) { - MS_LOG(ERROR) << "memcpy_s error."; - return RET_ERROR; - } - - if (primitive->GetAttr(opt::kWeightFormat) != nullptr) { - (*schema_tensor)->format = static_cast(GetValue(primitive->GetAttr(opt::kWeightFormat))); - } - } - - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); - return RET_OK; -} -int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT) { - int ret; - // data of int64 is converted to int32 here. - (*schema_tensor)->dataType = kNumberTypeInt32; - (*schema_tensor)->dims = {1}; - (*schema_tensor)->nodeType = NodeType_ValueNode; - int real_data = opt::CastToInt(value).front(); - (*schema_tensor)->data.resize(sizeof(int32_t)); - ret = memcpy_s((*schema_tensor)->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); - if (ret != EOK) { - MS_LOG(ERROR) << "memcpy_s error."; - return RET_ERROR; - } - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); - return ret; -} -void AnfExporter::ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT) { - auto valueAbstract = value_node->abstract(); - auto abstractScalar = utils::cast(valueAbstract); - auto typePtr = abstractScalar->GetTypeTrack(); - (*schema_tensor)->dataType = typePtr->type_id(); - (*schema_tensor)->dims = {1}; - (*schema_tensor)->nodeType = NodeType_ValueNode; - auto data = value->cast(); - (*schema_tensor)->data.emplace_back(data->value()); - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); -} -int AnfExporter::ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, - schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { - auto data = value_node->value()->cast(); - schema_tensor->data.resize(sizeof(int)); - int number_type = data->number_type(); - if (EOK != ::memcpy_s(schema_tensor->data.data(), sizeof(int), &number_type, sizeof(int))) { - MS_LOG(ERROR) << "memcpy_s failed"; - return RET_MEMORY_FAILED; - } - schema_tensor->dataType = kNumberTypeInt32; - schema_tensor->dims = {1}; - schema_tensor->nodeType = NodeType_ValueNode; - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(schema_tensor); - return RET_OK; -} -void AnfExporter::ProcessInt(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT) { - (*schema_tensor)->dataType = kNumberTypeInt32; - (*schema_tensor)->dims = {1}; - (*schema_tensor)->nodeType = NodeType_ValueNode; - (*schema_tensor)->data.emplace_back(kNumberTypeInt32); - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); -} -int AnfExporter::ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT) { - int ret = RET_OK; - auto valueAbstract = value_node->abstract(); - auto abstractSequnce = utils::cast(valueAbstract); - if (abstractSequnce->isa()) { - auto abstractTuple = utils::cast(valueAbstract); - auto x_shape_data = abstractTuple->elements(); - std::vector shape; - for (std::size_t i = 0; i < abstractTuple->size(); ++i) { - auto value_track = x_shape_data[i]->GetValueTrack(); - MS_ASSERT(value_track != nullptr); - if (value_track->isa()) { - shape.push_back((GetValue(value_track))); - } else if (value_track->isa()) { - shape.push_back((GetValue(value_track))); - } else { - MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; - return RET_ERROR; - } - } - (*schema_tensor)->dataType = kNumberTypeInt32; - (*schema_tensor)->dims = {static_cast(shape.size())}; - (*schema_tensor)->nodeType = NodeType_ValueNode; - (*schema_tensor)->data.resize(shape.size() * sizeof(int)); - if (!shape.empty()) { - if (EOK != memcpy_s((*schema_tensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(), - shape.size() * sizeof(int32_t))) { - MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed."; - return RET_MEMORY_FAILED; - } - } - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); - } - return ret; -} - -int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT) { - auto tensor_info = std::dynamic_pointer_cast(value); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "Input value is not a tensor"; - return RET_INPUT_PARAM_INVALID; - } - auto ret = UpdateTensorTFromTensorInfo(tensor_info, schema_tensor); - if (ret != RET_OK) { - MS_LOG(ERROR) << "UpdateTensorTFromTensorInfo failed"; - return ret; - } - if (train_flag_ && (*schema_tensor)->dims.empty()) { - (*schema_tensor)->dims = {1}; - } - - node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); - output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); - meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); - return ret; -} - -int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_anode, - const std::shared_ptr &primitive, +int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, const std::unique_ptr &meta_graphT, - schema::CNodeT *output_cnode) { - auto value_node = input_anode->cast(); - auto schema_tensor = std::make_unique(); - auto value = value_node->value(); - int ret = RET_OK; - - if (train_flag_) { - schema_tensor->name = value_node->fullname_with_scope(); - } - if (value->isa()) { - ret = ProcessTensor(value_node, &schema_tensor, value, primitive, output_cnode, meta_graphT); - } else if (value->isa() || value->isa()) { - ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT); - } else if (value->isa()) { - ProcessBoolImm(value_node, &schema_tensor, value, output_cnode, meta_graphT); - } else if (value->isa()) { - ProcessInt(value_node, &schema_tensor, output_cnode, meta_graphT); - } else if (value->isa()) { - ret = ProcessValueSequence(value_node, &schema_tensor, value, output_cnode, meta_graphT); - } else if (value->isa()) { - ret = ProcessNumber(value_node, schema_tensor.release(), output_cnode, meta_graphT); - } else if (value->isa()) { - ret = ProcessTensorInfo(value_node, &schema_tensor, value, output_cnode, meta_graphT); - } else if (value->isa()) { - MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; + schema::CNodeT *op_node) { + DataInfo data_info; + auto status = FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info); + if (status == RET_NO_CHANGE) { return RET_OK; - } else if (value->isa()) { - MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is Monad"; - return RET_OK; - } else { - MS_LOG(ERROR) << "Not support value type , need add support."; - return RET_ERROR; } - return ret; + if (status != RET_OK) { + MS_LOG(ERROR) << "parse value node failed."; + return status; + } + auto schema_tensor = std::make_unique(); + schema_tensor->name = cnode->input(index)->fullname_with_scope(); + schema_tensor->format = static_cast(data_info.format_); + schema_tensor->dataType = data_info.data_type_; + schema_tensor->dims = data_info.shape_; + schema_tensor->data = data_info.data_; + node_id_map_[cnode->input(index)->fullname_with_scope()] = meta_graphT->allTensors.size(); + op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); + return RET_OK; } int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, @@ -954,7 +577,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrisa()) { - auto ret = ConvertInputParameter(input_node, primitive_c, meta_graphT, fb_node); + auto ret = ConvertInputParameter(cnode, i, primitive_c, meta_graphT, fb_node); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvertInputParameter failed"; return ret; @@ -963,7 +586,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptrisa()) { - auto ret = ConvertInputValueNode(input_node, primitive_c, meta_graphT, fb_node); + auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node); if (ret != RET_OK) { MS_LOG(ERROR) << "ConvertInputValueNode failed"; return RET_ERROR; diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 87f697df1a..ed20096e78 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -24,8 +24,10 @@ #include "schema/inner/model_generated.h" #include "ops/primitive_c.h" #include "ir/func_graph.h" +#include "tools/anf_exporter/fetch_content.h" #include "tools/converter/converter_context.h" #include "tools/converter/converter_flags.h" +#include "tools/optimizer/common/gllo_utils.h" using mindspore::ops::PrimitiveC; @@ -44,35 +46,14 @@ class AnfExporter { schema::CNodeT *fb_node); int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, schema::CNodeT *fb_node); - static void RemoveIfMakeTuple(const CNodePtr &cnode); - static void RemoveIfDepend(const CNodePtr &cnode); protected: int ConvertInputCNode(const std::shared_ptr &input_anode, schema::CNodeT *output_cnode); int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode); - int ConvertInputParameter(const std::shared_ptr &input_anode, const std::shared_ptr &primitive, - const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); - int ConvertInputValueNode(const std::shared_ptr &input_anode, const std::shared_ptr &primitive, - const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); - int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, const std::shared_ptr &primitive, - schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); - void ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); - void ProcessInt(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - schema::CNodeT *output_cnode, const std::unique_ptr &meta_graphT); - int ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); - int ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); - int ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr *schema_tensor, - const std::shared_ptr &value, schema::CNodeT *output_cnode, - const std::unique_ptr &meta_graphT); + int ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, + const std::unique_ptr &meta_graphT, schema::CNodeT *op_node); + int ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, + const std::unique_ptr &meta_graphT, schema::CNodeT *op_node); int SetGraphInputIndex(const std::unique_ptr &meta_graphT, const size_t &subgraph_index); int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index, const std::unique_ptr &meta_graphT, schema::CNodeT *return_node); diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.cc b/mindspore/lite/tools/anf_exporter/fetch_content.cc new file mode 100644 index 0000000000..fd86870372 --- /dev/null +++ b/mindspore/lite/tools/anf_exporter/fetch_content.cc @@ -0,0 +1,437 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/anf_exporter/fetch_content.h" +#include +#include +#include +#include +#include "tools/converter/quant_param_holder.h" +#include "tools/optimizer/common/gllo_utils.h" +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace lite { +namespace { +constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t); +static const std::unordered_map TypeToTypeMap = { + {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}}; +STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) { + MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr); + auto data_type = tensor_info->data_type(); + if (data_type != kObjectTypeString) { + MS_LOG(ERROR) << "This function only used for string tensor."; + return RET_ERROR; + } + shape_vector->clear(); + auto tensor_data = reinterpret_cast(tensor_info->data_c()); + std::string shape_str; + std::string shape_size_str; + *offset = 0; + size_t cnt = 0; + for (; *offset < tensor_info->Size(); (*offset)++) { + if (tensor_data[*offset] == ',') { + (*offset)++; + break; + } + shape_size_str.push_back(tensor_data[*offset]); + } + if (*offset == 0) { + MS_LOG(ERROR) << "string tensor's dim size not found."; + return RET_ERROR; + } + size_t shape_size = std::stoi(shape_size_str); + for (; *offset < tensor_info->Size(); (*offset)++) { + if (tensor_data[*offset] == ',') { + cnt++; + shape_vector->push_back(std::stoi(shape_str)); + shape_str.clear(); + } else { + shape_str.push_back(tensor_data[*offset]); + } + if (cnt == shape_size) { + (*offset)++; + break; + } + } + if (shape_vector->empty()) { + MS_LOG(ERROR) << "string tensor's shape shouldn't be empty."; + return RET_ERROR; + } + return RET_OK; +} +int GetFormatByFmk(int32_t fmk_type) { + switch (fmk_type) { + case converter::FmkType_ONNX: + case lite::converter::FmkType_CAFFE: + case lite::converter::FmkType_MS: + return mindspore::NCHW; + case lite::converter::FmkType_TF: + case lite::converter::FmkType_TFLITE: + return mindspore::NHWC; + default: + return -1; + } +} + +STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) { + MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr); + auto abstract_base = param_node->abstract(); + if (abstract_base == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); + return RET_PARAM_INVALID; + } + if (!utils::isa(abstract_base)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); + return RET_INPUT_TENSOR_ERROR; + } + auto abstract_tensor = utils::cast(abstract_base); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + MS_ASSERT(typePtr != nullptr); + *data_type = typePtr->type_id(); + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); + return RET_PARAM_INVALID; + } + *shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); + return RET_OK; +} + +int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type, + bool train_flag, DataInfo *data_info) { + MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); + auto valueAbstract = value_node->abstract(); + auto abstract_tensor = utils::cast(valueAbstract); + if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { + MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; + return RET_ERROR; + } + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + data_info->data_type_ = typePtr->type_id(); + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); + std::vector dims(shape_vector.begin(), shape_vector.end()); + data_info->shape_ = dims; + if (train_flag && dims.empty()) { + data_info->shape_ = {1}; + } + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + auto data = value->cast(); + data_info->data_.resize(data->Size()); + data_info->format_ = GetFormatByFmk(fmk_type); + if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { + MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; + return RET_ERROR; + } + + // process weight tensor + if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) { + MS_LOG(ERROR) << "memcpy_s error."; + return RET_ERROR; + } + return RET_OK; +} + +int FetchFromInt32OrInt64ImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { + MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); + // data of int64 is converted to int32 here. + data_info->data_type_ = kNumberTypeInt32; + data_info->shape_ = {1}; + data_info->data_.resize(sizeof(int32_t)); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + int real_data = opt::CastToInt(value).front(); + if (memcpy_s(data_info->data_.data(), sizeof(int32_t), &real_data, sizeof(int32_t)) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_MEMORY_FAILED; + } + return RET_OK; +} + +int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { + MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); + data_info->data_type_ = kNumberTypeBool; + data_info->shape_ = {1}; + data_info->data_.resize(sizeof(bool)); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + auto data = value->cast(); + auto data_value = data->value(); + if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_MEMORY_FAILED; + } + return RET_OK; +} + +int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { + MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); + data_info->data_type_ = kNumberTypeInt32; + data_info->shape_ = {1}; + data_info->data_.resize(sizeof(int)); + auto data = value_node->value()->cast(); + int number_type = data->number_type(); + if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) { + number_type = TypeToTypeMap.at(number_type); + } + if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) { + MS_LOG(ERROR) << "memcpy_s failed"; + return RET_MEMORY_FAILED; + } + return RET_OK; +} + +int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { + MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); + auto value = value_node->value(); + MS_ASSERT(value != nullptr); + std::vector shape; + auto value_seq = value->cast(); + MS_ASSERT(value_seq != nullptr); + if (!value_seq->value().empty()) { + if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 || + value_seq->value().front()->type()->number_type() == kNumberTypeInt) { + shape = GetValue>(value); + } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) { + auto origin_value = GetValue>(value); + std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape), + [](int64_t val) { return static_cast(val); }); + } else { + MS_LOG(ERROR) << "Value type is ValueSequence is not integer."; + return RET_ERROR; + } + } + data_info->data_type_ = kNumberTypeInt32; + data_info->shape_ = {static_cast(shape.size())}; + data_info->data_.resize(shape.size() * sizeof(int)); + if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(), + shape.size() * sizeof(int32_t)) != EOK) { + MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed."; + return RET_ERROR; + } + return RET_OK; +} +} // namespace + +int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info) { + MS_ASSERT(cnode != nullptr && data_info != nullptr); + auto param_node = cnode->input(index)->cast(); + data_info->format_ = GetFormatByFmk(fmk_type); + if (data_info->format_ < 0) { + MS_LOG(ERROR) << "don't support current fmk: " << fmk_type; + return lite::RET_ERROR; + } + if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { + MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; + return RET_ERROR; + } + + // attr weightFormat is only used by conv-like ops' second input + auto prim = GetValueNode(cnode->input(0)); + if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { + data_info->format_ = GetValue(prim->GetAttr(opt::kWeightFormat)); + } + ShapeVector shape_vector; + TypeId data_type; + auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); + if (status != RET_OK) { + MS_LOG(ERROR) << "get data type and shape from param node failed."; + return RET_ERROR; + } + data_info->data_type_ = data_type; + auto tensor_info = std::dynamic_pointer_cast(param_node->default_param()); + size_t offset = 0; + if (!shape_vector.empty() && data_type == kObjectTypeString) { + status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); + if (status != RET_OK) { + MS_LOG(ERROR) << "get shape vector from string tensor failed."; + return RET_ERROR; + } + } + std::vector dims(shape_vector.begin(), shape_vector.end()); + data_info->shape_ = dims; + if (tensor_info != nullptr && tensor_info->Size() != 0) { + if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) { + data_info->data_.resize(tensor_info->Size() - offset); + if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(), + static_cast(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { + MS_LOG(ERROR) << "memcpy_s failed."; + return RET_ERROR; + } + } + } + QuantParamHolderPtr quant_param_holder = + prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast(); + if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) { + data_info->enable_huffman_code_ = true; + } + data_info->node_type_ = NodeType_ValueNode; + return RET_OK; +} + +int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info) { + MS_ASSERT(cnode != nullptr && data_info != nullptr); + auto value_node = cnode->input(index)->cast(); + auto value = value_node->value(); + int ret = RET_OK; + auto prim = GetValueNode(cnode->input(0)); + MS_ASSERT(prim != nullptr); + if (value->isa()) { + ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info); + if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { + data_info->format_ = GetValue(prim->GetAttr(opt::kWeightFormat)); + } + } else if (value->isa() || value->isa()) { + ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info); + } else if (value->isa()) { + ret = FetchFromBoolImmValue(value_node, prim, data_info); + } else if (value->isa()) { + ret = FetchFromSequenceValue(value_node, prim, data_info); + } else if (value->isa()) { + ret = FetchFromNumberValue(value_node, prim, data_info); + } else if (value->isa()) { + MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph"; + return RET_NO_CHANGE; + } else if (value->isa()) { + MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad"; + return RET_NO_CHANGE; + } else { + MS_LOG(ERROR) << "Not support value type , need add support."; + return RET_ERROR; + } + data_info->node_type_ = NodeType_ValueNode; + return ret; +} + +int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info) { + MS_ASSERT(cnode != nullptr && data_info != nullptr); + auto abstract = opt::GetCNodeInputAbstract(cnode, index); + if (abstract == nullptr) { + MS_LOG(ERROR) << "Abstract cnode is nullptr."; + return RET_ERROR; + } + if (!utils::isa(abstract)) { + MS_LOG(ERROR) << "Abstract should be anstract tensor."; + return RET_ERROR; + } + auto abstract_tensor = utils::cast(abstract); + auto type_ptr = abstract_tensor->element()->GetTypeTrack(); + MS_ASSERT(typePtr != nullptr); + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr."; + return RET_ERROR; + } + auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); + std::vector dims(shape_vector.begin(), shape_vector.end()); + data_info->format_ = mindspore::NHWC; + data_info->data_type_ = type_ptr->type_id(); + data_info->shape_ = dims; + data_info->node_type_ = NodeType_CNode; + if (type_ptr->type_id() == kObjectTypeTensorType) { + auto tensor_info = abstract_tensor->GetValueTrack(); + if (tensor_info == nullptr || !utils::isa(tensor_info)) { + MS_LOG(ERROR) << "tensor info is invalid."; + return RET_ERROR; + } + auto tensor_value = tensor_info->cast(); + if (tensor_value->Size() >= kTensorListMinSize) { + data_info->data_.resize(tensor_value->Size()); + if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) != + EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + return RET_ERROR; + } + } + } + return RET_OK; +} + +void RemoveIfDepend(const CNodePtr &cnode) { + bool has_depend = false; + std::vector inputs; + inputs.clear(); + + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto depend_node = utils::cast(inputNode); + auto value_node = depend_node->input(0)->cast(); + if (value_node == nullptr) { + MS_LOG(ERROR) << "value node is invalid."; + return; + } + if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { + has_depend = true; + bool mask_out = (depend_node->inputs().size() == 3); + for (size_t j = 1; j < depend_node->inputs().size(); ++j) { + AnfNodePtr depend_input_node = depend_node->input(j); + if (depend_input_node->isa()) { + inputs.emplace_back(depend_input_node); + if (mask_out) { + break; + } + } + } + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (has_depend) { + cnode->set_inputs(inputs); + } +} + +void RemoveIfMakeTuple(const CNodePtr &cnode) { + bool has_make_tuple = false; + std::vector inputs; + inputs.clear(); + + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr input_node = cnode->input(i); + if (!input_node->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto make_tuple_node = utils::cast(input_node); + auto value_node = make_tuple_node->input(0)->cast(); + if (value_node == nullptr) { + MS_LOG(ERROR) << "value node is invalid."; + return; + } + if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || + opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { + has_make_tuple = true; + for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { + inputs.emplace_back(make_tuple_node->input(j)); + } + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (has_make_tuple) { + cnode->set_inputs(inputs); + } +} +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/anf_exporter/fetch_content.h b/mindspore/lite/tools/anf_exporter/fetch_content.h new file mode 100644 index 0000000000..ebddb18077 --- /dev/null +++ b/mindspore/lite/tools/anf_exporter/fetch_content.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_ +#define MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_ + +#include +#include +#include "ir/primitive.h" +#include "ir/func_graph.h" +#include "src/common/utils.h" +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +struct DataInfo { + bool enable_huffman_code_; + int format_; + int data_type_; + int node_type_; + std::vector shape_; + std::vector data_; + DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0) {} +}; +int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info); +int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info); +int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, + DataInfo *data_info); +void RemoveIfDepend(const CNodePtr &cnode); + +void RemoveIfMakeTuple(const CNodePtr &cnode); +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_ diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 0f23eb3003..21499888da 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -168,9 +168,7 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); const_fold_pm->AddPass(std::make_shared()); if (!config->trainModel) { - auto inne_context_ptr = std::make_shared(); - inne_context_ptr->Init(); - const_fold_pm->AddPass(std::make_shared(inne_context_ptr)); + const_fold_pm->AddPass(std::make_shared(config->fmk)); } auto update_conv2d_param_pass = std::make_shared(); update_conv2d_param_pass->SetFmkType(config->fmk); diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 965027fa79..d612ffddb8 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -21,15 +21,10 @@ #include "src/common/log_adapter.h" #include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" -#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" -#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/infershape_pass.h" #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h" -#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" -#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" -#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" @@ -129,7 +124,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { auto old_nodes = GetGraphNodes(); Optimizer format_trans_optimizer; if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { - format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt index 6ac4737aba..2d64466616 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt @@ -1,12 +1,8 @@ file(GLOB FUSION_SRC ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/mul_add_fusion_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc ) set_property(SOURCE ${FUSION_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(fusion_mid OBJECT ${FUSION_SRC}) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc deleted file mode 100644 index c2026ff8cb..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" -#include "src/common/log_adapter.h" -#include "tools/common/graph_util.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace { -std::vector nchw2nhwc_perm = {0, 2, 3, 1}; -std::vector nhwc2nchw_perm = {0, 3, 1, 2}; -} // namespace -namespace lite { -#define kFormatTransMatchPathLen2 2 -#define kFormatTransMatchPathLen3 3 - -STATUS FormatTransFusionPass::DefinePattern() { - // nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc - { - auto transpose1 = std::make_shared(); - transpose1->id = kFormatTransTranspose1; - transpose1->types = {PrimitiveType_Transpose}; - auto transpose2 = std::make_shared(); - transpose2->id = kFormatTransTranspose2; - transpose2->types = {PrimitiveType_Transpose}; - - transpose2->left = transpose1; - auto pattern = std::make_unique(kNc2NhAndNh2NcFusionPattern); - if (pattern == nullptr) { - MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; - return RET_ERROR; - } - pattern->AddPatternOp(transpose1); - pattern->AddPatternOp(transpose2); - pattern->Finish(); - this->patterns.emplace_back(pattern.release()); - } - // nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw - { - auto transpose1 = std::make_shared(); - transpose1->id = kFormatTransTranspose1; - transpose1->types = {PrimitiveType_Transpose}; - auto passOp = std::make_shared(); - passOp->id = kFormatTransPassOp; - passOp->types = {PrimitiveType_QuantDTypeCast}; - auto transpose2 = std::make_shared(); - transpose2->id = kFormatTransTranspose2; - transpose2->types = {PrimitiveType_Transpose}; - - passOp->left = transpose2; - transpose1->left = passOp; - auto pattern = std::make_unique(kNh2NcAndNc2NhPassFusionPattern); - if (pattern == nullptr) { - MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; - return RET_ERROR; - } - pattern->AddPatternOp(transpose1); - pattern->AddPatternOp(passOp); - pattern->AddPatternOp(transpose2); - pattern->Finish(); - this->patterns.emplace_back(pattern.release()); - } - return RET_OK; -} - -STATUS FormatTransFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } - -STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - if (matchedPath.size() != kFormatTransMatchPathLen2 && matchedPath.size() != kFormatTransMatchPathLen3) { - MS_LOG(ERROR) << "schema::Format-Transform-Fusion should have " << kFormatTransMatchPathLen2 << " or " - << kFormatTransMatchPathLen3 << " NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - - std::shared_ptr srcPath = matchedPath[kFormatTransTranspose1]; - std::shared_ptr dstPath = matchedPath[kFormatTransTranspose2]; - if (srcPath == nullptr || dstPath == nullptr) { - MS_LOG(ERROR) << "srcPath or dstPath is failed to get"; - return RET_ERROR; - } - auto &srcNode = graph->nodes.at(srcPath->nodeIdx); - auto &dstNode = graph->nodes.at(dstPath->nodeIdx); - MS_ASSERT(srcNode != nullptr); - MS_ASSERT(dstNode != nullptr); - auto src_perm = GetTransposePerm(graph, srcNode); - auto dst_perm = GetTransposePerm(graph, dstNode); - bool isNc2NhAndNh2Nc = src_perm == nchw2nhwc_perm && dst_perm == nhwc2nchw_perm; - bool isNh2NcAndNc2Nh = src_perm == nhwc2nchw_perm && dst_perm == nchw2nhwc_perm; - if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) { - auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; - return status; - } - status = IsolateOneWayNode(graph, dstPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; - return status; - } - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h deleted file mode 100644 index 4ac7b779ba..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H -#define MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" - -namespace mindspore { -namespace lite { -constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1"; -constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2"; -constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; -constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; -constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; - -class FormatTransFusionPass : public FusionPass { - public: - FormatTransFusionPass() = default; - - ~FormatTransFusionPass() override = default; - - STATUS DefinePattern() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index d50b798d1a..a54ade170f 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -1,17 +1,12 @@ file(GLOB GRAPH_PASS - ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_insert_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/dropout_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc deleted file mode 100644 index 4924086495..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ /dev/null @@ -1,461 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" -#include "tools/common/node_util.h" -#include "src/common/log_adapter.h" -#include "src/common/common.h" -#include "src/common/utils.h" - -namespace mindspore { -namespace lite { -#define kMinInputNum 1 -#define kOutputNum 1 - -STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - auto status = DoModelInputFormatTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; - return status; - } - status = DoNodeInoutFormatTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; - return status; - } - return RET_OK; -} - -STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type, - FormatTransNodeType *after_node_type) { - if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc - if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { - return RET_NO_CHANGE; - } - *before_node_type = kNHWC2NCHW; - *after_node_type = kNCHW2NHWC; - return RET_OK; - } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || - fmk_type_ == converter::FmkType_ONNX) { - if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { - return RET_NO_CHANGE; - } - *before_node_type = kNCHW2NHWC; - *after_node_type = kNHWC2NCHW; - return RET_OK; - } else if (fmk_type_ == converter::FmkType_TF) { - if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) { - *before_node_type = kNCHW2NHWC; - *after_node_type = kNHWC2NCHW; - return RET_OK; - } - if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { - *before_node_type = kNHWC2NCHW; - *after_node_type = kNCHW2NHWC; - return RET_OK; - } - return RET_NO_CHANGE; - } - MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_; - return RET_ERROR; -} - -STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { - if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) { - return RET_OK; - } - MS_ASSERT(graph != nullptr); - // insert trans node in model input tensor - if (graph->nodes.empty()) { - return RET_OK; - } - // onnx input format may be nhwc - if (fmk_type_ == converter::FmkType_ONNX && graph->inputIndex.size() == 1) { - auto &input_tensor = graph->allTensors.at(graph->inputIndex[0]); - auto &input_dims = input_tensor->dims; - if (input_dims.size() == 4 && input_dims[3] != -1 && input_dims[1] == -1) { - return RET_OK; - } - } - auto graph_input_idxes = graph->inputIndex; - for (size_t i = 0; i < graph_input_idxes.size(); i++) { - bool transed = false; - auto input_idx = graph_input_idxes.at(i); - auto &tensor = graph->allTensors.at(input_idx); - if (tensor->dims.size() != kNCHWDimNumber) { - continue; - } - - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) { - if ((*iter)->inputIndex.at(input_index_idx) == input_idx) { - STATUS status = RET_OK; - iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; - return status; - } - // set first tensor format to nhwc - auto &trans_node = *(iter - 1); - MS_ASSERT(trans_node != nullptr); - MS_ASSERT(trans_node->inputIndex.size() == 1); - auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front()); - graph_in_tensor->format = schema::Format::Format_NHWC; - // assume parser not reformat shape - auto old_dims = graph_in_tensor->dims; - if (!transed) { - graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]}; - transed = true; - } - } - } - } - } - return RET_OK; -} - -// inference needed inputFormat: -// conv deconv depth dedepth -// fp32 NCHW NCHW NCHW NCHW -// uint8 NCHW ? NCHW ? -STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - // insert before and after the op cal by nchw/nc4hw4 - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - FormatTransNodeType before_node_type = kNCHW2NHWC; - FormatTransNodeType after_node_type = kNHWC2NCHW; - STATUS status = RET_OK; - status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type); - if (status == RET_NO_CHANGE) { - continue; - } - if (status != RET_OK) { - return status; - } - auto &node = *iter; - auto nodeName = node->name; - if (node->inputIndex.size() < kMinInputNum) { - MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; - return RET_ERROR; - } - if (node->outputIndex.size() < kOutputNum) { - MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; - return RET_ERROR; - } - void *attr = node->primitive->value.value; - if (node->primitive->value.type == schema::PrimitiveType_SpaceToDepth) { - reinterpret_cast(attr)->format = schema::Format_NHWC; - } - if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { - reinterpret_cast(attr)->format = schema::Format_NHWC; - } - auto spec_insert_indexes = GetExtNhwcIndexes(); - auto op_type = GetCNodeTType(**iter); - if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) { - for (auto insert_index : spec_insert_indexes[op_type]) { - iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; - return RET_ERROR; - } - } - } else if (IsContain(GetNhwcAllInputOpList(), op_type)) { - auto input_size = node->inputIndex.size(); - if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) { - if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) { - input_size = 1; - } - } - for (size_t i = 0; i < input_size; i++) { - iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; - return RET_ERROR; - } - } - } else { - iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status); - } - iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; - return RET_ERROR; - } - } - return RET_OK; -} - -NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place, - size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) { - MS_ASSERT((*exist_node_iter) != nullptr); - MS_ASSERT(graph != nullptr); - auto exist_node_name = (*exist_node_iter)->name; - std::string tile_name; - if (place == kBefore) { - tile_name = exist_node_name + "_pre"; - } else { - tile_name = exist_node_name + "_post"; - } - auto trans_node = std::make_unique(); - trans_node->primitive = std::make_unique(); - trans_node->primitive->value.type = schema::PrimitiveType_Transpose; - auto perm_tensor = std::make_unique(); - perm_tensor->dataType = kNumberTypeInt32; - perm_tensor->dims = {4}; - std::vector perm; - if (node_type == kNCHW2NHWC) { - trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++); - perm = {0, 2, 3, 1}; - } else { - trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++); - perm = {0, 3, 1, 2}; - } - size_t bytes = perm.size() * sizeof(int); - perm_tensor->data.resize(bytes); - if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) { - MS_LOG(ERROR) << "memcpy data failed."; - } - perm_tensor->name = trans_node->name + "_perm"; - - OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr { - auto new_op_def = std::make_unique(); - if (new_op_def == nullptr) { - MS_LOG(ERROR) << "new CNodeT failed"; - return nullptr; - } - new_op_def->name = in_op_def->name; - new_op_def->quantType = in_op_def->quantType; - new_op_def->primitive = std::make_unique(); - if (new_op_def->primitive == nullptr) { - MS_LOG(ERROR) << "new PrimitiveT failed"; - return nullptr; - } - new_op_def->primitive->value.type = schema::PrimitiveType_Transpose; - return new_op_def; - }; - int insert_num = 0; - auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, - transpose_op_copyer); - size_t index = graph->allTensors.size(); - graph->allTensors.push_back(std::move(perm_tensor)); - for (int i = insert_num; i > 0; --i) { - (*(iter - i))->inputIndex.push_back(index); - } - return iter; -} - -int FormatTransPass::GetFormat(const schema::CNodeT &node) { - switch (node.primitive->value.type) { - case schema::PrimitiveType_Conv2DFusion: - return node.primitive->value.AsConv2DFusion()->format; - case schema::PrimitiveType_Conv2dTransposeFusion: - return node.primitive->value.AsConv2dTransposeFusion()->format; - case schema::PrimitiveType_AvgPoolFusion: - return node.primitive->value.AsAvgPoolFusion()->format; - case schema::PrimitiveType_MaxPoolFusion: - return node.primitive->value.AsMaxPoolFusion()->format; - default: - return schema::Format_NHWC; - } -} - -STATUS FormatTransPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node) { - MS_ASSERT(node->primitive != nullptr); - auto type = node->primitive->value.type; - auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size(); - if (input1_ndim != 4) { - if (node->inputIndex.size() > 1) { - auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size(); - if (input2_ndim != 4 && input2_ndim != 0) { - MS_LOG(ERROR) << "change op axis only support 4 dims"; - return RET_NOT_SUPPORT; - } - } else { - MS_LOG(DEBUG) << "change op axis only support 4 dims"; - return RET_NOT_SUPPORT; - } - } - if (type == schema::PrimitiveType_Concat) { - MS_ASSERT(node->primitive->value.AsConcat() != nullptr); - auto origin_axis = node->primitive->value.AsConcat()->axis; - auto axis_map = GetNc2NhAxisMap(); - if (node->primitive->value.AsConcat() == nullptr) { - MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr"; - return RET_NULL_PTR; - } - node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis]; - } - if (type == schema::PrimitiveType_Split) { - MS_ASSERT(node->primitive->value.AsSplit() != nullptr); - auto origin_axis = node->primitive->value.AsSplit()->axis; - auto axis_map = GetNc2NhAxisMap(); - if (node->primitive->value.AsSplit() == nullptr) { - MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr"; - return RET_NULL_PTR; - } - node->primitive->value.AsSplit()->axis = axis_map[origin_axis]; - } - if (type == schema::PrimitiveType_Crop) { - MS_ASSERT(node->primitive->value.AsCrop() != nullptr); - auto origin_axis = node->primitive->value.AsCrop()->axis; - auto offsets = node->primitive->value.AsCrop()->offsets; - auto axis_map = GetNc2NhAxisMap(); - if (node->primitive->value.AsCrop() == nullptr) { - MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; - return RET_NULL_PTR; - } - // nchw->nhwc,offsets need pad 0; - if (axis_map[origin_axis] == 0) { - offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; - } else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) { - // orgin_axis = 2 or orgin_axis = 3 - offsets.push_back(0); - } else if (axis_map[origin_axis] == -1) { - // origin_axis = 1 - offsets = {offsets[1], offsets[2], offsets[0]}; - } else { - // axis error - MS_LOG(ERROR) << "Crop error"; - return RET_ERROR; - } - node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; - node->primitive->value.AsCrop()->offsets = offsets; - } - if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) { - return ChangeOpSliceAndStridedSlice(graph, node); - } - return RET_OK; -} - -void FormatTransPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) { - if (origin_attr == nullptr || axes == nullptr || element_size == 0) { - return; - } - auto axis_map = GetNc2NhAxisMap(); - std::vector cur_attr; - for (int dim = 0; dim < 4; ++dim) { - for (int index = 0; index < element_size; ++index) { - int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; - if (nhwc_dim == dim || (nhwc_dim + 4) == dim) { - cur_attr.push_back(origin_attr[index]); - } - } - } - for (int index = 0; index < element_size; ++index) { - origin_attr[index] = cur_attr[index]; - } -} - -void FormatTransPass::TransformOpAxisAttr(int *origin_axis, int element_size) { - if (origin_axis == nullptr || element_size == 0) { - return; - } - auto axis_map = GetNc2NhAxisMap(); - std::vector new_axis; - for (int i = 0; i < element_size; ++i) { - int axis = axis_map[origin_axis[i]]; - axis = axis < 0 ? axis + 4 : axis; - new_axis.push_back(axis); - } - std::sort(new_axis.begin(), new_axis.end()); - for (int i = 0; i < element_size; ++i) { - origin_axis[i] = new_axis[i]; - } -} - -STATUS FormatTransPass::ChangeOpSlice(schema::MetaGraphT *graph, const std::unique_ptr &node) { - auto attr = node->primitive->value.AsSliceFusion(); - if (attr == nullptr) { - MS_LOG(ERROR) << "node->primitive->value.AsSliceFusion() is nullptr."; - return RET_NULL_PTR; - } - // transform attr - if (node->inputIndex.size() < 2) { - MS_LOG(ERROR) << "slice input is error"; - return RET_ERROR; - } - for (size_t index = 1; index < node->inputIndex.size(); ++index) { - if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { - return RET_NOT_SUPPORT; - } - } - int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; - std::vector axes; - auto axes_attr = attr->axes; - if (axes_attr.empty()) { - for (int index = 0; index < element_num; ++index) { - axes.push_back(index); - } - } else { - std::transform(axes_attr.begin(), axes_attr.end(), std::back_inserter(axes), - [](int64_t val) { return static_cast(val); }); - } - for (size_t index = 1; index < node->inputIndex.size(); ++index) { - TransformAttrByAxes(reinterpret_cast(graph->allTensors[node->inputIndex[index]]->data.data()), - reinterpret_cast(axes.data()), element_num); - } - TransformOpAxisAttr(axes.data(), element_num); - attr->axes.clear(); - for (int i = 0; i < element_num; ++i) { - attr->axes.push_back(static_cast(axes[i])); - } - return RET_OK; -} - -STATUS FormatTransPass::ChangeOpStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr &node) { - // onnx input size is equal to 5 always. - if (node->inputIndex.size() != 5) { - return RET_NOT_SUPPORT; - } - if (node->inputIndex.size() == 5) { - for (int index = 1; index < 5; ++index) { - if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { - return RET_NOT_SUPPORT; - } - } - int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; - auto axes = graph->allTensors[node->inputIndex[3]]->data; - for (int index = 1; index < 5; ++index) { - if (index == 3) { - continue; - } - TransformAttrByAxes(reinterpret_cast(graph->allTensors[node->inputIndex[index]]->data.data()), - reinterpret_cast(axes.data()), element_num); - } - TransformOpAxisAttr(reinterpret_cast(graph->allTensors[node->inputIndex[3]]->data.data()), element_num); - } - return RET_OK; -} - -STATUS FormatTransPass::ChangeOpSliceAndStridedSlice(schema::MetaGraphT *graph, - const std::unique_ptr &node) { - auto type = node->primitive->value.type; - if (type == schema::PrimitiveType_StridedSlice) { - return ChangeOpStridedSlice(graph, node); - } - if (type == schema::PrimitiveType_SliceFusion) { - return ChangeOpSlice(graph, node); - } - return RET_ERROR; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h deleted file mode 100644 index afeef24aa8..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H -#define MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H - -#include -#include "tools/converter/optimizer.h" -#include "tools/common/graph_util.h" -#include "tools/converter/converter_flags.h" - -namespace mindspore { -namespace lite { -enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE }; - -class FormatTransPass : public GraphPass { - public: - FormatTransPass() : id_(0) {} - - ~FormatTransPass() override = default; - - STATUS Run(schema::MetaGraphT *graph) override; - - void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; } - - void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; } - - protected: - NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place, - size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code); - - STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr &node); - - private: - STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); - - STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); - - void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); - - void TransformOpAxisAttr(int *origin_axis, int element_size); - - STATUS ChangeOpSlice(schema::MetaGraphT *graph, const std::unique_ptr &node); - - STATUS ChangeOpStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr &node); - - STATUS ChangeOpSliceAndStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr &node); - - int GetFormat(const schema::CNodeT &); - - STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type, - FormatTransNodeType *after_node_type); - - protected: - size_t id_ = 0; - converter::FmkType fmk_type_ = converter::FmkType_TF; - - private: - QuantType quant_type_ = QuantType_QUANT_NONE; -}; -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc deleted file mode 100644 index 85813c12b4..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc +++ /dev/null @@ -1,223 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h" -#include -#include "third_party/securec/include/securec.h" -#include "src/common/log_adapter.h" -#include "src/common/utils.h" -#include "tools/common/graph_util.h" -#include "tools/common/node_util.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace { -std::vector nchw2nhwc_perm = {0, 2, 3, 1}; -std::vector nhwc2nchw_perm = {0, 3, 1, 2}; -} // namespace -namespace lite { - -STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - std::set need_del_nodes; - std::set need_trans_format_nodes; - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - auto type = node->primitive->value.type; - if (type != PrimitiveType_Transpose) { - continue; - } - if (GetTransposePerm(graph, node) != nchw2nhwc_perm) { - continue; - } - std::vector pre_nh2nc_nodes; - std::vector pre_not_trans_nodes; - auto status = FindPreNh2NcNodes(graph, iter - graph->nodes.begin(), &pre_nh2nc_nodes, &pre_not_trans_nodes); - if (status != RET_OK) { - MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; - return status; - } - std::copy(pre_nh2nc_nodes.begin(), pre_nh2nc_nodes.end(), std::inserter(need_del_nodes, need_del_nodes.end())); - std::copy(pre_not_trans_nodes.begin(), pre_not_trans_nodes.end(), - std::inserter(need_trans_format_nodes, need_trans_format_nodes.end())); - if (!pre_nh2nc_nodes.empty()) { - need_del_nodes.insert(iter - graph->nodes.begin()); - } - } - if (need_del_nodes.empty()) { - return RET_OK; - } - for (auto del_node_index : need_del_nodes) { - auto node_name = graph->nodes.at(del_node_index)->name; - auto status = IsolateOneWayNode(graph, del_node_index); - if (status != RET_OK) { - MS_LOG(ERROR) << "Isolate Node failed, node: " << node_name << ", error: " << status; - return status; - } - } - - auto status = TransWeightToNhwc(graph, need_trans_format_nodes); - if (status != RET_OK) { - MS_LOG(ERROR) << "trans weight to nhwc failed"; - return status; - } - return RET_OK; -} - -STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector &pad_dims) { - if (pad_dims.size() != 4) { - MS_LOG(ERROR) << "pad dims error"; - return RET_ERROR; - } - auto batch = pad_dims[NCHW_N]; - auto channel = pad_dims[NCHW_C]; - auto area = pad_dims[NCHW_H] * pad_dims[NCHW_W]; - auto size = batch * channel * area; - auto new_nhwc_data = new (std::nothrow) float[size]; - if (new_nhwc_data == nullptr) { - MS_LOG(ERROR) << "create new nhwc data failed"; - delete[] new_nhwc_data; - return RET_ERROR; - } - if (memset_s(new_nhwc_data, sizeof(float) * size, 0, sizeof(float) * size) != EOK) { - MS_LOG(ERROR) << "create new nhwc data failed"; - delete[] new_nhwc_data; - return RET_ERROR; - } - auto nchw_data = reinterpret_cast(tensor->data.data()); - // nchw to nhwc - for (auto i = 0; i < batch; i++) { - float *src_batch = nchw_data + i * channel * area; - float *dst_batch = new_nhwc_data + i * channel * area; - for (int j = 0; j < area; ++j) { - float *src_area = src_batch + i; - float *dst_area = dst_batch + i * channel; - for (int k = 0; k < channel; ++k) { - dst_area[k] = src_area[k * area]; - } - } - } - if (memcpy_s(nchw_data, tensor->data.size(), new_nhwc_data, sizeof(float) * size) != EOK) { - MS_LOG(ERROR) << "memcpy_s failed"; - delete[] new_nhwc_data; - return RET_ERROR; - } - delete[] new_nhwc_data; - return RET_OK; -} - -STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set &pre_not_trans_nodes) { - MS_ASSERT(graph != nullptr); - if (pre_not_trans_nodes.empty()) { - return RET_OK; - } - for (auto index : pre_not_trans_nodes) { - auto &cur_node = graph->nodes.at(index); - // need change axis from nchw to nhwc like concat,slice - auto ret = ChangeOpAxis(graph, cur_node); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ChangeOpAxis error"; - return ret; - } - auto node_input_indexs = cur_node->inputIndex; - for (auto input_index : node_input_indexs) { - // weight data need trans nhwc layerout - if (!IsContain(graph->inputIndex, input_index) && - graph->allTensors.at(input_index)->nodeType == NodeType_ValueNode) { - auto &weight_tensor = graph->allTensors.at(input_index); - auto origin_dims = weight_tensor->dims; - weight_tensor->format = Format_NHWC; - if (origin_dims.size() > 4) { - MS_LOG(ERROR) << "tensor origin tensor size error"; - return RET_ERROR; - } - if (origin_dims.empty()) { - continue; - } - auto pad_dims = origin_dims; - if (origin_dims.size() == 1) { - pad_dims = {1, 1, 1, origin_dims[0]}; - } else if (origin_dims.size() == 2) { - pad_dims = {1, 1, origin_dims[0], origin_dims[1]}; - } else if (origin_dims.size() == 3) { - pad_dims = {1, origin_dims[0], origin_dims[1], origin_dims[2]}; - } - if (ConvertNcTensor2Nh(weight_tensor.get(), pad_dims) != RET_OK) { - MS_LOG(ERROR) << "Convert nchw to nhwc failed"; - return RET_ERROR; - } - weight_tensor->dims = {pad_dims[NCHW_N], pad_dims[NCHW_H], pad_dims[NCHW_W], pad_dims[NCHW_C]}; - } - } - } - return RET_OK; -} - -STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, - std::vector *pre_nh2nc_nodes, - std::vector *pre_not_trans_nodes) { - MS_ASSERT(graph != nullptr); - std::vector bfs_queue = {nc2nh_index}; - // find pre node nh2nc start nodes - while (!bfs_queue.empty()) { - auto cur_node_index = bfs_queue.back(); - auto &cur_node = graph->nodes.at(cur_node_index); - bfs_queue.pop_back(); - auto input_node_indexes = GetInputNodeIdx(*graph, *cur_node); - for (auto input_node_index : input_node_indexes) { - MS_ASSERT(graph->nodes.size() > input_node_index); - auto &pre_node = graph->nodes.at(input_node_index); - MS_ASSERT(pre_node != nullptr); - auto node_type = pre_node->primitive->value.type; - if (node_type == schema::PrimitiveType_Transpose && GetTransposePerm(graph, pre_node) == nhwc2nchw_perm) { - if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { - pre_nh2nc_nodes->emplace_back(input_node_index); - } - } else if (IsContain(GetInsertOpList(), node_type)) { - if (!IsContain(bfs_queue, input_node_index)) { - bfs_queue.emplace_back(input_node_index); - } - auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); - if (pre_node_output_indexs.size() != 1) { - if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { - pre_nh2nc_nodes->clear(); - pre_not_trans_nodes->clear(); - return RET_OK; - } - for (auto pre_node_output_index : pre_node_output_indexs) { - MS_ASSERT(graph->nodes.size() > pre_node_output_index); - if (graph->nodes.at(pre_node_output_index)->primitive->value.type == schema::PrimitiveType_PadFusion) { - pre_nh2nc_nodes->clear(); - pre_not_trans_nodes->clear(); - return RET_OK; - } - } - } - } else { - pre_nh2nc_nodes->clear(); - pre_not_trans_nodes->clear(); - return RET_OK; - } - if (!IsContain(*pre_not_trans_nodes, cur_node_index) && cur_node_index != nc2nh_index) { - pre_not_trans_nodes->emplace_back(cur_node_index); - } - } - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.h deleted file mode 100644 index 68f7161397..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H -#define MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H - -#include -#include -#include -#include -#include -#include -#include "tools/common/graph_util.h" -#include "tools/converter/optimizer.h" -#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" - -using mindspore::schema::TensorT; -namespace mindspore { -namespace lite { -class GlobalFormatTransformPass : public FormatTransPass { - public: - GlobalFormatTransformPass() = default; - - ~GlobalFormatTransformPass() override = default; - - STATUS Run(MetaGraphT *graph) override; - - protected: - STATUS TransWeightToNhwc(MetaGraphT *graph, const std::set &pre_not_trans_nodes); - - STATUS FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, std::vector *to_do_insert_nodes, - std::vector *pre_not_trans_nodes); -}; -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc deleted file mode 100644 index 482f1f65f8..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ /dev/null @@ -1,193 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" -#include "tools/common/node_util.h" -#include "src/common/log_adapter.h" -#include "src/common/utils.h" - -namespace mindspore { -namespace { -std::vector nchw2nhwc_perm = {0, 2, 3, 1}; -std::vector nhwc2nchw_perm = {0, 3, 1, 2}; -} // namespace -namespace lite { -bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector &node_indexes, size_t *has_trans_count, - FormatTransNodeType *trans_type) { - for (auto input_node_index : node_indexes) { - MS_ASSERT(graph->nodes.size() > input_node_index); - auto &pre_node = graph->nodes.at(input_node_index); - MS_ASSERT(pre_node != nullptr); - MS_ASSERT(pre_node->primitive != nullptr); - MS_ASSERT(pre_node->primitive->value != nullptr); - if (*trans_type == kNONE) { - if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { - auto perm = GetTransposePerm(graph, pre_node); - if (perm == nchw2nhwc_perm) { - *trans_type = kNCHW2NHWC; - } else if (perm == nhwc2nchw_perm) { - *trans_type = kNHWC2NCHW; - } else { - return false; - } - (*has_trans_count)++; - } - } else { - if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { - auto cur_type = kNONE; - auto perm = GetTransposePerm(graph, pre_node); - if (perm == nchw2nhwc_perm) { - cur_type = kNCHW2NHWC; - } else if (perm == nhwc2nchw_perm) { - cur_type = kNHWC2NCHW; - } else { - return false; - } - if (*trans_type != cur_type) { - return false; - } else { - (*has_trans_count)++; - } - } - } - } - return true; -} -bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node) { - MS_ASSERT(graph != nullptr); - MS_ASSERT(node != nullptr); - auto input_node_indexes = GetInputNodeIdx(*graph, *node); - pre_type_ = kNONE; - size_t has_trans_count = 0; - if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) { - return false; - } - auto output_node_indexes = GetOutputNodeIdx(*graph, *node); - post_type_ = kNONE; - if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) { - return false; - } - if (pre_type_ == kNONE && post_type_ == kNONE) { - return false; - } - auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); - auto total_node_count = input_node_indexes.size() + output_size; - size_t half_count = total_node_count / 2; - if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) { - MS_ASSERT(node != nullptr); - MS_ASSERT(node->primitive != nullptr); - MS_ASSERT(node->primitive->value != nullptr); - MS_ASSERT(node->primitive->value.AsActivation() != nullptr); - if (node->primitive->value.AsActivation() != nullptr && - node->primitive->value.AsActivation()->activation_type == schema::ActivationType_LEAKY_RELU) { - return has_trans_count >= half_count; - } - } - if (GetCNodeTType(*node) == schema::PrimitiveType_Split) { - return has_trans_count >= half_count; - } - return has_trans_count > half_count; -} -STATUS TransOpInsertPass::FindOutTransType() { - pre_insert_trans_type_ = kNHWC2NCHW; - post_insert_trans_type_ = kNHWC2NCHW; - if (pre_type_ == kNONE && post_type_ != kNONE) { - pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; - post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; - } else if (pre_type_ != kNONE && post_type_ == kNONE) { - pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; - post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; - } else if (pre_type_ == kNONE && post_type_ == kNONE) { - MS_ASSERT(false); - } else { - if (pre_type_ == post_type_) { - MS_LOG(ERROR) << "Unknown error"; - return RET_ERROR; - } - pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; - post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; - } - return RET_OK; -} - -STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - bool changed = true; - int run_counts = 0; - std::vector has_insert_nodes; - while (changed && run_counts < 10) { - changed = false; - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - if (node == nullptr || node->primitive == nullptr) { - MS_LOG(ERROR) << "node or primitive null"; - return RET_NULL_PTR; - } - auto type = node->primitive->value.type; - if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) { - continue; - } - auto node_name = node->name; - if (!CanFusion(graph, node)) { - continue; - } - auto ret = FindOutTransType(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "FindOutTransType error"; - return ret; - } - ret = ChangeOpAxis(graph, node); - if (ret == RET_NOT_SUPPORT) { - MS_LOG(INFO) << "not support to ChangeOpAxis"; - return RET_OK; - } else if (ret != RET_OK) { - MS_LOG(INFO) << "no need to ChangeOpAxis"; - return ret; - } - has_insert_nodes.push_back(node.get()); - STATUS status = RET_OK; - auto input_tensor_size = (*iter)->inputIndex.size(); - for (size_t i = 0; i < input_tensor_size; i++) { - auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]); - if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) { - continue; - } - iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; - return status; - } - } - auto output_tensor_size = (*iter)->outputIndex.size(); - for (size_t i = 0; i < output_tensor_size; i++) { - iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); - if (status != RET_OK) { - MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; - return status; - } - } - changed = true; - } - run_counts++; - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h deleted file mode 100644 index a9e785e369..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H -#define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H - -#include -#include -#include "tools/common/graph_util.h" -#include "tools/converter/converter_flags.h" -#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" - -namespace mindspore { -namespace lite { -class TransOpInsertPass : public FormatTransPass { - public: - TransOpInsertPass() : FormatTransPass() {} - - ~TransOpInsertPass() override = default; - - STATUS Run(schema::MetaGraphT *graph) override; - - private: - bool CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node); - - STATUS FindOutTransType(); - - private: - FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; - FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; - FormatTransNodeType pre_type_ = kNONE; - std::vector pre_perm_; - FormatTransNodeType post_type_ = kNONE; - std::vector post_perm_; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc deleted file mode 100644 index a0246a6dc6..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" -#include -#include "src/common/log_adapter.h" -#include "include/errorcode.h" -#include "tools/common/graph_util.h" -#include "src/tensor.h" - -using mindspore::lite::Tensor; -namespace mindspore { -namespace { -std::vector nchw2nhwc_perm = {0, 2, 3, 1}; -std::vector nhwc2nchw_perm = {0, 3, 1, 2}; -} // namespace -namespace lite { -STATUS TransOpRemovePass::Run(MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - auto type = node->primitive->value.type; - auto perm = GetTransposePerm(graph, node); - if (type == schema::PrimitiveType_Transpose && (perm == nchw2nhwc_perm || perm == nhwc2nchw_perm)) { - auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); - // less than 4 dims can delete - if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { - auto status = IsolateOneWayNode(graph, node.get(), true); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << node->name.c_str() << ", error: " << status; - return status; - } - } - } - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h deleted file mode 100644 index 30b8e20e27..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2020-2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H -#define MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H - -#include -#include -#include -#include -#include "tools/common/graph_util.h" -#include "tools/converter/optimizer.h" - -using mindspore::schema::TensorT; -namespace mindspore { -namespace lite { -class TransOpRemovePass : public GraphPass { - public: - TransOpRemovePass() = default; - - ~TransOpRemovePass() = default; - - STATUS Run(MetaGraphT *graph) override; -}; -} // namespace lite -} // namespace mindspore -#endif // MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H diff --git a/mindspore/lite/tools/optimizer/common/format_utils.cc b/mindspore/lite/tools/optimizer/common/format_utils.cc index fad9587447..443875247c 100644 --- a/mindspore/lite/tools/optimizer/common/format_utils.cc +++ b/mindspore/lite/tools/optimizer/common/format_utils.cc @@ -28,14 +28,15 @@ #include "ops/concat.h" #include "ops/crop.h" #include "ops/depth_to_space.h" +#include "ops/fused_batch_norm.h" #include "ops/fusion/activation.h" #include "ops/fusion/add_fusion.h" -#include "ops/fused_batch_norm.h" #include "ops/fusion/avg_pool_fusion.h" #include "ops/fusion/conv2d_backprop_input_fusion.h" #include "ops/fusion/conv2d_backprop_filter_fusion.h" #include "ops/fusion/conv2d_fusion.h" #include "ops/fusion/conv2d_transpose_fusion.h" +#include "ops/fusion/div_fusion.h" #include "ops/fusion/max_pool_fusion.h" #include "ops/fusion/mul_fusion.h" #include "ops/fusion/pow_fusion.h" @@ -61,6 +62,7 @@ #include "ops/space_to_depth.h" #include "ops/split.h" #include "ops/strided_slice.h" +#include "tools/anf_exporter/fetch_content.h" namespace mindspore { namespace opt { @@ -96,9 +98,9 @@ static const std::unordered_map> NCHWOpMap = {{ // a certain op whose input's format is not fixed. static const std::vector DynamicFormatOpList = { - ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNamePowFusion, ops::kNameStridedSlice, - ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, ops::kNameCrop, - ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast}; + ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNameDivFusion, ops::kNamePowFusion, + ops::kNameStridedSlice, ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, + ops::kNameCrop, ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast}; static const std::unordered_map NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; @@ -120,33 +122,34 @@ Format GetFormat(const CNodePtr &cnode) { return format; } -STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector *perm) { +STATUS GetTransposePerm(const CNodePtr &cnode, std::vector *perm) { MS_ASSERT(perm_node != nullptr); - if (!utils::isa(perm_node)) { - return lite::RET_OK; + if (cnode->size() != 3) { + MS_LOG(ERROR) << "transpose op input size must be three."; + return lite::RET_ERROR; } - auto perm_param = perm_node->cast(); - if (!perm_param->has_default() || perm_param->default_param() == nullptr) { + if (utils::isa(cnode->input(2))) { return lite::RET_OK; } - auto tensor_info = perm_param->default_param()->cast(); - if (tensor_info == nullptr) { - MS_LOG(ERROR) << "default param is not a tensor."; - return lite::RET_ERROR; + lite::DataInfo data_info; + int status; + if (utils::isa(cnode->input(2))) { + status = lite::FetchDataFromParameterNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info); + } else { + status = lite::FetchDataFromValueNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info); } - if (tensor_info->data_type() != kNumberTypeInt && tensor_info->data_type() != kNumberTypeInt32) { - MS_LOG(ERROR) << "data type is error, which is " << tensor_info->data_type(); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "fetch transpose perm data failed."; return lite::RET_ERROR; } - auto tensor_shape = tensor_info->shape(); - if (tensor_shape.empty()) { - return lite::RET_OK; - } - if (tensor_shape.size() > 1) { + if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) || + data_info.shape_.size() != 1) { + MS_LOG(ERROR) << "transpose perm data is invalid."; return lite::RET_ERROR; } - perm->resize(tensor_shape[0]); - if (memcpy_s(perm->data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { + perm->resize(data_info.shape_[0]); + if (!data_info.data_.empty() && + memcpy_s(perm->data(), data_info.data_.size(), data_info.data_.data(), data_info.data_.size()) != EOK) { MS_LOG(ERROR) << "memcpy data failed."; return lite::RET_ERROR; } diff --git a/mindspore/lite/tools/optimizer/common/format_utils.h b/mindspore/lite/tools/optimizer/common/format_utils.h index e1f9119172..98e2b3efa0 100644 --- a/mindspore/lite/tools/optimizer/common/format_utils.h +++ b/mindspore/lite/tools/optimizer/common/format_utils.h @@ -38,7 +38,7 @@ const std::unordered_map> &GetNCHWOpMap(); const std::unordered_map &GetNC2NHAxisMap(); const std::vector &GetDynamicFormatOpList(); Format GetFormat(const CNodePtr &cnode); -STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector *perm); +STATUS GetTransposePerm(const CNodePtr &cnode, std::vector *perm); void RemoveIfMonad(const CNodePtr &cnode); bool IsMonadNode(const AnfNodePtr &node); } // namespace opt diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 36c6afd7fa..1cbe39d645 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -15,12 +15,13 @@ */ #include "tools/optimizer/fusion/constant_folding_fusion.h" +#include #include #include #include +#include "tools/anf_exporter/fetch_content.h" #include "tools/converter/quant_param_holder.h" #include "tools/optimizer/common/gllo_utils.h" -#include "tools/anf_exporter/anf_exporter.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" #include "src/common/common.h" @@ -36,47 +37,79 @@ using mindspore::lite::Tensor; namespace mindspore::opt { namespace { constexpr size_t INITIAL_SIZE = 1024; -std::vector GetCNodeInputTensors(const CNodePtr &CNode) { +void FreeTensors(std::vector *input_tensor, std::vector *output_tensor) { + if (input_tensor != nullptr) { + for (auto &i : *input_tensor) { + delete i; + i = nullptr; + } + } + if (output_tensor != nullptr) { + for (auto &i : *output_tensor) { + delete i; + i = nullptr; + } + } +} + +std::vector GetCNodeInputTensors(const CNodePtr &cnode, lite::converter::FmkType fmk_type) { MS_ASSERT(CNode != nullptr); - auto tmp_meta_graph = std::make_unique(); - auto tmp_fb_node = std::make_unique(); - lite::AnfExporter anfExporter; - anfExporter.SetOpInputNode(CNode, tmp_meta_graph, tmp_fb_node.get()); - std::vector input_tensors; - for (auto input_index : tmp_fb_node->inputIndex) { - auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); - auto tensor_shape = tensorT->dims; - auto lite_tensor = new (std::nothrow) Tensor( - TypeId(tensorT->dataType), tensor_shape, tensorT->format, - lite::TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size())); - if (lite_tensor == nullptr) { - MS_LOG(ERROR) << "lite tensor is nullptr"; - return input_tensors; + std::vector tensors; + for (size_t i = 1; i < cnode->size(); ++i) { + int status; + lite::DataInfo data_info; + if (utils::isa(cnode->input(i))) { + if (!cnode->input(i)->cast()->has_default()) { + FreeTensors(&tensors, nullptr); + return {}; + } + status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, false, &data_info); + } else if (utils::isa(cnode->input(i))) { + status = lite::FetchDataFromValueNode(cnode, i, fmk_type, false, &data_info); + } else { + MS_LOG(ERROR) << "input node is not const node."; + FreeTensors(&tensors, nullptr); + return {}; } - auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); - // when tensorT as graph input - if (lite_tensor_size <= 0) { - delete lite_tensor; - return input_tensors; + if (status == lite::RET_NO_CHANGE) { + continue; + } + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "parser const data failed."; + FreeTensors(&tensors, nullptr); + return {}; + } + if (data_info.shape_.empty() && data_info.data_.empty()) { + FreeTensors(&tensors, nullptr); + MS_LOG(DEBUG) << "input node is graph input."; + return {}; } - auto tensor_data = new (std::nothrow) uint8_t[lite_tensor_size / sizeof(char)]; + auto tensor = new (std::nothrow) + Tensor(TypeId(data_info.data_type_), data_info.shape_, schema::Format(data_info.format_), + lite::TensorCategory(0, data_info.shape_.size(), TypeId(data_info.data_type_), data_info.data_.size())); + if (tensor == nullptr) { + MS_LOG(ERROR) << "new a tensor is nullptr."; + FreeTensors(&tensors, nullptr); + return {}; + } + if (data_info.data_.empty()) { + tensors.emplace_back(tensor); + continue; + } + auto tensor_data = tensor->MutableData(); if (tensor_data == nullptr) { - MS_LOG(ERROR) << "tensor_data is nullptr"; - delete lite_tensor; - return input_tensors; + MS_LOG(ERROR) << "malloc data failed."; + FreeTensors(&tensors, nullptr); + return {}; } - auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size); - if (ret != EOK) { - delete lite_tensor; - delete[](tensor_data); - MS_LOG(ERROR) << "memcpy error: " << ret; - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED); + if (memcpy_s(tensor_data, data_info.data_.size(), data_info.data_.data(), data_info.data_.size()) != EOK) { + MS_LOG(ERROR) << "memcpy data failed."; + FreeTensors(&tensors, nullptr); return {}; } - lite_tensor->set_data(tensor_data); - input_tensors.emplace_back(lite_tensor); + tensors.emplace_back(tensor); } - return input_tensors; + return tensors; } ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { @@ -229,21 +262,6 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector } return lite::RET_OK; } - -void FreeTensors(std::vector *input_tensor, std::vector *output_tensor) { - if (input_tensor != nullptr) { - for (auto &i : *input_tensor) { - delete i; - i = nullptr; - } - } - if (output_tensor != nullptr) { - for (auto &i : *output_tensor) { - delete i; - i = nullptr; - } - } -} } // namespace const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, @@ -263,9 +281,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An continue; } auto input_cnode = input_node->cast(); - auto input_tensors = GetCNodeInputTensors(input_cnode); - if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { - FreeTensors(&input_tensors, nullptr); + auto input_tensors = GetCNodeInputTensors(input_cnode, fmk_type_); + if (input_tensors.empty()) { continue; } changed = true; @@ -279,7 +296,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An FreeTensors(&input_tensors, &output_tensors); return nullptr; } - auto lite_kernel = GetLiteKernel(input_tensors, &output_tensors, input_cnode, context.get()); + auto lite_kernel = GetLiteKernel(input_tensors, &output_tensors, input_cnode, context_.get()); if (lite_kernel == nullptr) { FreeTensors(&input_tensors, &output_tensors); MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h index 7429100f3f..ac875fe0d8 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.h @@ -24,18 +24,23 @@ #include "src/lite_kernel.h" #include "nnacl/op_base.h" #include "backend/optimizer/common/optimizer.h" +#include "tools/converter/converter_flags.h" namespace mindspore { namespace opt { class ConstFoldPass : public PatternProcessPass { public: - explicit ConstFoldPass(std::shared_ptr context_ptr = nullptr, bool multigraph = true) - : PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {} + explicit ConstFoldPass(lite::converter::FmkType fmk_type = lite::converter::FmkType_MS, bool multigraph = true) + : PatternProcessPass("constfold_pass", multigraph), fmk_type_(fmk_type) { + context_ = std::make_shared(); + context_->Init(); + } ~ConstFoldPass() override = default; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: - std::shared_ptr context; + lite::converter::FmkType fmk_type_{lite::converter::FmkType_MS}; + std::shared_ptr context_{nullptr}; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.cc b/mindspore/lite/tools/optimizer/graph/node_infershape.cc index ba0ecfe868..658cccbdd4 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.cc @@ -18,9 +18,9 @@ #include #include #include -#include "tools/anf_exporter/anf_exporter.h" #include "tools/common/node_util.h" #include "tools/common/tensor_util.h" +#include "src/common/utils.h" #include "src/ops/populate/populate_register.h" #include "src/ops/ops_utils.h" #include "src/runtime/infer_manager.h" @@ -67,8 +67,8 @@ bool DuceInferFlag(const CNodePtr &cnode, const std::vector &inp } } auto origin_inputs = cnode->inputs(); - lite::AnfExporter::RemoveIfDepend(cnode); - lite::AnfExporter::RemoveIfMakeTuple(cnode); + lite::RemoveIfDepend(cnode); + lite::RemoveIfMakeTuple(cnode); for (size_t i = 1; i < cnode->size(); ++i) { if (!utils::isa(cnode->input(i))) { continue; @@ -241,8 +241,8 @@ STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vectorinputs(); - lite::AnfExporter::RemoveIfDepend(cnode); - lite::AnfExporter::RemoveIfMakeTuple(cnode); + lite::RemoveIfDepend(cnode); + lite::RemoveIfMakeTuple(cnode); RemoveIfMonad(cnode); std::vector const_inputs; if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) { @@ -288,28 +288,29 @@ STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector *const_ms_inputs) { - MS_ASSERT(cnode != nullptr); - auto origin_inputs = cnode->inputs(); - std::vector const_inputs; - for (auto &input : origin_inputs) { - if (utils::isa(input)) { + MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr); + std::vector data_infos; + for (size_t i = 1; i < cnode->size(); ++i) { + if (utils::isa(cnode->input(i))) { continue; } - const_inputs.push_back(input); - } - cnode->set_inputs(const_inputs); - auto meta_graph = std::make_unique(); - meta_graph->fmkType = fmk_type_; - auto fb_node = std::make_unique(); - lite::AnfExporter anf_exporter; - anf_exporter.set_train_flag(train_flag_); - auto status = anf_exporter.SetOpInputNode(cnode, meta_graph, fb_node.get()); - cnode->set_inputs(origin_inputs); - if (status != lite::RET_OK) { - MS_LOG(ERROR) << "get const inputs failed."; - return status; + STATUS status; + lite::DataInfo data_info; + if (utils::isa(cnode->input(i))) { + status = lite::FetchDataFromParameterNode(cnode, i, fmk_type_, train_flag_, &data_info); + } else { + status = lite::FetchDataFromValueNode(cnode, i, fmk_type_, train_flag_, &data_info); + } + if (status == lite::RET_NO_CHANGE) { + continue; + } + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "fetch const input data failed."; + return status; + } + data_infos.emplace_back(data_info); } - return ConvertToLiteTensor(meta_graph, fb_node->inputIndex, const_ms_inputs); + return ConvertToLiteTensor(data_infos, const_ms_inputs); } STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector *var_ms_inputs) { @@ -319,29 +320,16 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector(cnode->input(i))) { continue; } - auto abstract = GetCNodeInputAbstract(cnode, i); - if (abstract == nullptr) { - MS_LOG(ERROR) << "Abstract cnode is nullptr."; - return lite::RET_ERROR; - } - if (!utils::isa(abstract)) { - MS_LOG(ERROR) << "Abstract should be anstract tensor."; - return lite::RET_ERROR; - } - auto abstract_tensor = utils::cast(abstract); - auto type_ptr = abstract_tensor->element()->GetTypeTrack(); - MS_ASSERT(typePtr != nullptr); - if (!utils::isa(abstract_tensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr."; + lite::DataInfo data_info; + if (lite::FetchDataFromCNode(cnode, i, fmk_type_, train_flag_, &data_info) != lite::RET_OK) { + MS_LOG(ERROR) << "parse cnode failed."; return lite::RET_ERROR; } - auto shape_vector = utils::cast(abstract_tensor->BuildShape())->shape(); - std::vector dims(shape_vector.begin(), shape_vector.end()); lite::Tensor *tensor = nullptr; - if (type_ptr->type_id() == kObjectTypeTensorType) { - tensor = GetCNodeTensorListVarInput(dims, abstract_tensor); + if (data_info.data_type_ == kObjectTypeTensorType) { + tensor = GetCNodeTensorListVarInput(data_info); } else { - tensor = new (std::nothrow) lite::Tensor(TypeId(type_ptr->type_id()), dims); + tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_); } if (tensor == nullptr) { MS_LOG(ERROR) << "new a lite tensor failed"; @@ -352,27 +340,16 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector shape, - const abstract::AbstractTensorPtr &abstract_tensor) { - MS_ASSERT(abstract_tensor != nullptr); - auto tensor_list = new (std::nothrow) lite::TensorList(shape, {}); +lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(const lite::DataInfo &data_info) { + auto tensor_list = new (std::nothrow) lite::TensorList(data_info.shape_, {}); if (tensor_list == nullptr) { MS_LOG(ERROR) << "new a lite tensor list failed"; return nullptr; } - auto tensor_info = abstract_tensor->GetValueTrack(); - if (tensor_info == nullptr || !utils::isa(tensor_info)) { - delete tensor_list; - MS_LOG(ERROR) << "nsor list abstract is invalid."; - return nullptr; - } - auto tensor_value = tensor_info->cast(); - if (tensor_value->data_c() == nullptr) { - delete tensor_list; - MS_LOG(ERROR) << "cannot get tensor list abstract's info."; - return nullptr; + if (data_info.data_.empty()) { + return tensor_list; } - auto status = tensor_list->Decode(static_cast(tensor_value->data_c())); + auto status = tensor_list->Decode(reinterpret_cast(data_info.data_.data())); if (status != lite::RET_OK) { delete tensor_list; MS_LOG(ERROR) << "decode tensor list failed."; @@ -384,41 +361,78 @@ lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector shape, STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *outputs) { MS_ASSERT(cnode != nullptr); MS_ASSERT(outputs != nullptr); - auto meta_graph = std::make_unique(); - meta_graph->fmkType = fmk_type_; - auto fb_node = std::make_unique(); - lite::AnfExporter anf_exporter; - anf_exporter.set_train_flag(train_flag_); - anf_exporter.SetOpOutputNode(cnode, meta_graph, fb_node.get()); - return ConvertToLiteTensor(meta_graph, fb_node->outputIndex, outputs); + std::vector data_infos; + if (utils::isa(cnode->abstract())) { + auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); + if (tuple == nullptr) { + MS_LOG(ERROR) << "tuple is nullptr"; + return lite::RET_ERROR; + } + auto elements = tuple->elements(); + for (size_t i = 0; i < elements.size(); i++) { + lite::DataInfo data_info; + data_info.node_type_ = lite::NodeType_CNode; + if (train_flag_) { + data_infos.emplace_back(data_info); + if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || CheckPrimitiveType(cnode, prim::kPrimAdam)) { + break; + } + } else { + if (!utils::isa(elements[i])) { + MS_LOG(ERROR) << "abstract is not AbstractTensor"; + return lite::RET_ERROR; + } + auto type = kNumberTypeFloat32; + if (utils::isa(elements[i])) { + auto abstract_tensor = utils::cast(elements[i]); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + type = typePtr->type_id(); + } + data_info.data_type_ = type; + data_infos.emplace_back(data_info); + if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || + CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) { + break; + } + } + } + } else { + lite::DataInfo data_info; + auto type = kNumberTypeFloat32; + if (utils::isa(cnode->abstract())) { + auto abstract_tensor = utils::cast(cnode->abstract()); + auto typePtr = abstract_tensor->element()->GetTypeTrack(); + type = typePtr->type_id(); + } + data_info.data_type_ = type; + data_info.node_type_ = lite::NodeType_CNode; + data_infos.emplace_back(data_info); + } + return ConvertToLiteTensor(data_infos, outputs); } -STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr &meta_graph, - const std::vector &tensor_indexes, +STATUS NodeInferShape::ConvertToLiteTensor(const std::vector &data_infos, std::vector *tensors) { - MS_ASSERT(meta_graph != nullptr); MS_ASSERT(tensors != nullptr); - for (auto index : tensor_indexes) { - auto tensor_t = meta_graph->allTensors.at(index).get(); - auto tensor_shape = tensor_t->dims; - auto tensor_category = lite::TensorCategory(tensor_t->nodeType, tensor_t->dims.size(), TypeId(tensor_t->dataType), - tensor_t->data.size()); + for (auto &data_info : data_infos) { + auto tensor_category = lite::TensorCategory(lite::NodeType(data_info.node_type_), data_info.shape_.size(), + TypeId(data_info.data_type_), data_info.data_.size()); lite::Tensor *tensor = nullptr; - if (tensor_t->dataType != kObjectTypeTensorType) { - tensor = - new (std::nothrow) lite::Tensor(TypeId(tensor_t->dataType), tensor_shape, tensor_t->format, tensor_category); + if (data_info.data_type_ != kObjectTypeTensorType) { + tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_, + (schema::Format)data_info.format_, tensor_category); } else { - tensor = new (std::nothrow) lite::TensorList(tensor_shape, std::vector(), tensor_category); + tensor = new (std::nothrow) lite::TensorList(data_info.shape_, std::vector(), tensor_category); } if (tensor == nullptr) { MS_LOG(ERROR) << "new a lite tensor failed"; return lite::RET_ERROR; } - auto tensor_size = tensor_t->data.size() * sizeof(char); + auto tensor_size = data_info.data_.size(); if (tensor_size > 0) { - if (tensor_t->dataType == kObjectTypeTensorType) { + if (data_info.data_type_ == kObjectTypeTensorType) { auto tensor_list = reinterpret_cast(tensor); - if (tensor_list->Decode(reinterpret_cast(tensor_t->data.data())) != RET_OK) { + if (tensor_list->Decode(reinterpret_cast(data_info.data_.data())) != RET_OK) { MS_LOG(ERROR) << "Decode tensorlist data failed"; return RET_ERROR; } @@ -429,7 +443,7 @@ STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptrdata.data(), tensor_size) != EOK) { + if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) { delete tensor; delete[](tensor_data); MS_LOG(ERROR) << "memcpy error: "; diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.h b/mindspore/lite/tools/optimizer/graph/node_infershape.h index 856d1da3b1..8879ac5b81 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.h +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.h @@ -22,6 +22,7 @@ #include #include "schema/inner/model_generated.h" #include "src/tensor.h" +#include "tools/anf_exporter/fetch_content.h" #include "tools/converter/converter_flags.h" #include "tools/optimizer/common/format_utils.h" @@ -44,10 +45,9 @@ class NodeInferShape { STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector *inputs); STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector *const_ms_inputs); STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector *var_ms_inputs); - lite::Tensor *GetCNodeTensorListVarInput(std::vector shape, const abstract::AbstractTensorPtr &abstract_tensor); + lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info); STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector *outputs); - STATUS ConvertToLiteTensor(const std::unique_ptr &meta_graph, - const std::vector &tensor_indexes, std::vector *tensors); + STATUS ConvertToLiteTensor(const std::vector &data_infos, std::vector *tensors); STATUS SetCNodeAbstract(const std::shared_ptr &cnode, const std::vector &outputs); abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor); abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor); diff --git a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc index 3a134d4f55..af3bfdd1a9 100644 --- a/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc +++ b/mindspore/lite/tools/optimizer/graph/transpose_strategy.cc @@ -67,7 +67,7 @@ AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &fu MS_LOG(ERROR) << "input node is invalid."; return nullptr; } - if (GetTransposePerm(input_cnode->input(kTransposePerm), &trans_perm) != lite::RET_OK) { + if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) { MS_LOG(ERROR) << "transpose perm get failed."; return nullptr; } @@ -142,8 +142,40 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const return can_insert; } +bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr); + auto shape = node_infer_shape_.GetInputShape(cnode, 1); + if (shape.size() != 4) { + if (cnode->size() > 2) { + shape = node_infer_shape_.GetInputShape(cnode, 2); + if (shape.size() != 4 && !shape.empty()) { + return false; + } + } else { + return false; + } + } + if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { + auto prim = GetValueNode(cnode->input(0)); + if (prim->GetAttr(ops::kAxis) == nullptr) { + return false; + } + } + if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { + for (size_t i = 2; i < cnode->size(); ++i) { + if (utils::isa(cnode->input(i))) { + return false; + } + } + if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) { + return false; + } + } + return true; +} + STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { - MS_ASSERT(cnode != nullptr); + MS_ASSERT(func_graph != nullptr && cnode != nullptr); auto shape = node_infer_shape_.GetInputShape(cnode, 1); if (shape.size() != 4) { if (cnode->size() > 2) { @@ -180,6 +212,7 @@ STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNo } else { offsets.push_back(0); } + crop_prim->set_axis(new_axis); crop_prim->set_offsets(offsets); } if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) { @@ -231,7 +264,7 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s if (cnode == nullptr) { return false; } - if (GetTransposePerm(cnode->input(kTransposePerm), &perm) != lite::RET_OK) { + if (GetTransposePerm(cnode, &perm) != lite::RET_OK) { return false; } if (perm == NH2NC) { diff --git a/mindspore/lite/tools/optimizer/graph/transpose_strategy.h b/mindspore/lite/tools/optimizer/graph/transpose_strategy.h index b462d1c577..e4d5865507 100644 --- a/mindspore/lite/tools/optimizer/graph/transpose_strategy.h +++ b/mindspore/lite/tools/optimizer/graph/transpose_strategy.h @@ -44,6 +44,7 @@ class TransposeStrategy { bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info, TransTypePair *trans_insert_info); STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + bool CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); private: STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index); diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc index be410284f7..dd81a9490e 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc @@ -15,12 +15,14 @@ */ #include "tools/optimizer/graph/unify_format_pass.h" +#include +#include #include #include #include "ops/op_utils.h" #include "src/common/common.h" #include "src/common/utils.h" -#include "tools/anf_exporter/anf_exporter.h" +#include "tools/common/tensor_util.h" using mindspore::lite::NCHW_SHAPE; namespace mindspore { @@ -37,6 +39,173 @@ bool IsSpecialType(const CNodePtr &cnode) { } return false; } + +STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node, + std::set *in_nodes, std::set *out_nodes, + std::set *middle_nodes) { + MS_ASSERT(func_graph != nullptr && root_node != nullptr); + MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr); + std::queue queue_nodes; + queue_nodes.push(root_node); + std::queue is_pre_nodes; + is_pre_nodes.push(true); + while (!queue_nodes.empty()) { + auto cur_node = queue_nodes.front(); + auto is_pre_node = is_pre_nodes.front(); + queue_nodes.pop(); + is_pre_nodes.pop(); + if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) { + if (is_pre_node) { + in_nodes->insert(cur_node); + } else { + out_nodes->insert(cur_node); + continue; + } + } + if (middle_nodes->find(cur_node) != middle_nodes->end()) { + continue; + } + if (in_nodes->find(cur_node) == in_nodes->end()) { + middle_nodes->insert(cur_node); + // insert pre nodes. + auto origin_inputs = cur_node->inputs(); + lite::RemoveIfDepend(cur_node); + for (size_t i = 1; i < cur_node->size(); ++i) { + if (!utils::isa(cur_node->input(i))) { + continue; + } + auto cur_node_input = cur_node->input(i)->cast(); + if (middle_nodes->find(cur_node_input) != middle_nodes->end() || + in_nodes->find(cur_node_input) != in_nodes->end()) { + continue; + } + queue_nodes.push(cur_node_input); + is_pre_nodes.push(true); + } + if (CheckIsAllInputsParam(cur_node)) { + in_nodes->insert(cur_node); + } + cur_node->set_inputs(origin_inputs); + } + // insert post nodes + auto cur_node_users = func_graph->manager()->node_users()[cur_node]; + for (auto &cur_node_user : cur_node_users) { + if (!utils::isa(cur_node_user.first)) { + MS_LOG(ERROR) << "post node is not cnode."; + return lite::RET_ERROR; + } + auto cur_node_post = cur_node_user.first->cast(); + if (middle_nodes->find(cur_node_post) != middle_nodes->end() || + out_nodes->find(cur_node_post) != out_nodes->end()) { + continue; + } + queue_nodes.push(cur_node_post); + is_pre_nodes.push(false); + } + if (cur_node_users.empty()) { + out_nodes->insert(cur_node); + } + } + return lite::RET_OK; +} + +bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set &in_nodes, + const std::set &out_nodes, const std::set &middle_nodes) { + MS_ASSERT(func_graph != nullptr); + for (auto &in_cnode : in_nodes) { + std::vector perm; + if (!CheckPrimitiveType(in_cnode, prim::kPrimTranspose) || GetTransposePerm(in_cnode, &perm) != lite::RET_OK || + perm != NH2NC) { + return false; + } + } + for (auto &out_cnode : out_nodes) { + std::vector perm; + if (!CheckPrimitiveType(out_cnode, prim::kPrimTranspose) || GetTransposePerm(out_cnode, &perm) != lite::RET_OK || + perm != NC2NH) { + return false; + } + } + auto &dynamic_ops = GetDynamicFormatOpList(); + TransposeStrategy transpose_strategy; + for (auto &middle_cnode : middle_nodes) { + if (IsSpecialType(middle_cnode)) { + continue; + } + auto middle_node_prim = GetValueNode(middle_cnode->input(0)); + if (!lite::IsContain(dynamic_ops, middle_node_prim->name()) || + !transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) { + return false; + } + } + return true; +} + +void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type, + bool train_flag) { + MS_ASSERT(cnode != nullptr); + if (utils::isa(cnode->input(index))) { + return; + } + lite::DataInfo data_info; + int status; + if (utils::isa(cnode->input(index))) { + status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info); + } else { + status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info); + } + if (status != lite::RET_OK) { + return; + } + if (data_info.shape_.empty() || + (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) { + return; + } + std::vector new_shape; + if (data_info.shape_.size() == 1) { + new_shape = {1, 1, 1, data_info.shape_[0]}; + } else if (data_info.shape_.size() == 2) { + new_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]}; + } else if (data_info.shape_.size() == 3) { + new_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[2]}; + } + auto size = data_info.data_.size() / sizeof(float); + std::vector new_data(size); + auto new_data_ptr = static_cast(new_data.data()); + auto nchw_data = reinterpret_cast(data_info.data_.data()); + // nchw to nhwc + auto batch = new_shape[lite::NCHW_N]; + auto channel = new_shape[lite::NCHW_C]; + auto area = new_shape[lite::NCHW_H] * new_shape[lite::NCHW_W]; + for (auto i = 0; i < batch; i++) { + float *src_batch = nchw_data + i * channel * area; + float *dst_batch = new_data_ptr + i * channel * area; + for (int j = 0; j < area; ++j) { + float *src_area = src_batch + i; + float *dst_area = dst_batch + i * channel; + for (int k = 0; k < channel; ++k) { + dst_area[k] = src_area[k * area]; + } + } + } + auto param_node = func_graph->add_parameter(); + param_node->set_name(cnode->input(index)->fullname_with_scope()); + std::vector shape_vec{new_shape[0], new_shape[2], new_shape[3], new_shape[1]}; + auto tensor_info = lite::CreateTensorInfo(new_data.data(), size * sizeof(float), shape_vec, kNumberTypeFloat32); + if (tensor_info == nullptr) { + MS_LOG(ERROR) << "Create tensor info failed"; + return; + } + status = lite::InitParameterFromTensorInfo(param_node, tensor_info); + if (status != RET_OK) { + MS_LOG(ERROR) << "init parameter from tensor info failed"; + return; + } + auto tr = func_graph->manager()->Transact(); + tr.SetEdge(cnode, index, param_node); + tr.Commit(); + return; +} } // namespace void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) { @@ -79,7 +248,7 @@ bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNo return false; } std::vector post_perm; - if (GetTransposePerm(cnode->input(2), &post_perm) != lite::RET_OK) { + if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) { MS_LOG(ERROR) << "get tanspose perm failed."; return false; } @@ -89,7 +258,7 @@ bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNo if (pre_cnode == nullptr) { return false; } - if (GetTransposePerm(pre_cnode->input(2), &pre_perm) != lite::RET_OK) { + if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) { MS_LOG(ERROR) << "get tanspose perm failed."; return false; } @@ -106,7 +275,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons return lite::RET_OK; } std::vector cur_perm; - if (GetTransposePerm(cnode->input(2), &cur_perm) != lite::RET_OK) { + if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) { MS_LOG(ERROR) << "get transpose perm failed."; return lite::RET_ERROR; } @@ -116,7 +285,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) { std::vector post_trans_perm; auto post_trans_node = post_node->cast(); - if (GetTransposePerm(post_trans_node->input(2), &post_trans_perm) != lite::RET_OK) { + if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) { MS_LOG(ERROR) << "get post transpose node perm failed."; return lite::RET_ERROR; } @@ -218,7 +387,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const MS_ASSERT(trans_insert_info != nullptr); TransTypePair trans_info; auto origin_inputs = cnode->inputs(); - lite::AnfExporter::RemoveIfMakeTuple(cnode); + lite::RemoveIfMakeTuple(cnode); RemoveIfMonad(cnode); if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) { cnode->set_inputs(origin_inputs); @@ -366,8 +535,8 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN prim->AddAttr(kTransDone, MakeValue(true)); TransTypePair trans_info; GetTransNodeFormatType(cnode, &trans_info); - if (!need_reset_ && (trans_info.pre_ == kNONE || trans_info.post_ == kNONE)) { - if (TransTransFusion(func_graph, cnode)) { + if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) { + if (!need_reset_ && TransTransFusion(func_graph, cnode)) { return lite::RET_OK; } std::unordered_map match; @@ -401,6 +570,65 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN return lite::RET_OK; } +STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + std::set *visit_transposes) { + MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr); + auto manager = func_graph->manager(); + MS_ASSERT(manager != nullptr); + std::set middle_nodes; + std::set in_nodes; + std::set out_nodes; + auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, &middle_nodes); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "find an area surrounded by transpose failed."; + return status; + } + for (auto &in_cnode : in_nodes) { + if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) { + visit_transposes->insert(in_cnode); + } + } + if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes)) { + return lite::RET_NO_CHANGE; + } + auto node_list = TopoSort(func_graph->get_return()); + std::vector middle_ops_vec; + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + if (middle_nodes.find(node->cast()) != middle_nodes.end()) { + middle_ops_vec.push_back(node->cast()); + middle_nodes.erase(node->cast()); + } + } + for (auto &in_cnode : in_nodes) { + manager->Replace(in_cnode, in_cnode->input(1)); + } + for (auto &out_cnode : out_nodes) { + manager->Replace(out_cnode, out_cnode->input(1)); + } + for (auto &middle_cnode : middle_ops_vec) { + if (IsSpecialType(middle_cnode)) { + continue; + } + for (size_t i = 1; i < middle_cnode->size(); ++i) { + ConvertNcTensor2Nh(func_graph, middle_cnode, i, fmk_type_, train_flag_); + } + status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "change op attr failed."; + return lite::RET_ERROR; + } + status = node_infer_shape_.InferShape(middle_cnode); + if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { + MS_LOG(ERROR) << "infer shape failed."; + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::unordered_map *match) { MS_ASSERT(func_graph != nullptr && cnode != nullptr); @@ -482,8 +710,8 @@ void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPt MS_ASSERT(cnode != nullptr && sub_graph != nullptr); auto return_node = sub_graph->get_return(); auto origin_input = return_node->inputs(); - lite::AnfExporter::RemoveIfDepend(return_node); - lite::AnfExporter::RemoveIfMakeTuple(return_node); + lite::RemoveIfDepend(return_node); + lite::RemoveIfMakeTuple(return_node); for (size_t i = 1; i < return_node->size(); ++i) { if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) { continue; @@ -511,8 +739,8 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph MS_ASSERT(cnode != nullptr && sub_graph != nullptr); auto return_node = sub_graph->get_return(); auto origin_inputs = return_node->inputs(); - lite::AnfExporter::RemoveIfDepend(return_node); - lite::AnfExporter::RemoveIfMakeTuple(return_node); + lite::RemoveIfDepend(return_node); + lite::RemoveIfMakeTuple(return_node); AbstractBasePtrList abstract_list; bool infer_done = true; for (size_t i = 1; i < return_node->size(); ++i) { @@ -679,6 +907,49 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap return true; } +bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto manager = Manage(func_graph, true); + if (manager == nullptr) { + MS_LOG(ERROR) << "manager is nullptr."; + return false; + } + auto node_list = TopoSort(func_graph->get_return()); + std::set visit_transposes; + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) { + continue; + } + if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { + auto sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + return false; + } + (void)DecreaseTransposeForMultiOp(sub_func_graph); + sub_func_graph = GetValueNode(cnode->input(2)); + if (sub_func_graph == nullptr) { + return false; + } + (void)DecreaseTransposeForMultiOp(sub_func_graph); + } + std::vector perm; + if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK || + perm != NH2NC) { + continue; + } + auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes); + if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { + MS_LOG(ERROR) << "global optimizer failed."; + return false; + } + } + return true; +} + bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); auto manager = Manage(func_graph, true); @@ -774,11 +1045,17 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { MS_LOG(ERROR) << "run framework transpose unify failed."; return false; } - // if input's format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. + // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. if (!DecreaseTransposeForSingleOp(func_graph)) { MS_LOG(ERROR) << "run local trans insert optimizer failed."; return false; } + // if input format of several ops surrounded only by transpose op all can be NHWC, + // we can delete these transpose ops, and at the same time, transform these middle ops. + if (!DecreaseTransposeForMultiOp(func_graph)) { + MS_LOG(ERROR) << "run global trans insert optimizer failed."; + return false; + } return true; } } // namespace opt diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.h b/mindspore/lite/tools/optimizer/graph/unify_format_pass.h index 54f03f41eb..450090ee00 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.h +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.h @@ -48,6 +48,7 @@ class UnifyFormatPass : public Pass { bool ResetFuncGraph(const FuncGraphPtr &func_graph); bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); + bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph); bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector perm, bool before, @@ -55,6 +56,8 @@ class UnifyFormatPass : public Pass { void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info); STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); + STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, + std::set *visit_transposes); STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm);