| @@ -34,7 +34,6 @@ | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/converter/quantizer/bitpacking.h" | |||
| #include "src/tensor.h" | |||
| #include "src/common/utils.h" | |||
| #include "ops/op_utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| @@ -80,159 +79,8 @@ std::list<CNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) { | |||
| } | |||
| return cnodes; | |||
| } | |||
| STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) { | |||
| auto data_type = tensor_info->data_type(); | |||
| if (data_type != kObjectTypeString) { | |||
| MS_LOG(ERROR) << "This function only used for string tensor."; | |||
| return RET_ERROR; | |||
| } | |||
| shape_vector->clear(); | |||
| auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||
| std::string shape_str; | |||
| std::string shape_size_str; | |||
| *offset = 0; | |||
| size_t cnt = 0; | |||
| for (; *offset < tensor_info->Size(); (*offset)++) { | |||
| if (tensor_data[*offset] == ',') { | |||
| (*offset)++; | |||
| break; | |||
| } | |||
| shape_size_str.push_back(tensor_data[*offset]); | |||
| } | |||
| if (*offset == 0) { | |||
| MS_LOG(ERROR) << "string tensor's dim size not found."; | |||
| return RET_ERROR; | |||
| } | |||
| size_t shape_size = std::stoi(shape_size_str); | |||
| for (; *offset < tensor_info->Size(); (*offset)++) { | |||
| if (tensor_data[*offset] == ',') { | |||
| cnt++; | |||
| shape_vector->push_back(std::stoi(shape_str)); | |||
| shape_str.clear(); | |||
| } else { | |||
| shape_str.push_back(tensor_data[*offset]); | |||
| } | |||
| if (cnt == shape_size) { | |||
| (*offset)++; | |||
| break; | |||
| } | |||
| } | |||
| if (shape_vector->empty()) { | |||
| MS_LOG(ERROR) << "string tensor's shape shouldn't be empty."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| schema::Format GetFormatByFmk(int32_t fmk_type) { | |||
| switch (fmk_type) { | |||
| case converter::FmkType_ONNX: | |||
| case lite::converter::FmkType_CAFFE: | |||
| case lite::converter::FmkType_MS: | |||
| return schema::Format_NCHW; | |||
| case lite::converter::FmkType_TF: | |||
| case lite::converter::FmkType_TFLITE: | |||
| return schema::Format_NHWC; | |||
| default: | |||
| MS_LOG(ERROR) << "don't support current fmk: " + fmk_type; | |||
| return static_cast<schema::Format>(fmk_type); | |||
| } | |||
| } | |||
| STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) { | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(typePtr != nullptr); | |||
| *data_type = typePtr->type_id(); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| return RET_OK; | |||
| } | |||
| } // namespace | |||
| void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||
| bool has_make_tuple = false; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| AnfNodePtr input_node = cnode->input(i); | |||
| if (!input_node->isa<CNode>()) { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| continue; | |||
| } | |||
| auto make_tuple_node = utils::cast<CNodePtr>(input_node); | |||
| auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return; | |||
| } | |||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || | |||
| opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { | |||
| has_make_tuple = true; | |||
| for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | |||
| inputs.emplace_back(make_tuple_node->input(j)); | |||
| } | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (has_make_tuple) { | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| } | |||
| void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { | |||
| bool has_depend = false; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| AnfNodePtr inputNode = cnode->input(i); | |||
| if (!inputNode->isa<CNode>()) { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| continue; | |||
| } | |||
| auto depend_node = utils::cast<CNodePtr>(inputNode); | |||
| auto value_node = depend_node->input(0)->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return; | |||
| } | |||
| if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { | |||
| has_depend = true; | |||
| bool mask_out = (depend_node->inputs().size() == 3); | |||
| for (size_t j = 1; j < depend_node->inputs().size(); ++j) { | |||
| AnfNodePtr depend_input_node = depend_node->input(j); | |||
| if (depend_input_node->isa<CNode>()) { | |||
| inputs.emplace_back(depend_input_node); | |||
| if (mask_out) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (has_depend) { | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| } | |||
| int AnfExporter::SetQuantOutputTensorType(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::shared_ptr<mindspore::Primitive> &primitive, | |||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||
| @@ -653,283 +501,58 @@ int AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::shared_ptr<PrimitiveC> &primitive_c, | |||
| int AnfExporter::ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *output_cnode) { | |||
| auto param_node = input_anode->cast<ParameterPtr>(); | |||
| schema::CNodeT *op_node) { | |||
| auto param_node = cnode->input(index)->cast<ParameterPtr>(); | |||
| MS_ASSERT(param_node != nullptr); | |||
| std::string input_name = param_node->fullname_with_scope(); | |||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||
| output_cnode->inputIndex.emplace_back(node_id_map_[param_node->name()]); | |||
| op_node->inputIndex.emplace_back(node_id_map_[param_node->name()]); | |||
| return RET_OK; | |||
| } | |||
| auto schema_tensor = std::make_unique<schema::TensorT>(); | |||
| schema_tensor->format = GetFormatByFmk(meta_graphT->fmkType); | |||
| if (schema_tensor->format != schema::Format_NHWC && schema_tensor->format != schema::Format_NCHW) { | |||
| MS_LOG(ERROR) << "schema tensor format is wrong, " << schema_tensor->format; | |||
| DataInfo data_info; | |||
| if (FetchDataFromParameterNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info) != | |||
| RET_OK) { | |||
| MS_LOG(ERROR) << "parse const node failed."; | |||
| return RET_ERROR; | |||
| } | |||
| // attr weightFormat is only used by conv-like ops' second input | |||
| if (output_cnode->inputIndex.size() == 1 && primitive_c->GetAttr(opt::kWeightFormat) != nullptr) { | |||
| schema_tensor->format = static_cast<schema::Format>(GetValue<int64_t>(primitive_c->GetAttr(opt::kWeightFormat))); | |||
| } | |||
| auto schema_tensor = std::make_unique<schema::TensorT>(); | |||
| schema_tensor->format = static_cast<schema::Format>(data_info.format_); | |||
| schema_tensor->name = param_node->name(); | |||
| ShapeVector shape_vector; | |||
| TypeId data_type; | |||
| auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "get data type and shape from param node failed."; | |||
| return RET_ERROR; | |||
| } | |||
| schema_tensor->dataType = data_type; | |||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()); | |||
| size_t offset = 0; | |||
| if (!shape_vector.empty() && schema_tensor->dataType == kObjectTypeString) { | |||
| status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "get shape vector from string tensor failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| std::vector<int32_t> dims; | |||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | |||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | |||
| schema_tensor->dims = dims; | |||
| if (tensor_info != nullptr && tensor_info->Size() != 0) { | |||
| if (schema_tensor->dataType == kObjectTypeTensorType && shape_vector.empty() && | |||
| meta_graphT->fmkType == converter::FmkType_ONNX) { | |||
| schema_tensor->data.resize(0); | |||
| } else { | |||
| schema_tensor->data.resize(tensor_info->Size() - offset); | |||
| if (EOK != memcpy_s(schema_tensor->data.data(), schema_tensor->data.size(), | |||
| static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { | |||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| schema_tensor->name = input_name; | |||
| QuantParamHolderPtr quant_param_holder = primitive_c->GetAttr("quant_params") == nullptr | |||
| ? nullptr | |||
| : primitive_c->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | |||
| if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && | |||
| schema_tensor->dataType == kNumberTypeInt8) { | |||
| schema_tensor->enableHuffmanCode = true; | |||
| } | |||
| schema_tensor->dims = data_info.shape_; | |||
| schema_tensor->dataType = data_info.data_type_; | |||
| schema_tensor->data = data_info.data_; | |||
| schema_tensor->enableHuffmanCode = data_info.enable_huffman_code_; | |||
| node_id_map_[input_name] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive, | |||
| schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| auto valueAbstract = value_node->abstract(); | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); | |||
| if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| (*schema_tensor)->dataType = typePtr->type_id(); | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| std::vector<int32_t> dims; | |||
| (void)std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(dims), | |||
| [](const int64_t &value) { return static_cast<int32_t>(value); }); | |||
| (*schema_tensor)->dims = dims; | |||
| if (train_flag_ && (*schema_tensor)->dims.empty()) (*schema_tensor)->dims = {1}; | |||
| (*schema_tensor)->nodeType = NodeType_ValueNode; | |||
| auto data = value->cast<tensor::TensorPtr>(); | |||
| (*schema_tensor)->data.resize(data->Size()); | |||
| (*schema_tensor)->format = schema::Format_NHWC; | |||
| (*schema_tensor)->format = GetFormatByFmk(meta_graphT->fmkType); | |||
| if ((*schema_tensor)->format != schema::Format_NHWC && (*schema_tensor)->format != schema::Format_NCHW) { | |||
| MS_LOG(ERROR) << "schema tensor format is wrong, " << (*schema_tensor)->format; | |||
| return RET_ERROR; | |||
| } | |||
| // process weight tensor | |||
| if (data->Size() > 0) { | |||
| if (memcpy_s((*schema_tensor)->data.data(), (*schema_tensor)->data.size(), data->data_c(), data->Size()) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error."; | |||
| return RET_ERROR; | |||
| } | |||
| if (primitive->GetAttr(opt::kWeightFormat) != nullptr) { | |||
| (*schema_tensor)->format = static_cast<schema::Format>(GetValue<int64_t>(primitive->GetAttr(opt::kWeightFormat))); | |||
| } | |||
| } | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| int ret; | |||
| // data of int64 is converted to int32 here. | |||
| (*schema_tensor)->dataType = kNumberTypeInt32; | |||
| (*schema_tensor)->dims = {1}; | |||
| (*schema_tensor)->nodeType = NodeType_ValueNode; | |||
| int real_data = opt::CastToInt(value).front(); | |||
| (*schema_tensor)->data.resize(sizeof(int32_t)); | |||
| ret = memcpy_s((*schema_tensor)->data.data(), sizeof(int32_t), &real_data, sizeof(int32_t)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error."; | |||
| return RET_ERROR; | |||
| } | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); | |||
| return ret; | |||
| } | |||
| void AnfExporter::ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| auto valueAbstract = value_node->abstract(); | |||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||
| (*schema_tensor)->dataType = typePtr->type_id(); | |||
| (*schema_tensor)->dims = {1}; | |||
| (*schema_tensor)->nodeType = NodeType_ValueNode; | |||
| auto data = value->cast<mindspore::BoolImmPtr>(); | |||
| (*schema_tensor)->data.emplace_back(data->value()); | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); | |||
| } | |||
| int AnfExporter::ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, | |||
| schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| auto data = value_node->value()->cast<NumberPtr>(); | |||
| schema_tensor->data.resize(sizeof(int)); | |||
| int number_type = data->number_type(); | |||
| if (EOK != ::memcpy_s(schema_tensor->data.data(), sizeof(int), &number_type, sizeof(int))) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| schema_tensor->dataType = kNumberTypeInt32; | |||
| schema_tensor->dims = {1}; | |||
| schema_tensor->nodeType = NodeType_ValueNode; | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(schema_tensor); | |||
| return RET_OK; | |||
| } | |||
| void AnfExporter::ProcessInt(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| (*schema_tensor)->dataType = kNumberTypeInt32; | |||
| (*schema_tensor)->dims = {1}; | |||
| (*schema_tensor)->nodeType = NodeType_ValueNode; | |||
| (*schema_tensor)->data.emplace_back(kNumberTypeInt32); | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); | |||
| } | |||
| int AnfExporter::ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| int ret = RET_OK; | |||
| auto valueAbstract = value_node->abstract(); | |||
| auto abstractSequnce = utils::cast<abstract::AbstractSequeuePtr>(valueAbstract); | |||
| if (abstractSequnce->isa<abstract::AbstractTuple>()) { | |||
| auto abstractTuple = utils::cast<abstract::AbstractTuplePtr>(valueAbstract); | |||
| auto x_shape_data = abstractTuple->elements(); | |||
| std::vector<int32_t> shape; | |||
| for (std::size_t i = 0; i < abstractTuple->size(); ++i) { | |||
| auto value_track = x_shape_data[i]->GetValueTrack(); | |||
| MS_ASSERT(value_track != nullptr); | |||
| if (value_track->isa<Int32Imm>()) { | |||
| shape.push_back((GetValue<int>(value_track))); | |||
| } else if (value_track->isa<Int64Imm>()) { | |||
| shape.push_back((GetValue<int64_t>(value_track))); | |||
| } else { | |||
| MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| (*schema_tensor)->dataType = kNumberTypeInt32; | |||
| (*schema_tensor)->dims = {static_cast<int32_t>(shape.size())}; | |||
| (*schema_tensor)->nodeType = NodeType_ValueNode; | |||
| (*schema_tensor)->data.resize(shape.size() * sizeof(int)); | |||
| if (!shape.empty()) { | |||
| if (EOK != memcpy_s((*schema_tensor)->data.data(), shape.size() * sizeof(int32_t), shape.data(), | |||
| shape.size() * sizeof(int32_t))) { | |||
| MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed."; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| } | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); | |||
| } | |||
| return ret; | |||
| } | |||
| int AnfExporter::ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(value); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "Input value is not a tensor"; | |||
| return RET_INPUT_PARAM_INVALID; | |||
| } | |||
| auto ret = UpdateTensorTFromTensorInfo(tensor_info, schema_tensor); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "UpdateTensorTFromTensorInfo failed"; | |||
| return ret; | |||
| } | |||
| if (train_flag_ && (*schema_tensor)->dims.empty()) { | |||
| (*schema_tensor)->dims = {1}; | |||
| } | |||
| node_id_map_[value_node->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(*schema_tensor)); | |||
| return ret; | |||
| } | |||
| int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, | |||
| const std::shared_ptr<PrimitiveC> &primitive, | |||
| int AnfExporter::ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *output_cnode) { | |||
| auto value_node = input_anode->cast<ValueNodePtr>(); | |||
| auto schema_tensor = std::make_unique<schema::TensorT>(); | |||
| auto value = value_node->value(); | |||
| int ret = RET_OK; | |||
| if (train_flag_) { | |||
| schema_tensor->name = value_node->fullname_with_scope(); | |||
| } | |||
| if (value->isa<tensor::Tensor>()) { | |||
| ret = ProcessTensor(value_node, &schema_tensor, value, primitive, output_cnode, meta_graphT); | |||
| } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) { | |||
| ret = ProcessInt32OrInt64Imm(value_node, &schema_tensor, value, output_cnode, meta_graphT); | |||
| } else if (value->isa<mindspore::BoolImm>()) { | |||
| ProcessBoolImm(value_node, &schema_tensor, value, output_cnode, meta_graphT); | |||
| } else if (value->isa<mindspore::Int>()) { | |||
| ProcessInt(value_node, &schema_tensor, output_cnode, meta_graphT); | |||
| } else if (value->isa<mindspore::ValueSequeue>()) { | |||
| ret = ProcessValueSequence(value_node, &schema_tensor, value, output_cnode, meta_graphT); | |||
| } else if (value->isa<Number>()) { | |||
| ret = ProcessNumber(value_node, schema_tensor.release(), output_cnode, meta_graphT); | |||
| } else if (value->isa<mindspore::tensor::Tensor>()) { | |||
| ret = ProcessTensorInfo(value_node, &schema_tensor, value, output_cnode, meta_graphT); | |||
| } else if (value->isa<FuncGraph>()) { | |||
| MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; | |||
| schema::CNodeT *op_node) { | |||
| DataInfo data_info; | |||
| auto status = FetchDataFromValueNode(cnode, index, converter::FmkType(meta_graphT->fmkType), train_flag_, &data_info); | |||
| if (status == RET_NO_CHANGE) { | |||
| return RET_OK; | |||
| } else if (value->isa<Monad>()) { | |||
| MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is Monad"; | |||
| return RET_OK; | |||
| } else { | |||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||
| return RET_ERROR; | |||
| } | |||
| return ret; | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "parse value node failed."; | |||
| return status; | |||
| } | |||
| auto schema_tensor = std::make_unique<schema::TensorT>(); | |||
| schema_tensor->name = cnode->input(index)->fullname_with_scope(); | |||
| schema_tensor->format = static_cast<schema::Format>(data_info.format_); | |||
| schema_tensor->dataType = data_info.data_type_; | |||
| schema_tensor->dims = data_info.shape_; | |||
| schema_tensor->data = data_info.data_; | |||
| node_id_map_[cnode->input(index)->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||
| op_node->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||
| meta_graphT->allTensors.emplace_back(std::move(schema_tensor)); | |||
| return RET_OK; | |||
| } | |||
| int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| @@ -954,7 +577,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch | |||
| return ret; | |||
| } | |||
| } else if (input_node->isa<Parameter>()) { | |||
| auto ret = ConvertInputParameter(input_node, primitive_c, meta_graphT, fb_node); | |||
| auto ret = ConvertInputParameter(cnode, i, primitive_c, meta_graphT, fb_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertInputParameter failed"; | |||
| return ret; | |||
| @@ -963,7 +586,7 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch | |||
| is_graph_input = true; | |||
| } | |||
| } else if (input_node->isa<ValueNode>()) { | |||
| auto ret = ConvertInputValueNode(input_node, primitive_c, meta_graphT, fb_node); | |||
| auto ret = ConvertInputValueNode(cnode, i, primitive_c, meta_graphT, fb_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ConvertInputValueNode failed"; | |||
| return RET_ERROR; | |||
| @@ -24,8 +24,10 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| #include "ops/primitive_c.h" | |||
| #include "ir/func_graph.h" | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| #include "tools/converter/converter_context.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::ops::PrimitiveC; | |||
| @@ -44,35 +46,14 @@ class AnfExporter { | |||
| schema::CNodeT *fb_node); | |||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||
| schema::CNodeT *fb_node); | |||
| static void RemoveIfMakeTuple(const CNodePtr &cnode); | |||
| static void RemoveIfDepend(const CNodePtr &cnode); | |||
| protected: | |||
| int ConvertInputCNode(const std::shared_ptr<AnfNode> &input_anode, schema::CNodeT *output_cnode); | |||
| int ConvertInputCNodeCommonOp(const AnfNodePtr &input_anode, schema::CNodeT *output_cnode); | |||
| int ConvertInputParameter(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| int ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_anode, const std::shared_ptr<PrimitiveC> &primitive, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||
| int ProcessTensor(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, const std::shared_ptr<PrimitiveC> &primitive, | |||
| schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int ProcessInt32OrInt64Imm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| void ProcessBoolImm(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| void ProcessInt(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| schema::CNodeT *output_cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int ProcessNumber(const ValueNodePtr &value_node, schema::TensorT *schema_tensor, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int ProcessValueSequence(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int ProcessTensorInfo(const ValueNodePtr &value_node, std::unique_ptr<schema::TensorT> *schema_tensor, | |||
| const std::shared_ptr<Value> &value, schema::CNodeT *output_cnode, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||
| int ConvertInputParameter(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node); | |||
| int ConvertInputValueNode(const CNodePtr &cnode, size_t index, const PrimitivePtr &primitive, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *op_node); | |||
| int SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const size_t &subgraph_index); | |||
| int SetGraphoutputIndex(const CNodePtr &cnode, size_t subgraph_index, | |||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *return_node); | |||
| @@ -0,0 +1,437 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| namespace { | |||
| constexpr size_t kTensorListMinSize = 3 * sizeof(int32_t); | |||
| static const std::unordered_map<int, int> TypeToTypeMap = { | |||
| {kNumberTypeInt, kNumberTypeInt32}, {kNumberTypeUInt, kNumberTypeUInt32}, {kNumberTypeFloat, kNumberTypeFloat32}}; | |||
| STATUS GetShapeVectorFromStringTensor(const tensor::TensorPtr &tensor_info, ShapeVector *shape_vector, size_t *offset) { | |||
| MS_ASSERT(tensor_info != nullptr && shape_vector != nullptr && offset != nullptr); | |||
| auto data_type = tensor_info->data_type(); | |||
| if (data_type != kObjectTypeString) { | |||
| MS_LOG(ERROR) << "This function only used for string tensor."; | |||
| return RET_ERROR; | |||
| } | |||
| shape_vector->clear(); | |||
| auto tensor_data = reinterpret_cast<uint8_t *>(tensor_info->data_c()); | |||
| std::string shape_str; | |||
| std::string shape_size_str; | |||
| *offset = 0; | |||
| size_t cnt = 0; | |||
| for (; *offset < tensor_info->Size(); (*offset)++) { | |||
| if (tensor_data[*offset] == ',') { | |||
| (*offset)++; | |||
| break; | |||
| } | |||
| shape_size_str.push_back(tensor_data[*offset]); | |||
| } | |||
| if (*offset == 0) { | |||
| MS_LOG(ERROR) << "string tensor's dim size not found."; | |||
| return RET_ERROR; | |||
| } | |||
| size_t shape_size = std::stoi(shape_size_str); | |||
| for (; *offset < tensor_info->Size(); (*offset)++) { | |||
| if (tensor_data[*offset] == ',') { | |||
| cnt++; | |||
| shape_vector->push_back(std::stoi(shape_str)); | |||
| shape_str.clear(); | |||
| } else { | |||
| shape_str.push_back(tensor_data[*offset]); | |||
| } | |||
| if (cnt == shape_size) { | |||
| (*offset)++; | |||
| break; | |||
| } | |||
| } | |||
| if (shape_vector->empty()) { | |||
| MS_LOG(ERROR) << "string tensor's shape shouldn't be empty."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int GetFormatByFmk(int32_t fmk_type) { | |||
| switch (fmk_type) { | |||
| case converter::FmkType_ONNX: | |||
| case lite::converter::FmkType_CAFFE: | |||
| case lite::converter::FmkType_MS: | |||
| return mindspore::NCHW; | |||
| case lite::converter::FmkType_TF: | |||
| case lite::converter::FmkType_TFLITE: | |||
| return mindspore::NHWC; | |||
| default: | |||
| return -1; | |||
| } | |||
| } | |||
| STATUS GetDataTypeAndShape(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape_vector) { | |||
| MS_ASSERT(param_node != nullptr && data_type != nullptr && shape_vector != nullptr); | |||
| auto abstract_base = param_node->abstract(); | |||
| if (abstract_base == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract_base)) { | |||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name(); | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(typePtr != nullptr); | |||
| *data_type = typePtr->type_id(); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << param_node->name(); | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| *shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| return RET_OK; | |||
| } | |||
| int FetchFromTensorValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, converter::FmkType fmk_type, | |||
| bool train_flag, DataInfo *data_info) { | |||
| MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | |||
| auto valueAbstract = value_node->abstract(); | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); | |||
| if (abstract_tensor == nullptr || abstract_tensor->element() == nullptr) { | |||
| MS_LOG(ERROR) << "abstract_tensor or abstract_tensor->element() is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| data_info->data_type_ = typePtr->type_id(); | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||
| data_info->shape_ = dims; | |||
| if (train_flag && dims.empty()) { | |||
| data_info->shape_ = {1}; | |||
| } | |||
| auto value = value_node->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| auto data = value->cast<tensor::TensorPtr>(); | |||
| data_info->data_.resize(data->Size()); | |||
| data_info->format_ = GetFormatByFmk(fmk_type); | |||
| if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { | |||
| MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; | |||
| return RET_ERROR; | |||
| } | |||
| // process weight tensor | |||
| if (data->Size() > 0 && memcpy_s(data_info->data_.data(), data->Size(), data->data_c(), data->Size()) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s error."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FetchFromInt32OrInt64ImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { | |||
| MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | |||
| // data of int64 is converted to int32 here. | |||
| data_info->data_type_ = kNumberTypeInt32; | |||
| data_info->shape_ = {1}; | |||
| data_info->data_.resize(sizeof(int32_t)); | |||
| auto value = value_node->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| int real_data = opt::CastToInt(value).front(); | |||
| if (memcpy_s(data_info->data_.data(), sizeof(int32_t), &real_data, sizeof(int32_t)) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FetchFromBoolImmValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { | |||
| MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | |||
| data_info->data_type_ = kNumberTypeBool; | |||
| data_info->shape_ = {1}; | |||
| data_info->data_.resize(sizeof(bool)); | |||
| auto value = value_node->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| auto data = value->cast<mindspore::BoolImmPtr>(); | |||
| auto data_value = data->value(); | |||
| if (memcpy_s(data_info->data_.data(), sizeof(bool), &data_value, sizeof(bool)) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FetchFromNumberValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { | |||
| MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | |||
| data_info->data_type_ = kNumberTypeInt32; | |||
| data_info->shape_ = {1}; | |||
| data_info->data_.resize(sizeof(int)); | |||
| auto data = value_node->value()->cast<NumberPtr>(); | |||
| int number_type = data->number_type(); | |||
| if (TypeToTypeMap.find(number_type) != TypeToTypeMap.end()) { | |||
| number_type = TypeToTypeMap.at(number_type); | |||
| } | |||
| if (memcpy_s(data_info->data_.data(), sizeof(int), &number_type, sizeof(int)) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| int FetchFromSequenceValue(const ValueNodePtr &value_node, const PrimitivePtr &primitive, DataInfo *data_info) { | |||
| MS_ASSERT(value_node != nullptr && primitive != nullptr && data_info != nullptr); | |||
| auto value = value_node->value(); | |||
| MS_ASSERT(value != nullptr); | |||
| std::vector<int32_t> shape; | |||
| auto value_seq = value->cast<ValueSequeuePtr>(); | |||
| MS_ASSERT(value_seq != nullptr); | |||
| if (!value_seq->value().empty()) { | |||
| if (value_seq->value().front()->type()->number_type() == kNumberTypeInt32 || | |||
| value_seq->value().front()->type()->number_type() == kNumberTypeInt) { | |||
| shape = GetValue<std::vector<int>>(value); | |||
| } else if (value_seq->value().front()->type()->number_type() == kNumberTypeInt64) { | |||
| auto origin_value = GetValue<std::vector<int64_t>>(value); | |||
| std::transform(origin_value.begin(), origin_value.end(), std::back_inserter(shape), | |||
| [](int64_t val) { return static_cast<int32_t>(val); }); | |||
| } else { | |||
| MS_LOG(ERROR) << "Value type is ValueSequence is not integer."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| data_info->data_type_ = kNumberTypeInt32; | |||
| data_info->shape_ = {static_cast<int32_t>(shape.size())}; | |||
| data_info->data_.resize(shape.size() * sizeof(int)); | |||
| if (!shape.empty() && memcpy_s(data_info->data_.data(), shape.size() * sizeof(int32_t), shape.data(), | |||
| shape.size() * sizeof(int32_t)) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s data into schema_tensor failed."; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace | |||
| int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, | |||
| DataInfo *data_info) { | |||
| MS_ASSERT(cnode != nullptr && data_info != nullptr); | |||
| auto param_node = cnode->input(index)->cast<ParameterPtr>(); | |||
| data_info->format_ = GetFormatByFmk(fmk_type); | |||
| if (data_info->format_ < 0) { | |||
| MS_LOG(ERROR) << "don't support current fmk: " << fmk_type; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (data_info->format_ != mindspore::NHWC && data_info->format_ != mindspore::NCHW) { | |||
| MS_LOG(ERROR) << "schema tensor format is wrong, " << data_info->format_; | |||
| return RET_ERROR; | |||
| } | |||
| // attr weightFormat is only used by conv-like ops' second input | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { | |||
| data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)); | |||
| } | |||
| ShapeVector shape_vector; | |||
| TypeId data_type; | |||
| auto status = GetDataTypeAndShape(param_node, &data_type, &shape_vector); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "get data type and shape from param node failed."; | |||
| return RET_ERROR; | |||
| } | |||
| data_info->data_type_ = data_type; | |||
| auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()); | |||
| size_t offset = 0; | |||
| if (!shape_vector.empty() && data_type == kObjectTypeString) { | |||
| status = GetShapeVectorFromStringTensor(tensor_info, &shape_vector, &offset); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "get shape vector from string tensor failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||
| data_info->shape_ = dims; | |||
| if (tensor_info != nullptr && tensor_info->Size() != 0) { | |||
| if (data_type != kObjectTypeTensorType || tensor_info->Size() >= kTensorListMinSize) { | |||
| data_info->data_.resize(tensor_info->Size() - offset); | |||
| if (EOK != memcpy_s(data_info->data_.data(), data_info->data_.size(), | |||
| static_cast<uint8_t *>(tensor_info->data_c()) + offset, tensor_info->Size() - offset)) { | |||
| MS_LOG(ERROR) << "memcpy_s failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| QuantParamHolderPtr quant_param_holder = | |||
| prim->GetAttr("quant_params") == nullptr ? nullptr : prim->GetAttr("quant_params")->cast<QuantParamHolderPtr>(); | |||
| if (quant_param_holder != nullptr && quant_param_holder->enable_huffman_code() && data_type == kNumberTypeInt8) { | |||
| data_info->enable_huffman_code_ = true; | |||
| } | |||
| data_info->node_type_ = NodeType_ValueNode; | |||
| return RET_OK; | |||
| } | |||
| int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, | |||
| DataInfo *data_info) { | |||
| MS_ASSERT(cnode != nullptr && data_info != nullptr); | |||
| auto value_node = cnode->input(index)->cast<ValueNodePtr>(); | |||
| auto value = value_node->value(); | |||
| int ret = RET_OK; | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| MS_ASSERT(prim != nullptr); | |||
| if (value->isa<tensor::Tensor>()) { | |||
| ret = FetchFromTensorValue(value_node, prim, fmk_type, train_flag, data_info); | |||
| if (index == 2 && prim->GetAttr(opt::kWeightFormat) != nullptr) { | |||
| data_info->format_ = GetValue<int64_t>(prim->GetAttr(opt::kWeightFormat)); | |||
| } | |||
| } else if (value->isa<mindspore::Int32Imm>() || value->isa<mindspore::Int64Imm>()) { | |||
| ret = FetchFromInt32OrInt64ImmValue(value_node, prim, data_info); | |||
| } else if (value->isa<mindspore::BoolImm>()) { | |||
| ret = FetchFromBoolImmValue(value_node, prim, data_info); | |||
| } else if (value->isa<mindspore::ValueSequeue>()) { | |||
| ret = FetchFromSequenceValue(value_node, prim, data_info); | |||
| } else if (value->isa<Number>()) { | |||
| ret = FetchFromNumberValue(value_node, prim, data_info); | |||
| } else if (value->isa<FuncGraph>()) { | |||
| MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is func_graph"; | |||
| return RET_NO_CHANGE; | |||
| } else if (value->isa<Monad>()) { | |||
| MS_LOG(INFO) << "op name:" << value_node->fullname_with_scope() << " input is Monad"; | |||
| return RET_NO_CHANGE; | |||
| } else { | |||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||
| return RET_ERROR; | |||
| } | |||
| data_info->node_type_ = NodeType_ValueNode; | |||
| return ret; | |||
| } | |||
| int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, | |||
| DataInfo *data_info) { | |||
| MS_ASSERT(cnode != nullptr && data_info != nullptr); | |||
| auto abstract = opt::GetCNodeInputAbstract(cnode, index); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract cnode is nullptr."; | |||
| return RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||
| MS_LOG(ERROR) << "Abstract should be anstract tensor."; | |||
| return RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(typePtr != nullptr); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr."; | |||
| return RET_ERROR; | |||
| } | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||
| data_info->format_ = mindspore::NHWC; | |||
| data_info->data_type_ = type_ptr->type_id(); | |||
| data_info->shape_ = dims; | |||
| data_info->node_type_ = NodeType_CNode; | |||
| if (type_ptr->type_id() == kObjectTypeTensorType) { | |||
| auto tensor_info = abstract_tensor->GetValueTrack(); | |||
| if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) { | |||
| MS_LOG(ERROR) << "tensor info is invalid."; | |||
| return RET_ERROR; | |||
| } | |||
| auto tensor_value = tensor_info->cast<tensor::TensorPtr>(); | |||
| if (tensor_value->Size() >= kTensorListMinSize) { | |||
| data_info->data_.resize(tensor_value->Size()); | |||
| if (memcpy_s(data_info->data_.data(), tensor_value->Size(), tensor_value->data_c(), tensor_value->Size()) != | |||
| EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void RemoveIfDepend(const CNodePtr &cnode) { | |||
| bool has_depend = false; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| AnfNodePtr inputNode = cnode->input(i); | |||
| if (!inputNode->isa<CNode>()) { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| continue; | |||
| } | |||
| auto depend_node = utils::cast<CNodePtr>(inputNode); | |||
| auto value_node = depend_node->input(0)->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return; | |||
| } | |||
| if (value_node->value() != nullptr && opt::CheckPrimitiveType(depend_node, prim::kPrimDepend)) { | |||
| has_depend = true; | |||
| bool mask_out = (depend_node->inputs().size() == 3); | |||
| for (size_t j = 1; j < depend_node->inputs().size(); ++j) { | |||
| AnfNodePtr depend_input_node = depend_node->input(j); | |||
| if (depend_input_node->isa<CNode>()) { | |||
| inputs.emplace_back(depend_input_node); | |||
| if (mask_out) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (has_depend) { | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| } | |||
| void RemoveIfMakeTuple(const CNodePtr &cnode) { | |||
| bool has_make_tuple = false; | |||
| std::vector<AnfNodePtr> inputs; | |||
| inputs.clear(); | |||
| inputs.emplace_back(cnode->input(0)); | |||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||
| AnfNodePtr input_node = cnode->input(i); | |||
| if (!input_node->isa<CNode>()) { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| continue; | |||
| } | |||
| auto make_tuple_node = utils::cast<CNodePtr>(input_node); | |||
| auto value_node = make_tuple_node->input(0)->cast<ValueNodePtr>(); | |||
| if (value_node == nullptr) { | |||
| MS_LOG(ERROR) << "value node is invalid."; | |||
| return; | |||
| } | |||
| if (value_node->value() != nullptr && (opt::CheckPrimitiveType(make_tuple_node, prim::kPrimMakeTuple) || | |||
| opt::CheckPrimitiveType(make_tuple_node, opt::kPrimMakeTupleV2))) { | |||
| has_make_tuple = true; | |||
| for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | |||
| inputs.emplace_back(make_tuple_node->input(j)); | |||
| } | |||
| } else { | |||
| inputs.emplace_back(cnode->input(i)); | |||
| } | |||
| } | |||
| if (has_make_tuple) { | |||
| cnode->set_inputs(inputs); | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_ | |||
| #define MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ir/primitive.h" | |||
| #include "ir/func_graph.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| struct DataInfo { | |||
| bool enable_huffman_code_; | |||
| int format_; | |||
| int data_type_; | |||
| int node_type_; | |||
| std::vector<int> shape_; | |||
| std::vector<uint8_t> data_; | |||
| DataInfo() : enable_huffman_code_(false), format_(0), data_type_(0) {} | |||
| }; | |||
| int FetchDataFromParameterNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, | |||
| DataInfo *data_info); | |||
| int FetchDataFromValueNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, | |||
| DataInfo *data_info); | |||
| int FetchDataFromCNode(const CNodePtr &cnode, size_t index, converter::FmkType fmk_type, bool train_flag, | |||
| DataInfo *data_info); | |||
| void RemoveIfDepend(const CNodePtr &cnode); | |||
| void RemoveIfMakeTuple(const CNodePtr &cnode); | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_ANF_EXPORTER_FETCH_CONTENT_H_ | |||
| @@ -168,9 +168,7 @@ int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converte | |||
| auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false); | |||
| const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>()); | |||
| if (!config->trainModel) { | |||
| auto inne_context_ptr = std::make_shared<lite::InnerContext>(); | |||
| inne_context_ptr->Init(); | |||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr)); | |||
| const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk)); | |||
| } | |||
| auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>(); | |||
| update_conv2d_param_pass->SetFmkType(config->fmk); | |||
| @@ -21,15 +21,10 @@ | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/infershape_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/dropout_node_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h" | |||
| @@ -129,7 +124,6 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| auto old_nodes = GetGraphNodes(); | |||
| Optimizer format_trans_optimizer; | |||
| if (!ctx.trainModel && ctx.fmk != converter::FmkType_ONNX) { | |||
| format_trans_optimizer.AddPass(new (std::nothrow) GlobalFormatTransformPass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| format_trans_optimizer.AddPass(new (std::nothrow) SubgraphNodePass(old_nodes)); | |||
| } | |||
| @@ -1,12 +1,8 @@ | |||
| file(GLOB FUSION_SRC | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/mul_add_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_fusion_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_transpose_fusion_pass.cc | |||
| ) | |||
| set_property(SOURCE ${FUSION_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) | |||
| add_library(fusion_mid OBJECT ${FUSION_SRC}) | |||
| @@ -1,125 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| #define kFormatTransMatchPathLen2 2 | |||
| #define kFormatTransMatchPathLen3 3 | |||
| STATUS FormatTransFusionPass::DefinePattern() { | |||
| // nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc | |||
| { | |||
| auto transpose1 = std::make_shared<PatternOp>(); | |||
| transpose1->id = kFormatTransTranspose1; | |||
| transpose1->types = {PrimitiveType_Transpose}; | |||
| auto transpose2 = std::make_shared<PatternOp>(); | |||
| transpose2->id = kFormatTransTranspose2; | |||
| transpose2->types = {PrimitiveType_Transpose}; | |||
| transpose2->left = transpose1; | |||
| auto pattern = std::make_unique<FusionPattern>(kNc2NhAndNh2NcFusionPattern); | |||
| if (pattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| pattern->AddPatternOp(transpose1); | |||
| pattern->AddPatternOp(transpose2); | |||
| pattern->Finish(); | |||
| this->patterns.emplace_back(pattern.release()); | |||
| } | |||
| // nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw | |||
| { | |||
| auto transpose1 = std::make_shared<PatternOp>(); | |||
| transpose1->id = kFormatTransTranspose1; | |||
| transpose1->types = {PrimitiveType_Transpose}; | |||
| auto passOp = std::make_shared<PatternOp>(); | |||
| passOp->id = kFormatTransPassOp; | |||
| passOp->types = {PrimitiveType_QuantDTypeCast}; | |||
| auto transpose2 = std::make_shared<PatternOp>(); | |||
| transpose2->id = kFormatTransTranspose2; | |||
| transpose2->types = {PrimitiveType_Transpose}; | |||
| passOp->left = transpose2; | |||
| transpose1->left = passOp; | |||
| auto pattern = std::make_unique<FusionPattern>(kNh2NcAndNc2NhPassFusionPattern); | |||
| if (pattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| pattern->AddPatternOp(transpose1); | |||
| pattern->AddPatternOp(passOp); | |||
| pattern->AddPatternOp(transpose2); | |||
| pattern->Finish(); | |||
| this->patterns.emplace_back(pattern.release()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FormatTransFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } | |||
| STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (matchedPath.size() != kFormatTransMatchPathLen2 && matchedPath.size() != kFormatTransMatchPathLen3) { | |||
| MS_LOG(ERROR) << "schema::Format-Transform-Fusion should have " << kFormatTransMatchPathLen2 << " or " | |||
| << kFormatTransMatchPathLen3 << " NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::shared_ptr<Path> srcPath = matchedPath[kFormatTransTranspose1]; | |||
| std::shared_ptr<Path> dstPath = matchedPath[kFormatTransTranspose2]; | |||
| if (srcPath == nullptr || dstPath == nullptr) { | |||
| MS_LOG(ERROR) << "srcPath or dstPath is failed to get"; | |||
| return RET_ERROR; | |||
| } | |||
| auto &srcNode = graph->nodes.at(srcPath->nodeIdx); | |||
| auto &dstNode = graph->nodes.at(dstPath->nodeIdx); | |||
| MS_ASSERT(srcNode != nullptr); | |||
| MS_ASSERT(dstNode != nullptr); | |||
| auto src_perm = GetTransposePerm(graph, srcNode); | |||
| auto dst_perm = GetTransposePerm(graph, dstNode); | |||
| bool isNc2NhAndNh2Nc = src_perm == nchw2nhwc_perm && dst_perm == nhwc2nchw_perm; | |||
| bool isNh2NcAndNc2Nh = src_perm == nhwc2nchw_perm && dst_perm == nchw2nhwc_perm; | |||
| if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) { | |||
| auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateOneWayNode(graph, dstPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,49 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1"; | |||
| constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2"; | |||
| constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; | |||
| constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; | |||
| constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; | |||
| class FormatTransFusionPass : public FusionPass { | |||
| public: | |||
| FormatTransFusionPass() = default; | |||
| ~FormatTransFusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_FORMAT_TRANS_FUSION_PASS_H | |||
| @@ -1,17 +1,12 @@ | |||
| file(GLOB GRAPH_PASS | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_insert_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/dropout_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/trans_format_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/infershape_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_quant_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/infer_quant_param_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/global_format_transform_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/set_unused_quant_param_to_default_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensor_name_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/switch_pass.cc | |||
| @@ -1,461 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define kMinInputNum 1 | |||
| #define kOutputNum 1 | |||
| STATUS FormatTransPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto status = DoModelInputFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; | |||
| return status; | |||
| } | |||
| status = DoNodeInoutFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FormatTransPass::GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type, | |||
| FormatTransNodeType *after_node_type) { | |||
| if (fmk_type_ == converter::FmkType_TFLITE) { // inference by nhwc | |||
| if (!IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *before_node_type = kNHWC2NCHW; | |||
| *after_node_type = kNCHW2NHWC; | |||
| return RET_OK; | |||
| } else if (fmk_type_ == converter::FmkType_CAFFE || fmk_type_ == converter::FmkType_MS || | |||
| fmk_type_ == converter::FmkType_ONNX) { | |||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(node))) { | |||
| return RET_NO_CHANGE; | |||
| } | |||
| *before_node_type = kNCHW2NHWC; | |||
| *after_node_type = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } else if (fmk_type_ == converter::FmkType_TF) { | |||
| if (IsContain(GetNhwcOpList(), GetCNodeTType(node)) && GetFormat(node) == schema::Format_NCHW) { | |||
| *before_node_type = kNCHW2NHWC; | |||
| *after_node_type = kNHWC2NCHW; | |||
| return RET_OK; | |||
| } | |||
| if (IsContain(GetNchwOpList(), GetCNodeTType(node))) { | |||
| *before_node_type = kNHWC2NCHW; | |||
| *after_node_type = kNCHW2NHWC; | |||
| return RET_OK; | |||
| } | |||
| return RET_NO_CHANGE; | |||
| } | |||
| MS_LOG(ERROR) << "Unsupported fmk: " << fmk_type_; | |||
| return RET_ERROR; | |||
| } | |||
| STATUS FormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { | |||
| if (fmk_type_ == converter::FmkType_TF || fmk_type_ == converter::FmkType_TFLITE) { | |||
| return RET_OK; | |||
| } | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert trans node in model input tensor | |||
| if (graph->nodes.empty()) { | |||
| return RET_OK; | |||
| } | |||
| // onnx input format may be nhwc | |||
| if (fmk_type_ == converter::FmkType_ONNX && graph->inputIndex.size() == 1) { | |||
| auto &input_tensor = graph->allTensors.at(graph->inputIndex[0]); | |||
| auto &input_dims = input_tensor->dims; | |||
| if (input_dims.size() == 4 && input_dims[3] != -1 && input_dims[1] == -1) { | |||
| return RET_OK; | |||
| } | |||
| } | |||
| auto graph_input_idxes = graph->inputIndex; | |||
| for (size_t i = 0; i < graph_input_idxes.size(); i++) { | |||
| bool transed = false; | |||
| auto input_idx = graph_input_idxes.at(i); | |||
| auto &tensor = graph->allTensors.at(input_idx); | |||
| if (tensor->dims.size() != kNCHWDimNumber) { | |||
| continue; | |||
| } | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| for (size_t input_index_idx = 0; input_index_idx < (*iter)->inputIndex.size(); input_index_idx++) { | |||
| if ((*iter)->inputIndex.at(input_index_idx) == input_idx) { | |||
| STATUS status = RET_OK; | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, input_index_idx, kNHWC2NCHW, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; | |||
| return status; | |||
| } | |||
| // set first tensor format to nhwc | |||
| auto &trans_node = *(iter - 1); | |||
| MS_ASSERT(trans_node != nullptr); | |||
| MS_ASSERT(trans_node->inputIndex.size() == 1); | |||
| auto &graph_in_tensor = graph->allTensors.at(trans_node->inputIndex.front()); | |||
| graph_in_tensor->format = schema::Format::Format_NHWC; | |||
| // assume parser not reformat shape | |||
| auto old_dims = graph_in_tensor->dims; | |||
| if (!transed) { | |||
| graph_in_tensor->dims = {old_dims[NCHW_N], old_dims[NCHW_H], old_dims[NCHW_W], old_dims[NCHW_C]}; | |||
| transed = true; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| // inference needed inputFormat: | |||
| // conv deconv depth dedepth | |||
| // fp32 NCHW NCHW NCHW NCHW | |||
| // uint8 NCHW ? NCHW ? | |||
| STATUS FormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| // insert before and after the op cal by nchw/nc4hw4 | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| FormatTransNodeType before_node_type = kNCHW2NHWC; | |||
| FormatTransNodeType after_node_type = kNHWC2NCHW; | |||
| STATUS status = RET_OK; | |||
| status = GetInsertFormatTrans(**iter, &before_node_type, &after_node_type); | |||
| if (status == RET_NO_CHANGE) { | |||
| continue; | |||
| } | |||
| if (status != RET_OK) { | |||
| return status; | |||
| } | |||
| auto &node = *iter; | |||
| auto nodeName = node->name; | |||
| if (node->inputIndex.size() < kMinInputNum) { | |||
| MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; | |||
| return RET_ERROR; | |||
| } | |||
| if (node->outputIndex.size() < kOutputNum) { | |||
| MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; | |||
| return RET_ERROR; | |||
| } | |||
| void *attr = node->primitive->value.value; | |||
| if (node->primitive->value.type == schema::PrimitiveType_SpaceToDepth) { | |||
| reinterpret_cast<schema::SpaceToDepthT *>(attr)->format = schema::Format_NHWC; | |||
| } | |||
| if (node->primitive->value.type == schema::PrimitiveType_DepthToSpace) { | |||
| reinterpret_cast<schema::DepthToSpaceT *>(attr)->format = schema::Format_NHWC; | |||
| } | |||
| auto spec_insert_indexes = GetExtNhwcIndexes(); | |||
| auto op_type = GetCNodeTType(**iter); | |||
| if (spec_insert_indexes.find(op_type) != spec_insert_indexes.end()) { | |||
| for (auto insert_index : spec_insert_indexes[op_type]) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, insert_index, before_node_type, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else if (IsContain(GetNhwcAllInputOpList(), op_type)) { | |||
| auto input_size = node->inputIndex.size(); | |||
| if (GetCNodeTType(**iter) == schema::PrimitiveType_ResizeGrad) { | |||
| if ((**iter).primitive->value.AsResizeGrad()->method == schema::ResizeMethod_NEAREST) { | |||
| input_size = 1; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < input_size; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, before_node_type, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNchw2NhwcNode before " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| } else { | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, before_node_type, &status); | |||
| } | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, after_node_type, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter exist_node_iter, InsertPlace place, | |||
| size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code) { | |||
| MS_ASSERT((*exist_node_iter) != nullptr); | |||
| MS_ASSERT(graph != nullptr); | |||
| auto exist_node_name = (*exist_node_iter)->name; | |||
| std::string tile_name; | |||
| if (place == kBefore) { | |||
| tile_name = exist_node_name + "_pre"; | |||
| } else { | |||
| tile_name = exist_node_name + "_post"; | |||
| } | |||
| auto trans_node = std::make_unique<schema::CNodeT>(); | |||
| trans_node->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| trans_node->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| auto perm_tensor = std::make_unique<schema::TensorT>(); | |||
| perm_tensor->dataType = kNumberTypeInt32; | |||
| perm_tensor->dims = {4}; | |||
| std::vector<int> perm; | |||
| if (node_type == kNCHW2NHWC) { | |||
| trans_node->name = "nchw2nhwc_" + tile_name + std::to_string(id_++); | |||
| perm = {0, 2, 3, 1}; | |||
| } else { | |||
| trans_node->name = "nhwc2nchw_" + tile_name + std::to_string(id_++); | |||
| perm = {0, 3, 1, 2}; | |||
| } | |||
| size_t bytes = perm.size() * sizeof(int); | |||
| perm_tensor->data.resize(bytes); | |||
| if (memcpy_s(perm_tensor->data.data(), bytes, perm.data(), bytes) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| } | |||
| perm_tensor->name = trans_node->name + "_perm"; | |||
| OpDefCopyer transpose_op_copyer = [](CNodeT *in_op_def) -> std::unique_ptr<CNodeT> { | |||
| auto new_op_def = std::make_unique<schema::CNodeT>(); | |||
| if (new_op_def == nullptr) { | |||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||
| return nullptr; | |||
| } | |||
| new_op_def->name = in_op_def->name; | |||
| new_op_def->quantType = in_op_def->quantType; | |||
| new_op_def->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (new_op_def->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new PrimitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| new_op_def->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| return new_op_def; | |||
| }; | |||
| int insert_num = 0; | |||
| auto iter = InsertNode(graph, exist_node_iter, place, inout_idx, std::move(trans_node), error_code, &insert_num, | |||
| transpose_op_copyer); | |||
| size_t index = graph->allTensors.size(); | |||
| graph->allTensors.push_back(std::move(perm_tensor)); | |||
| for (int i = insert_num; i > 0; --i) { | |||
| (*(iter - i))->inputIndex.push_back(index); | |||
| } | |||
| return iter; | |||
| } | |||
| int FormatTransPass::GetFormat(const schema::CNodeT &node) { | |||
| switch (node.primitive->value.type) { | |||
| case schema::PrimitiveType_Conv2DFusion: | |||
| return node.primitive->value.AsConv2DFusion()->format; | |||
| case schema::PrimitiveType_Conv2dTransposeFusion: | |||
| return node.primitive->value.AsConv2dTransposeFusion()->format; | |||
| case schema::PrimitiveType_AvgPoolFusion: | |||
| return node.primitive->value.AsAvgPoolFusion()->format; | |||
| case schema::PrimitiveType_MaxPoolFusion: | |||
| return node.primitive->value.AsMaxPoolFusion()->format; | |||
| default: | |||
| return schema::Format_NHWC; | |||
| } | |||
| } | |||
| STATUS FormatTransPass::ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) { | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto type = node->primitive->value.type; | |||
| auto input1_ndim = graph->allTensors.at(node->inputIndex[0])->dims.size(); | |||
| if (input1_ndim != 4) { | |||
| if (node->inputIndex.size() > 1) { | |||
| auto input2_ndim = graph->allTensors.at(node->inputIndex[1])->dims.size(); | |||
| if (input2_ndim != 4 && input2_ndim != 0) { | |||
| MS_LOG(ERROR) << "change op axis only support 4 dims"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } else { | |||
| MS_LOG(DEBUG) << "change op axis only support 4 dims"; | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| if (type == schema::PrimitiveType_Concat) { | |||
| MS_ASSERT(node->primitive->value.AsConcat() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsConcat()->axis; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsConcat() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsConcat() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsConcat()->axis = axis_map[origin_axis < 0 ? origin_axis + 4 : origin_axis]; | |||
| } | |||
| if (type == schema::PrimitiveType_Split) { | |||
| MS_ASSERT(node->primitive->value.AsSplit() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsSplit()->axis; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsSplit() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsSplit() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| node->primitive->value.AsSplit()->axis = axis_map[origin_axis]; | |||
| } | |||
| if (type == schema::PrimitiveType_Crop) { | |||
| MS_ASSERT(node->primitive->value.AsCrop() != nullptr); | |||
| auto origin_axis = node->primitive->value.AsCrop()->axis; | |||
| auto offsets = node->primitive->value.AsCrop()->offsets; | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| if (node->primitive->value.AsCrop() == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| // nchw->nhwc,offsets need pad 0; | |||
| if (axis_map[origin_axis] == 0) { | |||
| offsets = {offsets[0], offsets[2], offsets[3], offsets[1]}; | |||
| } else if (axis_map[origin_axis] == 1 || axis_map[origin_axis] == 2) { | |||
| // orgin_axis = 2 or orgin_axis = 3 | |||
| offsets.push_back(0); | |||
| } else if (axis_map[origin_axis] == -1) { | |||
| // origin_axis = 1 | |||
| offsets = {offsets[1], offsets[2], offsets[0]}; | |||
| } else { | |||
| // axis error | |||
| MS_LOG(ERROR) << "Crop error"; | |||
| return RET_ERROR; | |||
| } | |||
| node->primitive->value.AsCrop()->axis = axis_map[origin_axis]; | |||
| node->primitive->value.AsCrop()->offsets = offsets; | |||
| } | |||
| if (type == schema::PrimitiveType_SliceFusion || type == schema::PrimitiveType_StridedSlice) { | |||
| return ChangeOpSliceAndStridedSlice(graph, node); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| void FormatTransPass::TransformAttrByAxes(int *origin_attr, int *axes, int element_size) { | |||
| if (origin_attr == nullptr || axes == nullptr || element_size == 0) { | |||
| return; | |||
| } | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| std::vector<int> cur_attr; | |||
| for (int dim = 0; dim < 4; ++dim) { | |||
| for (int index = 0; index < element_size; ++index) { | |||
| int nhwc_dim = axis_map[axes[index] < 0 ? axes[index] + 4 : axes[index]]; | |||
| if (nhwc_dim == dim || (nhwc_dim + 4) == dim) { | |||
| cur_attr.push_back(origin_attr[index]); | |||
| } | |||
| } | |||
| } | |||
| for (int index = 0; index < element_size; ++index) { | |||
| origin_attr[index] = cur_attr[index]; | |||
| } | |||
| } | |||
| void FormatTransPass::TransformOpAxisAttr(int *origin_axis, int element_size) { | |||
| if (origin_axis == nullptr || element_size == 0) { | |||
| return; | |||
| } | |||
| auto axis_map = GetNc2NhAxisMap(); | |||
| std::vector<int> new_axis; | |||
| for (int i = 0; i < element_size; ++i) { | |||
| int axis = axis_map[origin_axis[i]]; | |||
| axis = axis < 0 ? axis + 4 : axis; | |||
| new_axis.push_back(axis); | |||
| } | |||
| std::sort(new_axis.begin(), new_axis.end()); | |||
| for (int i = 0; i < element_size; ++i) { | |||
| origin_axis[i] = new_axis[i]; | |||
| } | |||
| } | |||
| STATUS FormatTransPass::ChangeOpSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) { | |||
| auto attr = node->primitive->value.AsSliceFusion(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "node->primitive->value.AsSliceFusion() is nullptr."; | |||
| return RET_NULL_PTR; | |||
| } | |||
| // transform attr | |||
| if (node->inputIndex.size() < 2) { | |||
| MS_LOG(ERROR) << "slice input is error"; | |||
| return RET_ERROR; | |||
| } | |||
| for (size_t index = 1; index < node->inputIndex.size(); ++index) { | |||
| if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; | |||
| std::vector<int> axes; | |||
| auto axes_attr = attr->axes; | |||
| if (axes_attr.empty()) { | |||
| for (int index = 0; index < element_num; ++index) { | |||
| axes.push_back(index); | |||
| } | |||
| } else { | |||
| std::transform(axes_attr.begin(), axes_attr.end(), std::back_inserter(axes), | |||
| [](int64_t val) { return static_cast<int>(val); }); | |||
| } | |||
| for (size_t index = 1; index < node->inputIndex.size(); ++index) { | |||
| TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()), | |||
| reinterpret_cast<int *>(axes.data()), element_num); | |||
| } | |||
| TransformOpAxisAttr(axes.data(), element_num); | |||
| attr->axes.clear(); | |||
| for (int i = 0; i < element_num; ++i) { | |||
| attr->axes.push_back(static_cast<int64_t>(axes[i])); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FormatTransPass::ChangeOpStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node) { | |||
| // onnx input size is equal to 5 always. | |||
| if (node->inputIndex.size() != 5) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| if (node->inputIndex.size() == 5) { | |||
| for (int index = 1; index < 5; ++index) { | |||
| if (graph->allTensors[node->inputIndex[index]]->data.data() == nullptr) { | |||
| return RET_NOT_SUPPORT; | |||
| } | |||
| } | |||
| int element_num = graph->allTensors[node->inputIndex[1]]->dims[0]; | |||
| auto axes = graph->allTensors[node->inputIndex[3]]->data; | |||
| for (int index = 1; index < 5; ++index) { | |||
| if (index == 3) { | |||
| continue; | |||
| } | |||
| TransformAttrByAxes(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[index]]->data.data()), | |||
| reinterpret_cast<int *>(axes.data()), element_num); | |||
| } | |||
| TransformOpAxisAttr(reinterpret_cast<int *>(graph->allTensors[node->inputIndex[3]]->data.data()), element_num); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FormatTransPass::ChangeOpSliceAndStridedSlice(schema::MetaGraphT *graph, | |||
| const std::unique_ptr<schema::CNodeT> &node) { | |||
| auto type = node->primitive->value.type; | |||
| if (type == schema::PrimitiveType_StridedSlice) { | |||
| return ChangeOpStridedSlice(graph, node); | |||
| } | |||
| if (type == schema::PrimitiveType_SliceFusion) { | |||
| return ChangeOpSlice(graph, node); | |||
| } | |||
| return RET_ERROR; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,76 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H | |||
| #define MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H | |||
| #include <memory> | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW, kNONE }; | |||
| class FormatTransPass : public GraphPass { | |||
| public: | |||
| FormatTransPass() : id_(0) {} | |||
| ~FormatTransPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| void set_quant_type(QuantType quant_type) { this->quant_type_ = quant_type; } | |||
| void set_fmk_type(converter::FmkType fmk_type) { this->fmk_type_ = fmk_type; } | |||
| protected: | |||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *in_op_def, NodeIter exist_node_iter, InsertPlace place, | |||
| size_t inout_idx, FormatTransNodeType node_type, STATUS *error_code); | |||
| STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| private: | |||
| STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); | |||
| STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); | |||
| void TransformAttrByAxes(int *origin_attr, int *axes, int element_size); | |||
| void TransformOpAxisAttr(int *origin_axis, int element_size); | |||
| STATUS ChangeOpSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| STATUS ChangeOpStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| STATUS ChangeOpSliceAndStridedSlice(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNodeT> &node); | |||
| int GetFormat(const schema::CNodeT &); | |||
| STATUS GetInsertFormatTrans(const schema::CNodeT &node, FormatTransNodeType *before_node_type, | |||
| FormatTransNodeType *after_node_type); | |||
| protected: | |||
| size_t id_ = 0; | |||
| converter::FmkType fmk_type_ = converter::FmkType_TF; | |||
| private: | |||
| QuantType quant_type_ = QuantType_QUANT_NONE; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H | |||
| @@ -1,223 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/global_format_transform_pass.h" | |||
| #include <algorithm> | |||
| #include "third_party/securec/include/securec.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| std::set<size_t> need_del_nodes; | |||
| std::set<size_t> need_trans_format_nodes; | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto type = node->primitive->value.type; | |||
| if (type != PrimitiveType_Transpose) { | |||
| continue; | |||
| } | |||
| if (GetTransposePerm(graph, node) != nchw2nhwc_perm) { | |||
| continue; | |||
| } | |||
| std::vector<size_t> pre_nh2nc_nodes; | |||
| std::vector<size_t> pre_not_trans_nodes; | |||
| auto status = FindPreNh2NcNodes(graph, iter - graph->nodes.begin(), &pre_nh2nc_nodes, &pre_not_trans_nodes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "GenNewScaleTensor failed: " << status; | |||
| return status; | |||
| } | |||
| std::copy(pre_nh2nc_nodes.begin(), pre_nh2nc_nodes.end(), std::inserter(need_del_nodes, need_del_nodes.end())); | |||
| std::copy(pre_not_trans_nodes.begin(), pre_not_trans_nodes.end(), | |||
| std::inserter(need_trans_format_nodes, need_trans_format_nodes.end())); | |||
| if (!pre_nh2nc_nodes.empty()) { | |||
| need_del_nodes.insert(iter - graph->nodes.begin()); | |||
| } | |||
| } | |||
| if (need_del_nodes.empty()) { | |||
| return RET_OK; | |||
| } | |||
| for (auto del_node_index : need_del_nodes) { | |||
| auto node_name = graph->nodes.at(del_node_index)->name; | |||
| auto status = IsolateOneWayNode(graph, del_node_index); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Isolate Node failed, node: " << node_name << ", error: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| auto status = TransWeightToNhwc(graph, need_trans_format_nodes); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "trans weight to nhwc failed"; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS ConvertNcTensor2Nh(TensorT *tensor, const std::vector<int> &pad_dims) { | |||
| if (pad_dims.size() != 4) { | |||
| MS_LOG(ERROR) << "pad dims error"; | |||
| return RET_ERROR; | |||
| } | |||
| auto batch = pad_dims[NCHW_N]; | |||
| auto channel = pad_dims[NCHW_C]; | |||
| auto area = pad_dims[NCHW_H] * pad_dims[NCHW_W]; | |||
| auto size = batch * channel * area; | |||
| auto new_nhwc_data = new (std::nothrow) float[size]; | |||
| if (new_nhwc_data == nullptr) { | |||
| MS_LOG(ERROR) << "create new nhwc data failed"; | |||
| delete[] new_nhwc_data; | |||
| return RET_ERROR; | |||
| } | |||
| if (memset_s(new_nhwc_data, sizeof(float) * size, 0, sizeof(float) * size) != EOK) { | |||
| MS_LOG(ERROR) << "create new nhwc data failed"; | |||
| delete[] new_nhwc_data; | |||
| return RET_ERROR; | |||
| } | |||
| auto nchw_data = reinterpret_cast<float *>(tensor->data.data()); | |||
| // nchw to nhwc | |||
| for (auto i = 0; i < batch; i++) { | |||
| float *src_batch = nchw_data + i * channel * area; | |||
| float *dst_batch = new_nhwc_data + i * channel * area; | |||
| for (int j = 0; j < area; ++j) { | |||
| float *src_area = src_batch + i; | |||
| float *dst_area = dst_batch + i * channel; | |||
| for (int k = 0; k < channel; ++k) { | |||
| dst_area[k] = src_area[k * area]; | |||
| } | |||
| } | |||
| } | |||
| if (memcpy_s(nchw_data, tensor->data.size(), new_nhwc_data, sizeof(float) * size) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed"; | |||
| delete[] new_nhwc_data; | |||
| return RET_ERROR; | |||
| } | |||
| delete[] new_nhwc_data; | |||
| return RET_OK; | |||
| } | |||
| STATUS GlobalFormatTransformPass::TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (pre_not_trans_nodes.empty()) { | |||
| return RET_OK; | |||
| } | |||
| for (auto index : pre_not_trans_nodes) { | |||
| auto &cur_node = graph->nodes.at(index); | |||
| // need change axis from nchw to nhwc like concat,slice | |||
| auto ret = ChangeOpAxis(graph, cur_node); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ChangeOpAxis error"; | |||
| return ret; | |||
| } | |||
| auto node_input_indexs = cur_node->inputIndex; | |||
| for (auto input_index : node_input_indexs) { | |||
| // weight data need trans nhwc layerout | |||
| if (!IsContain(graph->inputIndex, input_index) && | |||
| graph->allTensors.at(input_index)->nodeType == NodeType_ValueNode) { | |||
| auto &weight_tensor = graph->allTensors.at(input_index); | |||
| auto origin_dims = weight_tensor->dims; | |||
| weight_tensor->format = Format_NHWC; | |||
| if (origin_dims.size() > 4) { | |||
| MS_LOG(ERROR) << "tensor origin tensor size error"; | |||
| return RET_ERROR; | |||
| } | |||
| if (origin_dims.empty()) { | |||
| continue; | |||
| } | |||
| auto pad_dims = origin_dims; | |||
| if (origin_dims.size() == 1) { | |||
| pad_dims = {1, 1, 1, origin_dims[0]}; | |||
| } else if (origin_dims.size() == 2) { | |||
| pad_dims = {1, 1, origin_dims[0], origin_dims[1]}; | |||
| } else if (origin_dims.size() == 3) { | |||
| pad_dims = {1, origin_dims[0], origin_dims[1], origin_dims[2]}; | |||
| } | |||
| if (ConvertNcTensor2Nh(weight_tensor.get(), pad_dims) != RET_OK) { | |||
| MS_LOG(ERROR) << "Convert nchw to nhwc failed"; | |||
| return RET_ERROR; | |||
| } | |||
| weight_tensor->dims = {pad_dims[NCHW_N], pad_dims[NCHW_H], pad_dims[NCHW_W], pad_dims[NCHW_C]}; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, | |||
| std::vector<size_t> *pre_nh2nc_nodes, | |||
| std::vector<size_t> *pre_not_trans_nodes) { | |||
| MS_ASSERT(graph != nullptr); | |||
| std::vector<size_t> bfs_queue = {nc2nh_index}; | |||
| // find pre node nh2nc start nodes | |||
| while (!bfs_queue.empty()) { | |||
| auto cur_node_index = bfs_queue.back(); | |||
| auto &cur_node = graph->nodes.at(cur_node_index); | |||
| bfs_queue.pop_back(); | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *cur_node); | |||
| for (auto input_node_index : input_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > input_node_index); | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| auto node_type = pre_node->primitive->value.type; | |||
| if (node_type == schema::PrimitiveType_Transpose && GetTransposePerm(graph, pre_node) == nhwc2nchw_perm) { | |||
| if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { | |||
| pre_nh2nc_nodes->emplace_back(input_node_index); | |||
| } | |||
| } else if (IsContain(GetInsertOpList(), node_type)) { | |||
| if (!IsContain(bfs_queue, input_node_index)) { | |||
| bfs_queue.emplace_back(input_node_index); | |||
| } | |||
| auto pre_node_output_indexs = GetOutputNodeIdx(*graph, *pre_node); | |||
| if (pre_node_output_indexs.size() != 1) { | |||
| if (node_type == schema::PrimitiveType_Activation || node_type == schema::PrimitiveType_Concat) { | |||
| pre_nh2nc_nodes->clear(); | |||
| pre_not_trans_nodes->clear(); | |||
| return RET_OK; | |||
| } | |||
| for (auto pre_node_output_index : pre_node_output_indexs) { | |||
| MS_ASSERT(graph->nodes.size() > pre_node_output_index); | |||
| if (graph->nodes.at(pre_node_output_index)->primitive->value.type == schema::PrimitiveType_PadFusion) { | |||
| pre_nh2nc_nodes->clear(); | |||
| pre_not_trans_nodes->clear(); | |||
| return RET_OK; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| pre_nh2nc_nodes->clear(); | |||
| pre_not_trans_nodes->clear(); | |||
| return RET_OK; | |||
| } | |||
| if (!IsContain(*pre_not_trans_nodes, cur_node_index) && cur_node_index != nc2nh_index) { | |||
| pre_not_trans_nodes->emplace_back(cur_node_index); | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,49 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H | |||
| #define MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H | |||
| #include <unordered_map> | |||
| #include <set> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| using mindspore::schema::TensorT; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class GlobalFormatTransformPass : public FormatTransPass { | |||
| public: | |||
| GlobalFormatTransformPass() = default; | |||
| ~GlobalFormatTransformPass() override = default; | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| protected: | |||
| STATUS TransWeightToNhwc(MetaGraphT *graph, const std::set<size_t> &pre_not_trans_nodes); | |||
| STATUS FindPreNh2NcNodes(MetaGraphT *graph, size_t nc2nh_index, std::vector<size_t> *to_do_insert_nodes, | |||
| std::vector<size_t> *pre_not_trans_nodes); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_BATCHNORM_GLOBAL_FORMAT_TRANSFORM_PASS_H | |||
| @@ -1,193 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| bool IsInOutCanFusion(schema::MetaGraphT *graph, const std::vector<size_t> &node_indexes, size_t *has_trans_count, | |||
| FormatTransNodeType *trans_type) { | |||
| for (auto input_node_index : node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > input_node_index); | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| MS_ASSERT(pre_node->primitive != nullptr); | |||
| MS_ASSERT(pre_node->primitive->value != nullptr); | |||
| if (*trans_type == kNONE) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| auto perm = GetTransposePerm(graph, pre_node); | |||
| if (perm == nchw2nhwc_perm) { | |||
| *trans_type = kNCHW2NHWC; | |||
| } else if (perm == nhwc2nchw_perm) { | |||
| *trans_type = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| (*has_trans_count)++; | |||
| } | |||
| } else { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| auto cur_type = kNONE; | |||
| auto perm = GetTransposePerm(graph, pre_node); | |||
| if (perm == nchw2nhwc_perm) { | |||
| cur_type = kNCHW2NHWC; | |||
| } else if (perm == nhwc2nchw_perm) { | |||
| cur_type = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| if (*trans_type != cur_type) { | |||
| return false; | |||
| } else { | |||
| (*has_trans_count)++; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||
| pre_type_ = kNONE; | |||
| size_t has_trans_count = 0; | |||
| if (!IsInOutCanFusion(graph, input_node_indexes, &has_trans_count, &pre_type_)) { | |||
| return false; | |||
| } | |||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | |||
| post_type_ = kNONE; | |||
| if (!IsInOutCanFusion(graph, output_node_indexes, &has_trans_count, &post_type_)) { | |||
| return false; | |||
| } | |||
| if (pre_type_ == kNONE && post_type_ == kNONE) { | |||
| return false; | |||
| } | |||
| auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); | |||
| auto total_node_count = input_node_indexes.size() + output_size; | |||
| size_t half_count = total_node_count / 2; | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Activation) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| MS_ASSERT(node->primitive->value != nullptr); | |||
| MS_ASSERT(node->primitive->value.AsActivation() != nullptr); | |||
| if (node->primitive->value.AsActivation() != nullptr && | |||
| node->primitive->value.AsActivation()->activation_type == schema::ActivationType_LEAKY_RELU) { | |||
| return has_trans_count >= half_count; | |||
| } | |||
| } | |||
| if (GetCNodeTType(*node) == schema::PrimitiveType_Split) { | |||
| return has_trans_count >= half_count; | |||
| } | |||
| return has_trans_count > half_count; | |||
| } | |||
| STATUS TransOpInsertPass::FindOutTransType() { | |||
| pre_insert_trans_type_ = kNHWC2NCHW; | |||
| post_insert_trans_type_ = kNHWC2NCHW; | |||
| if (pre_type_ == kNONE && post_type_ != kNONE) { | |||
| pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else if (pre_type_ != kNONE && post_type_ == kNONE) { | |||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } else if (pre_type_ == kNONE && post_type_ == kNONE) { | |||
| MS_ASSERT(false); | |||
| } else { | |||
| if (pre_type_ == post_type_) { | |||
| MS_LOG(ERROR) << "Unknown error"; | |||
| return RET_ERROR; | |||
| } | |||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS TransOpInsertPass::Run(schema::MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| bool changed = true; | |||
| int run_counts = 0; | |||
| std::vector<CNodeT *> has_insert_nodes; | |||
| while (changed && run_counts < 10) { | |||
| changed = false; | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| if (node == nullptr || node->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "node or primitive null"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| auto type = node->primitive->value.type; | |||
| if (IsContain(has_insert_nodes, node.get()) || !IsContain(GetInsertOpList(), type)) { | |||
| continue; | |||
| } | |||
| auto node_name = node->name; | |||
| if (!CanFusion(graph, node)) { | |||
| continue; | |||
| } | |||
| auto ret = FindOutTransType(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "FindOutTransType error"; | |||
| return ret; | |||
| } | |||
| ret = ChangeOpAxis(graph, node); | |||
| if (ret == RET_NOT_SUPPORT) { | |||
| MS_LOG(INFO) << "not support to ChangeOpAxis"; | |||
| return RET_OK; | |||
| } else if (ret != RET_OK) { | |||
| MS_LOG(INFO) << "no need to ChangeOpAxis"; | |||
| return ret; | |||
| } | |||
| has_insert_nodes.push_back(node.get()); | |||
| STATUS status = RET_OK; | |||
| auto input_tensor_size = (*iter)->inputIndex.size(); | |||
| for (size_t i = 0; i < input_tensor_size; i++) { | |||
| auto &input_tensor = graph->allTensors.at((*iter)->inputIndex[i]); | |||
| if (input_tensor->nodeType == NodeType_ValueNode && input_tensor->dims.size() < 4) { | |||
| continue; | |||
| } | |||
| iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed"; | |||
| return status; | |||
| } | |||
| } | |||
| auto output_tensor_size = (*iter)->outputIndex.size(); | |||
| for (size_t i = 0; i < output_tensor_size; i++) { | |||
| iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed"; | |||
| return status; | |||
| } | |||
| } | |||
| changed = true; | |||
| } | |||
| run_counts++; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,52 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||
| #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TransOpInsertPass : public FormatTransPass { | |||
| public: | |||
| TransOpInsertPass() : FormatTransPass() {} | |||
| ~TransOpInsertPass() override = default; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| private: | |||
| bool CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node); | |||
| STATUS FindOutTransType(); | |||
| private: | |||
| FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; | |||
| FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; | |||
| FormatTransNodeType pre_type_ = kNONE; | |||
| std::vector<int> pre_perm_; | |||
| FormatTransNodeType post_type_ = kNONE; | |||
| std::vector<int> post_perm_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||
| @@ -1,52 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" | |||
| #include <vector> | |||
| #include "src/common/log_adapter.h" | |||
| #include "include/errorcode.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "src/tensor.h" | |||
| using mindspore::lite::Tensor; | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| STATUS TransOpRemovePass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto type = node->primitive->value.type; | |||
| auto perm = GetTransposePerm(graph, node); | |||
| if (type == schema::PrimitiveType_Transpose && (perm == nchw2nhwc_perm || perm == nhwc2nchw_perm)) { | |||
| auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); | |||
| // less than 4 dims can delete | |||
| if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { | |||
| auto status = IsolateOneWayNode(graph, node.get(), true); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << node->name.c_str() << ", error: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,40 +0,0 @@ | |||
| /** | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H | |||
| #define MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/optimizer.h" | |||
| using mindspore::schema::TensorT; | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class TransOpRemovePass : public GraphPass { | |||
| public: | |||
| TransOpRemovePass() = default; | |||
| ~TransOpRemovePass() = default; | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_TRANS_FORMAT_REMOVE_PASS_H | |||
| @@ -28,14 +28,15 @@ | |||
| #include "ops/concat.h" | |||
| #include "ops/crop.h" | |||
| #include "ops/depth_to_space.h" | |||
| #include "ops/fused_batch_norm.h" | |||
| #include "ops/fusion/activation.h" | |||
| #include "ops/fusion/add_fusion.h" | |||
| #include "ops/fused_batch_norm.h" | |||
| #include "ops/fusion/avg_pool_fusion.h" | |||
| #include "ops/fusion/conv2d_backprop_input_fusion.h" | |||
| #include "ops/fusion/conv2d_backprop_filter_fusion.h" | |||
| #include "ops/fusion/conv2d_fusion.h" | |||
| #include "ops/fusion/conv2d_transpose_fusion.h" | |||
| #include "ops/fusion/div_fusion.h" | |||
| #include "ops/fusion/max_pool_fusion.h" | |||
| #include "ops/fusion/mul_fusion.h" | |||
| #include "ops/fusion/pow_fusion.h" | |||
| @@ -61,6 +62,7 @@ | |||
| #include "ops/space_to_depth.h" | |||
| #include "ops/split.h" | |||
| #include "ops/strided_slice.h" | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -96,9 +98,9 @@ static const std::unordered_map<std::string, std::vector<size_t>> NCHWOpMap = {{ | |||
| // a certain op whose input's format is not fixed. | |||
| static const std::vector<std::string> DynamicFormatOpList = { | |||
| ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNamePowFusion, ops::kNameStridedSlice, | |||
| ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, ops::kNameCrop, | |||
| ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast}; | |||
| ops::kNameEltwise, ops::kNameActivation, ops::kNameConcat, ops::kNameDivFusion, ops::kNamePowFusion, | |||
| ops::kNameStridedSlice, ops::kNameAddFusion, ops::kNameAddN, ops::kNameSplit, ops::kNameSliceFusion, | |||
| ops::kNameCrop, ops::kNameMulFusion, ops::kNameMaximum, ops::kNameActivationGrad, ops::kNameQuantDTypeCast}; | |||
| static const std::unordered_map<int, int> NC2NHAxisMap = {{0, 0}, {1, 3}, {2, 1}, {3, 2}}; | |||
| @@ -120,33 +122,34 @@ Format GetFormat(const CNodePtr &cnode) { | |||
| return format; | |||
| } | |||
| STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm) { | |||
| STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm) { | |||
| MS_ASSERT(perm_node != nullptr); | |||
| if (!utils::isa<ParameterPtr>(perm_node)) { | |||
| return lite::RET_OK; | |||
| if (cnode->size() != 3) { | |||
| MS_LOG(ERROR) << "transpose op input size must be three."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto perm_param = perm_node->cast<ParameterPtr>(); | |||
| if (!perm_param->has_default() || perm_param->default_param() == nullptr) { | |||
| if (utils::isa<CNodePtr>(cnode->input(2))) { | |||
| return lite::RET_OK; | |||
| } | |||
| auto tensor_info = perm_param->default_param()->cast<tensor::TensorPtr>(); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "default param is not a tensor."; | |||
| return lite::RET_ERROR; | |||
| lite::DataInfo data_info; | |||
| int status; | |||
| if (utils::isa<ParameterPtr>(cnode->input(2))) { | |||
| status = lite::FetchDataFromParameterNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info); | |||
| } else { | |||
| status = lite::FetchDataFromValueNode(cnode, 2, lite::converter::FmkType_MS, false, &data_info); | |||
| } | |||
| if (tensor_info->data_type() != kNumberTypeInt && tensor_info->data_type() != kNumberTypeInt32) { | |||
| MS_LOG(ERROR) << "data type is error, which is " << tensor_info->data_type(); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch transpose perm data failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto tensor_shape = tensor_info->shape(); | |||
| if (tensor_shape.empty()) { | |||
| return lite::RET_OK; | |||
| } | |||
| if (tensor_shape.size() > 1) { | |||
| if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) || | |||
| data_info.shape_.size() != 1) { | |||
| MS_LOG(ERROR) << "transpose perm data is invalid."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| perm->resize(tensor_shape[0]); | |||
| if (memcpy_s(perm->data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) { | |||
| perm->resize(data_info.shape_[0]); | |||
| if (!data_info.data_.empty() && | |||
| memcpy_s(perm->data(), data_info.data_.size(), data_info.data_.data(), data_info.data_.size()) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| @@ -38,7 +38,7 @@ const std::unordered_map<std::string, std::vector<size_t>> &GetNCHWOpMap(); | |||
| const std::unordered_map<int, int> &GetNC2NHAxisMap(); | |||
| const std::vector<std::string> &GetDynamicFormatOpList(); | |||
| Format GetFormat(const CNodePtr &cnode); | |||
| STATUS GetTransposePerm(const AnfNodePtr &perm_node, std::vector<int> *perm); | |||
| STATUS GetTransposePerm(const CNodePtr &cnode, std::vector<int> *perm); | |||
| void RemoveIfMonad(const CNodePtr &cnode); | |||
| bool IsMonadNode(const AnfNodePtr &node); | |||
| } // namespace opt | |||
| @@ -15,12 +15,13 @@ | |||
| */ | |||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| #include "tools/converter/quant_param_holder.h" | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "src/common/common.h" | |||
| @@ -36,47 +37,79 @@ using mindspore::lite::Tensor; | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t INITIAL_SIZE = 1024; | |||
| std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) { | |||
| void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) { | |||
| if (input_tensor != nullptr) { | |||
| for (auto &i : *input_tensor) { | |||
| delete i; | |||
| i = nullptr; | |||
| } | |||
| } | |||
| if (output_tensor != nullptr) { | |||
| for (auto &i : *output_tensor) { | |||
| delete i; | |||
| i = nullptr; | |||
| } | |||
| } | |||
| } | |||
| std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &cnode, lite::converter::FmkType fmk_type) { | |||
| MS_ASSERT(CNode != nullptr); | |||
| auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| auto tmp_fb_node = std::make_unique<schema::CNodeT>(); | |||
| lite::AnfExporter anfExporter; | |||
| anfExporter.SetOpInputNode(CNode, tmp_meta_graph, tmp_fb_node.get()); | |||
| std::vector<Tensor *> input_tensors; | |||
| for (auto input_index : tmp_fb_node->inputIndex) { | |||
| auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); | |||
| auto tensor_shape = tensorT->dims; | |||
| auto lite_tensor = new (std::nothrow) Tensor( | |||
| TypeId(tensorT->dataType), tensor_shape, tensorT->format, | |||
| lite::TensorCategory(tensorT->nodeType, tensorT->dims.size(), TypeId(tensorT->dataType), tensorT->data.size())); | |||
| if (lite_tensor == nullptr) { | |||
| MS_LOG(ERROR) << "lite tensor is nullptr"; | |||
| return input_tensors; | |||
| std::vector<Tensor *> tensors; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| int status; | |||
| lite::DataInfo data_info; | |||
| if (utils::isa<ParameterPtr>(cnode->input(i))) { | |||
| if (!cnode->input(i)->cast<ParameterPtr>()->has_default()) { | |||
| FreeTensors(&tensors, nullptr); | |||
| return {}; | |||
| } | |||
| status = lite::FetchDataFromParameterNode(cnode, i, fmk_type, false, &data_info); | |||
| } else if (utils::isa<ValueNodePtr>(cnode->input(i))) { | |||
| status = lite::FetchDataFromValueNode(cnode, i, fmk_type, false, &data_info); | |||
| } else { | |||
| MS_LOG(ERROR) << "input node is not const node."; | |||
| FreeTensors(&tensors, nullptr); | |||
| return {}; | |||
| } | |||
| auto lite_tensor_size = tensorT->data.size() * sizeof(uint8_t); | |||
| // when tensorT as graph input | |||
| if (lite_tensor_size <= 0) { | |||
| delete lite_tensor; | |||
| return input_tensors; | |||
| if (status == lite::RET_NO_CHANGE) { | |||
| continue; | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "parser const data failed."; | |||
| FreeTensors(&tensors, nullptr); | |||
| return {}; | |||
| } | |||
| if (data_info.shape_.empty() && data_info.data_.empty()) { | |||
| FreeTensors(&tensors, nullptr); | |||
| MS_LOG(DEBUG) << "input node is graph input."; | |||
| return {}; | |||
| } | |||
| auto tensor_data = new (std::nothrow) uint8_t[lite_tensor_size / sizeof(char)]; | |||
| auto tensor = new (std::nothrow) | |||
| Tensor(TypeId(data_info.data_type_), data_info.shape_, schema::Format(data_info.format_), | |||
| lite::TensorCategory(0, data_info.shape_.size(), TypeId(data_info.data_type_), data_info.data_.size())); | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new a tensor is nullptr."; | |||
| FreeTensors(&tensors, nullptr); | |||
| return {}; | |||
| } | |||
| if (data_info.data_.empty()) { | |||
| tensors.emplace_back(tensor); | |||
| continue; | |||
| } | |||
| auto tensor_data = tensor->MutableData(); | |||
| if (tensor_data == nullptr) { | |||
| MS_LOG(ERROR) << "tensor_data is nullptr"; | |||
| delete lite_tensor; | |||
| return input_tensors; | |||
| MS_LOG(ERROR) << "malloc data failed."; | |||
| FreeTensors(&tensors, nullptr); | |||
| return {}; | |||
| } | |||
| auto ret = memcpy_s(tensor_data, lite_tensor_size, tensorT->data.data(), lite_tensor_size); | |||
| if (ret != EOK) { | |||
| delete lite_tensor; | |||
| delete[](tensor_data); | |||
| MS_LOG(ERROR) << "memcpy error: " << ret; | |||
| lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED); | |||
| if (memcpy_s(tensor_data, data_info.data_.size(), data_info.data_.data(), data_info.data_.size()) != EOK) { | |||
| MS_LOG(ERROR) << "memcpy data failed."; | |||
| FreeTensors(&tensors, nullptr); | |||
| return {}; | |||
| } | |||
| lite_tensor->set_data(tensor_data); | |||
| input_tensors.emplace_back(lite_tensor); | |||
| tensors.emplace_back(tensor); | |||
| } | |||
| return input_tensors; | |||
| return tensors; | |||
| } | |||
| ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { | |||
| @@ -229,21 +262,6 @@ lite::STATUS CopyQuantParams(const CNodePtr &cnode, const std::vector<Tensor *> | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| void FreeTensors(std::vector<Tensor *> *input_tensor, std::vector<Tensor *> *output_tensor) { | |||
| if (input_tensor != nullptr) { | |||
| for (auto &i : *input_tensor) { | |||
| delete i; | |||
| i = nullptr; | |||
| } | |||
| } | |||
| if (output_tensor != nullptr) { | |||
| for (auto &i : *output_tensor) { | |||
| delete i; | |||
| i = nullptr; | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| @@ -263,9 +281,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| continue; | |||
| } | |||
| auto input_cnode = input_node->cast<CNodePtr>(); | |||
| auto input_tensors = GetCNodeInputTensors(input_cnode); | |||
| if (input_tensors.empty() || input_tensors.size() != input_cnode->inputs().size() - 1) { | |||
| FreeTensors(&input_tensors, nullptr); | |||
| auto input_tensors = GetCNodeInputTensors(input_cnode, fmk_type_); | |||
| if (input_tensors.empty()) { | |||
| continue; | |||
| } | |||
| changed = true; | |||
| @@ -279,7 +296,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| FreeTensors(&input_tensors, &output_tensors); | |||
| return nullptr; | |||
| } | |||
| auto lite_kernel = GetLiteKernel(input_tensors, &output_tensors, input_cnode, context.get()); | |||
| auto lite_kernel = GetLiteKernel(input_tensors, &output_tensors, input_cnode, context_.get()); | |||
| if (lite_kernel == nullptr) { | |||
| FreeTensors(&input_tensors, &output_tensors); | |||
| MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; | |||
| @@ -24,18 +24,23 @@ | |||
| #include "src/lite_kernel.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class ConstFoldPass : public PatternProcessPass { | |||
| public: | |||
| explicit ConstFoldPass(std::shared_ptr<lite::InnerContext> context_ptr = nullptr, bool multigraph = true) | |||
| : PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {} | |||
| explicit ConstFoldPass(lite::converter::FmkType fmk_type = lite::converter::FmkType_MS, bool multigraph = true) | |||
| : PatternProcessPass("constfold_pass", multigraph), fmk_type_(fmk_type) { | |||
| context_ = std::make_shared<lite::InnerContext>(); | |||
| context_->Init(); | |||
| } | |||
| ~ConstFoldPass() override = default; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| std::shared_ptr<lite::InnerContext> context; | |||
| lite::converter::FmkType fmk_type_{lite::converter::FmkType_MS}; | |||
| std::shared_ptr<lite::InnerContext> context_{nullptr}; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -18,9 +18,9 @@ | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/tensor_util.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/ops/populate/populate_register.h" | |||
| #include "src/ops/ops_utils.h" | |||
| #include "src/runtime/infer_manager.h" | |||
| @@ -67,8 +67,8 @@ bool DuceInferFlag(const CNodePtr &cnode, const std::vector<lite::Tensor *> &inp | |||
| } | |||
| } | |||
| auto origin_inputs = cnode->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(cnode); | |||
| lite::AnfExporter::RemoveIfMakeTuple(cnode); | |||
| lite::RemoveIfDepend(cnode); | |||
| lite::RemoveIfMakeTuple(cnode); | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (!utils::isa<CNodePtr>(cnode->input(i))) { | |||
| continue; | |||
| @@ -241,8 +241,8 @@ STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(inputs != nullptr); | |||
| auto origin_inputs = cnode->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(cnode); | |||
| lite::AnfExporter::RemoveIfMakeTuple(cnode); | |||
| lite::RemoveIfDepend(cnode); | |||
| lite::RemoveIfMakeTuple(cnode); | |||
| RemoveIfMonad(cnode); | |||
| std::vector<lite::Tensor *> const_inputs; | |||
| if (GetCNodeConstInput(cnode, &const_inputs) != lite::RET_OK) { | |||
| @@ -288,28 +288,29 @@ STATUS NodeInferShape::GetCNodeInputTensors(const CNodePtr &cnode, std::vector<l | |||
| } | |||
| STATUS NodeInferShape::GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| auto origin_inputs = cnode->inputs(); | |||
| std::vector<AnfNodePtr> const_inputs; | |||
| for (auto &input : origin_inputs) { | |||
| if (utils::isa<CNodePtr>(input)) { | |||
| MS_ASSERT(cnode != nullptr && const_ms_inputs != nullptr); | |||
| std::vector<lite::DataInfo> data_infos; | |||
| for (size_t i = 1; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| continue; | |||
| } | |||
| const_inputs.push_back(input); | |||
| } | |||
| cnode->set_inputs(const_inputs); | |||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| meta_graph->fmkType = fmk_type_; | |||
| auto fb_node = std::make_unique<schema::CNodeT>(); | |||
| lite::AnfExporter anf_exporter; | |||
| anf_exporter.set_train_flag(train_flag_); | |||
| auto status = anf_exporter.SetOpInputNode(cnode, meta_graph, fb_node.get()); | |||
| cnode->set_inputs(origin_inputs); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get const inputs failed."; | |||
| return status; | |||
| STATUS status; | |||
| lite::DataInfo data_info; | |||
| if (utils::isa<ParameterPtr>(cnode->input(i))) { | |||
| status = lite::FetchDataFromParameterNode(cnode, i, fmk_type_, train_flag_, &data_info); | |||
| } else { | |||
| status = lite::FetchDataFromValueNode(cnode, i, fmk_type_, train_flag_, &data_info); | |||
| } | |||
| if (status == lite::RET_NO_CHANGE) { | |||
| continue; | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "fetch const input data failed."; | |||
| return status; | |||
| } | |||
| data_infos.emplace_back(data_info); | |||
| } | |||
| return ConvertToLiteTensor(meta_graph, fb_node->inputIndex, const_ms_inputs); | |||
| return ConvertToLiteTensor(data_infos, const_ms_inputs); | |||
| } | |||
| STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs) { | |||
| @@ -319,29 +320,16 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite: | |||
| if (!utils::isa<CNodePtr>(cnode->input(i))) { | |||
| continue; | |||
| } | |||
| auto abstract = GetCNodeInputAbstract(cnode, i); | |||
| if (abstract == nullptr) { | |||
| MS_LOG(ERROR) << "Abstract cnode is nullptr."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) { | |||
| MS_LOG(ERROR) << "Abstract should be anstract tensor."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract); | |||
| auto type_ptr = abstract_tensor->element()->GetTypeTrack(); | |||
| MS_ASSERT(typePtr != nullptr); | |||
| if (!utils::isa<abstract::ShapePtr>(abstract_tensor->BuildShape())) { | |||
| MS_LOG(ERROR) << "Shape of Abstract should be ShapePtr."; | |||
| lite::DataInfo data_info; | |||
| if (lite::FetchDataFromCNode(cnode, i, fmk_type_, train_flag_, &data_info) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "parse cnode failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto shape_vector = utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape(); | |||
| std::vector<int32_t> dims(shape_vector.begin(), shape_vector.end()); | |||
| lite::Tensor *tensor = nullptr; | |||
| if (type_ptr->type_id() == kObjectTypeTensorType) { | |||
| tensor = GetCNodeTensorListVarInput(dims, abstract_tensor); | |||
| if (data_info.data_type_ == kObjectTypeTensorType) { | |||
| tensor = GetCNodeTensorListVarInput(data_info); | |||
| } else { | |||
| tensor = new (std::nothrow) lite::Tensor(TypeId(type_ptr->type_id()), dims); | |||
| tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_); | |||
| } | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite tensor failed"; | |||
| @@ -352,27 +340,16 @@ STATUS NodeInferShape::GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite: | |||
| return lite::RET_OK; | |||
| } | |||
| lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector<int> shape, | |||
| const abstract::AbstractTensorPtr &abstract_tensor) { | |||
| MS_ASSERT(abstract_tensor != nullptr); | |||
| auto tensor_list = new (std::nothrow) lite::TensorList(shape, {}); | |||
| lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(const lite::DataInfo &data_info) { | |||
| auto tensor_list = new (std::nothrow) lite::TensorList(data_info.shape_, {}); | |||
| if (tensor_list == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite tensor list failed"; | |||
| return nullptr; | |||
| } | |||
| auto tensor_info = abstract_tensor->GetValueTrack(); | |||
| if (tensor_info == nullptr || !utils::isa<tensor::TensorPtr>(tensor_info)) { | |||
| delete tensor_list; | |||
| MS_LOG(ERROR) << "nsor list abstract is invalid."; | |||
| return nullptr; | |||
| } | |||
| auto tensor_value = tensor_info->cast<tensor::TensorPtr>(); | |||
| if (tensor_value->data_c() == nullptr) { | |||
| delete tensor_list; | |||
| MS_LOG(ERROR) << "cannot get tensor list abstract's info."; | |||
| return nullptr; | |||
| if (data_info.data_.empty()) { | |||
| return tensor_list; | |||
| } | |||
| auto status = tensor_list->Decode(static_cast<int *>(tensor_value->data_c())); | |||
| auto status = tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data())); | |||
| if (status != lite::RET_OK) { | |||
| delete tensor_list; | |||
| MS_LOG(ERROR) << "decode tensor list failed."; | |||
| @@ -384,41 +361,78 @@ lite::Tensor *NodeInferShape::GetCNodeTensorListVarInput(std::vector<int> shape, | |||
| STATUS NodeInferShape::GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(outputs != nullptr); | |||
| auto meta_graph = std::make_unique<schema::MetaGraphT>(); | |||
| meta_graph->fmkType = fmk_type_; | |||
| auto fb_node = std::make_unique<schema::CNodeT>(); | |||
| lite::AnfExporter anf_exporter; | |||
| anf_exporter.set_train_flag(train_flag_); | |||
| anf_exporter.SetOpOutputNode(cnode, meta_graph, fb_node.get()); | |||
| return ConvertToLiteTensor(meta_graph, fb_node->outputIndex, outputs); | |||
| std::vector<lite::DataInfo> data_infos; | |||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||
| if (tuple == nullptr) { | |||
| MS_LOG(ERROR) << "tuple is nullptr"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto elements = tuple->elements(); | |||
| for (size_t i = 0; i < elements.size(); i++) { | |||
| lite::DataInfo data_info; | |||
| data_info.node_type_ = lite::NodeType_CNode; | |||
| if (train_flag_) { | |||
| data_infos.emplace_back(data_info); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || CheckPrimitiveType(cnode, prim::kPrimAdam)) { | |||
| break; | |||
| } | |||
| } else { | |||
| if (!utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| MS_LOG(ERROR) << "abstract is not AbstractTensor"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(elements[i])) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(elements[i]); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| data_info.data_type_ = type; | |||
| data_infos.emplace_back(data_info); | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConv2DFusion) || | |||
| CheckPrimitiveType(cnode, prim::kPrimFusedBatchNorm)) { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| lite::DataInfo data_info; | |||
| auto type = kNumberTypeFloat32; | |||
| if (utils::isa<abstract::AbstractTensorPtr>(cnode->abstract())) { | |||
| auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(cnode->abstract()); | |||
| auto typePtr = abstract_tensor->element()->GetTypeTrack(); | |||
| type = typePtr->type_id(); | |||
| } | |||
| data_info.data_type_ = type; | |||
| data_info.node_type_ = lite::NodeType_CNode; | |||
| data_infos.emplace_back(data_info); | |||
| } | |||
| return ConvertToLiteTensor(data_infos, outputs); | |||
| } | |||
| STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::vector<uint32_t> &tensor_indexes, | |||
| STATUS NodeInferShape::ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, | |||
| std::vector<lite::Tensor *> *tensors) { | |||
| MS_ASSERT(meta_graph != nullptr); | |||
| MS_ASSERT(tensors != nullptr); | |||
| for (auto index : tensor_indexes) { | |||
| auto tensor_t = meta_graph->allTensors.at(index).get(); | |||
| auto tensor_shape = tensor_t->dims; | |||
| auto tensor_category = lite::TensorCategory(tensor_t->nodeType, tensor_t->dims.size(), TypeId(tensor_t->dataType), | |||
| tensor_t->data.size()); | |||
| for (auto &data_info : data_infos) { | |||
| auto tensor_category = lite::TensorCategory(lite::NodeType(data_info.node_type_), data_info.shape_.size(), | |||
| TypeId(data_info.data_type_), data_info.data_.size()); | |||
| lite::Tensor *tensor = nullptr; | |||
| if (tensor_t->dataType != kObjectTypeTensorType) { | |||
| tensor = | |||
| new (std::nothrow) lite::Tensor(TypeId(tensor_t->dataType), tensor_shape, tensor_t->format, tensor_category); | |||
| if (data_info.data_type_ != kObjectTypeTensorType) { | |||
| tensor = new (std::nothrow) lite::Tensor(TypeId(data_info.data_type_), data_info.shape_, | |||
| (schema::Format)data_info.format_, tensor_category); | |||
| } else { | |||
| tensor = new (std::nothrow) lite::TensorList(tensor_shape, std::vector<int>(), tensor_category); | |||
| tensor = new (std::nothrow) lite::TensorList(data_info.shape_, std::vector<int>(), tensor_category); | |||
| } | |||
| if (tensor == nullptr) { | |||
| MS_LOG(ERROR) << "new a lite tensor failed"; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto tensor_size = tensor_t->data.size() * sizeof(char); | |||
| auto tensor_size = data_info.data_.size(); | |||
| if (tensor_size > 0) { | |||
| if (tensor_t->dataType == kObjectTypeTensorType) { | |||
| if (data_info.data_type_ == kObjectTypeTensorType) { | |||
| auto tensor_list = reinterpret_cast<lite::TensorList *>(tensor); | |||
| if (tensor_list->Decode(reinterpret_cast<const int *>(tensor_t->data.data())) != RET_OK) { | |||
| if (tensor_list->Decode(reinterpret_cast<const int *>(data_info.data_.data())) != RET_OK) { | |||
| MS_LOG(ERROR) << "Decode tensorlist data failed"; | |||
| return RET_ERROR; | |||
| } | |||
| @@ -429,7 +443,7 @@ STATUS NodeInferShape::ConvertToLiteTensor(const std::unique_ptr<schema::MetaGra | |||
| delete tensor; | |||
| return lite::RET_ERROR; | |||
| } | |||
| if (memcpy_s(tensor_data, tensor_size, tensor_t->data.data(), tensor_size) != EOK) { | |||
| if (memcpy_s(tensor_data, tensor_size, data_info.data_.data(), tensor_size) != EOK) { | |||
| delete tensor; | |||
| delete[](tensor_data); | |||
| MS_LOG(ERROR) << "memcpy error: "; | |||
| @@ -22,6 +22,7 @@ | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "src/tensor.h" | |||
| #include "tools/anf_exporter/fetch_content.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/optimizer/common/format_utils.h" | |||
| @@ -44,10 +45,9 @@ class NodeInferShape { | |||
| STATUS GetCNodeInputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *inputs); | |||
| STATUS GetCNodeConstInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *const_ms_inputs); | |||
| STATUS GetCNodeVarInput(const CNodePtr &cnode, std::vector<lite::Tensor *> *var_ms_inputs); | |||
| lite::Tensor *GetCNodeTensorListVarInput(std::vector<int> shape, const abstract::AbstractTensorPtr &abstract_tensor); | |||
| lite::Tensor *GetCNodeTensorListVarInput(const lite::DataInfo &data_info); | |||
| STATUS GetCNodeOutputTensors(const CNodePtr &cnode, std::vector<lite::Tensor *> *outputs); | |||
| STATUS ConvertToLiteTensor(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||
| const std::vector<uint32_t> &tensor_indexes, std::vector<lite::Tensor *> *tensors); | |||
| STATUS ConvertToLiteTensor(const std::vector<lite::DataInfo> &data_infos, std::vector<lite::Tensor *> *tensors); | |||
| STATUS SetCNodeAbstract(const std::shared_ptr<CNode> &cnode, const std::vector<lite::Tensor *> &outputs); | |||
| abstract::AbstractBasePtr ConvertLiteTensorToAbstract(lite::Tensor *tensor); | |||
| abstract::AbstractBasePtr ConvertTensorListToAbstract(lite::Tensor *tensor); | |||
| @@ -67,7 +67,7 @@ AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &fu | |||
| MS_LOG(ERROR) << "input node is invalid."; | |||
| return nullptr; | |||
| } | |||
| if (GetTransposePerm(input_cnode->input(kTransposePerm), &trans_perm) != lite::RET_OK) { | |||
| if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "transpose perm get failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -142,8 +142,40 @@ bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const | |||
| return can_insert; | |||
| } | |||
| bool TransposeStrategy::CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| if (shape.size() != 4) { | |||
| if (cnode->size() > 2) { | |||
| shape = node_infer_shape_.GetInputShape(cnode, 2); | |||
| if (shape.size() != 4 && !shape.empty()) { | |||
| return false; | |||
| } | |||
| } else { | |||
| return false; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimConcat) || CheckPrimitiveType(cnode, prim::kPrimSplit)) { | |||
| auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||
| if (prim->GetAttr(ops::kAxis) == nullptr) { | |||
| return false; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice)) { | |||
| for (size_t i = 2; i < cnode->size(); ++i) { | |||
| if (utils::isa<CNodePtr>(cnode->input(i))) { | |||
| return false; | |||
| } | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| auto shape = node_infer_shape_.GetInputShape(cnode, 1); | |||
| if (shape.size() != 4) { | |||
| if (cnode->size() > 2) { | |||
| @@ -180,6 +212,7 @@ STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNo | |||
| } else { | |||
| offsets.push_back(0); | |||
| } | |||
| crop_prim->set_axis(new_axis); | |||
| crop_prim->set_offsets(offsets); | |||
| } | |||
| if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) { | |||
| @@ -231,7 +264,7 @@ bool TransposeStrategy::IsInOutCanFuison(const FuncGraphPtr &func_graph, const s | |||
| if (cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (GetTransposePerm(cnode->input(kTransposePerm), &perm) != lite::RET_OK) { | |||
| if (GetTransposePerm(cnode, &perm) != lite::RET_OK) { | |||
| return false; | |||
| } | |||
| if (perm == NH2NC) { | |||
| @@ -44,6 +44,7 @@ class TransposeStrategy { | |||
| bool CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_info, | |||
| TransTypePair *trans_insert_info); | |||
| STATUS ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| bool CanChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| private: | |||
| STATUS TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool before, size_t index); | |||
| @@ -15,12 +15,14 @@ | |||
| */ | |||
| #include "tools/optimizer/graph/unify_format_pass.h" | |||
| #include <queue> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include "ops/op_utils.h" | |||
| #include "src/common/common.h" | |||
| #include "src/common/utils.h" | |||
| #include "tools/anf_exporter/anf_exporter.h" | |||
| #include "tools/common/tensor_util.h" | |||
| using mindspore::lite::NCHW_SHAPE; | |||
| namespace mindspore { | |||
| @@ -37,6 +39,173 @@ bool IsSpecialType(const CNodePtr &cnode) { | |||
| } | |||
| return false; | |||
| } | |||
| STATUS FindAreaSurroundedByTranspose(const FuncGraphPtr &func_graph, const CNodePtr &root_node, | |||
| std::set<CNodePtr> *in_nodes, std::set<CNodePtr> *out_nodes, | |||
| std::set<CNodePtr> *middle_nodes) { | |||
| MS_ASSERT(func_graph != nullptr && root_node != nullptr); | |||
| MS_ASSERT(in_nodes != nullptr && out_nodes != nullptr && middle_nodes != nullptr); | |||
| std::queue<CNodePtr> queue_nodes; | |||
| queue_nodes.push(root_node); | |||
| std::queue<bool> is_pre_nodes; | |||
| is_pre_nodes.push(true); | |||
| while (!queue_nodes.empty()) { | |||
| auto cur_node = queue_nodes.front(); | |||
| auto is_pre_node = is_pre_nodes.front(); | |||
| queue_nodes.pop(); | |||
| is_pre_nodes.pop(); | |||
| if (CheckPrimitiveType(cur_node, prim::kPrimTranspose)) { | |||
| if (is_pre_node) { | |||
| in_nodes->insert(cur_node); | |||
| } else { | |||
| out_nodes->insert(cur_node); | |||
| continue; | |||
| } | |||
| } | |||
| if (middle_nodes->find(cur_node) != middle_nodes->end()) { | |||
| continue; | |||
| } | |||
| if (in_nodes->find(cur_node) == in_nodes->end()) { | |||
| middle_nodes->insert(cur_node); | |||
| // insert pre nodes. | |||
| auto origin_inputs = cur_node->inputs(); | |||
| lite::RemoveIfDepend(cur_node); | |||
| for (size_t i = 1; i < cur_node->size(); ++i) { | |||
| if (!utils::isa<CNodePtr>(cur_node->input(i))) { | |||
| continue; | |||
| } | |||
| auto cur_node_input = cur_node->input(i)->cast<CNodePtr>(); | |||
| if (middle_nodes->find(cur_node_input) != middle_nodes->end() || | |||
| in_nodes->find(cur_node_input) != in_nodes->end()) { | |||
| continue; | |||
| } | |||
| queue_nodes.push(cur_node_input); | |||
| is_pre_nodes.push(true); | |||
| } | |||
| if (CheckIsAllInputsParam(cur_node)) { | |||
| in_nodes->insert(cur_node); | |||
| } | |||
| cur_node->set_inputs(origin_inputs); | |||
| } | |||
| // insert post nodes | |||
| auto cur_node_users = func_graph->manager()->node_users()[cur_node]; | |||
| for (auto &cur_node_user : cur_node_users) { | |||
| if (!utils::isa<CNodePtr>(cur_node_user.first)) { | |||
| MS_LOG(ERROR) << "post node is not cnode."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto cur_node_post = cur_node_user.first->cast<CNodePtr>(); | |||
| if (middle_nodes->find(cur_node_post) != middle_nodes->end() || | |||
| out_nodes->find(cur_node_post) != out_nodes->end()) { | |||
| continue; | |||
| } | |||
| queue_nodes.push(cur_node_post); | |||
| is_pre_nodes.push(false); | |||
| } | |||
| if (cur_node_users.empty()) { | |||
| out_nodes->insert(cur_node); | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool JudgeCanOptimizerForMultiOp(const FuncGraphPtr &func_graph, const std::set<CNodePtr> &in_nodes, | |||
| const std::set<CNodePtr> &out_nodes, const std::set<CNodePtr> &middle_nodes) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| for (auto &in_cnode : in_nodes) { | |||
| std::vector<int> perm; | |||
| if (!CheckPrimitiveType(in_cnode, prim::kPrimTranspose) || GetTransposePerm(in_cnode, &perm) != lite::RET_OK || | |||
| perm != NH2NC) { | |||
| return false; | |||
| } | |||
| } | |||
| for (auto &out_cnode : out_nodes) { | |||
| std::vector<int> perm; | |||
| if (!CheckPrimitiveType(out_cnode, prim::kPrimTranspose) || GetTransposePerm(out_cnode, &perm) != lite::RET_OK || | |||
| perm != NC2NH) { | |||
| return false; | |||
| } | |||
| } | |||
| auto &dynamic_ops = GetDynamicFormatOpList(); | |||
| TransposeStrategy transpose_strategy; | |||
| for (auto &middle_cnode : middle_nodes) { | |||
| if (IsSpecialType(middle_cnode)) { | |||
| continue; | |||
| } | |||
| auto middle_node_prim = GetValueNode<PrimitivePtr>(middle_cnode->input(0)); | |||
| if (!lite::IsContain(dynamic_ops, middle_node_prim->name()) || | |||
| !transpose_strategy.CanChangeOpAxis(func_graph, middle_cnode)) { | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| void ConvertNcTensor2Nh(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t index, FmkType fmk_type, | |||
| bool train_flag) { | |||
| MS_ASSERT(cnode != nullptr); | |||
| if (utils::isa<CNodePtr>(cnode->input(index))) { | |||
| return; | |||
| } | |||
| lite::DataInfo data_info; | |||
| int status; | |||
| if (utils::isa<ParameterPtr>(cnode->input(index))) { | |||
| status = lite::FetchDataFromParameterNode(cnode, index, fmk_type, train_flag, &data_info); | |||
| } else { | |||
| status = lite::FetchDataFromValueNode(cnode, index, fmk_type, train_flag, &data_info); | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| return; | |||
| } | |||
| if (data_info.shape_.empty() || | |||
| (data_info.data_type_ != kNumberTypeFloat32 && data_info.data_type_ != kNumberTypeFloat)) { | |||
| return; | |||
| } | |||
| std::vector<int> new_shape; | |||
| if (data_info.shape_.size() == 1) { | |||
| new_shape = {1, 1, 1, data_info.shape_[0]}; | |||
| } else if (data_info.shape_.size() == 2) { | |||
| new_shape = {1, 1, data_info.shape_[0], data_info.shape_[1]}; | |||
| } else if (data_info.shape_.size() == 3) { | |||
| new_shape = {1, data_info.shape_[0], data_info.shape_[1], data_info.shape_[2]}; | |||
| } | |||
| auto size = data_info.data_.size() / sizeof(float); | |||
| std::vector<float> new_data(size); | |||
| auto new_data_ptr = static_cast<float *>(new_data.data()); | |||
| auto nchw_data = reinterpret_cast<float *>(data_info.data_.data()); | |||
| // nchw to nhwc | |||
| auto batch = new_shape[lite::NCHW_N]; | |||
| auto channel = new_shape[lite::NCHW_C]; | |||
| auto area = new_shape[lite::NCHW_H] * new_shape[lite::NCHW_W]; | |||
| for (auto i = 0; i < batch; i++) { | |||
| float *src_batch = nchw_data + i * channel * area; | |||
| float *dst_batch = new_data_ptr + i * channel * area; | |||
| for (int j = 0; j < area; ++j) { | |||
| float *src_area = src_batch + i; | |||
| float *dst_area = dst_batch + i * channel; | |||
| for (int k = 0; k < channel; ++k) { | |||
| dst_area[k] = src_area[k * area]; | |||
| } | |||
| } | |||
| } | |||
| auto param_node = func_graph->add_parameter(); | |||
| param_node->set_name(cnode->input(index)->fullname_with_scope()); | |||
| std::vector<int64_t> shape_vec{new_shape[0], new_shape[2], new_shape[3], new_shape[1]}; | |||
| auto tensor_info = lite::CreateTensorInfo(new_data.data(), size * sizeof(float), shape_vec, kNumberTypeFloat32); | |||
| if (tensor_info == nullptr) { | |||
| MS_LOG(ERROR) << "Create tensor info failed"; | |||
| return; | |||
| } | |||
| status = lite::InitParameterFromTensorInfo(param_node, tensor_info); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "init parameter from tensor info failed"; | |||
| return; | |||
| } | |||
| auto tr = func_graph->manager()->Transact(); | |||
| tr.SetEdge(cnode, index, param_node); | |||
| tr.Commit(); | |||
| return; | |||
| } | |||
| } // namespace | |||
| void UnifyFormatPass::GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info) { | |||
| @@ -79,7 +248,7 @@ bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNo | |||
| return false; | |||
| } | |||
| std::vector<int> post_perm; | |||
| if (GetTransposePerm(cnode->input(2), &post_perm) != lite::RET_OK) { | |||
| if (GetTransposePerm(cnode, &post_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return false; | |||
| } | |||
| @@ -89,7 +258,7 @@ bool UnifyFormatPass::TransTransFusion(const FuncGraphPtr &func_graph, const CNo | |||
| if (pre_cnode == nullptr) { | |||
| return false; | |||
| } | |||
| if (GetTransposePerm(pre_cnode->input(2), &pre_perm) != lite::RET_OK) { | |||
| if (GetTransposePerm(pre_cnode, &pre_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get tanspose perm failed."; | |||
| return false; | |||
| } | |||
| @@ -106,7 +275,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons | |||
| return lite::RET_OK; | |||
| } | |||
| std::vector<int> cur_perm; | |||
| if (GetTransposePerm(cnode->input(2), &cur_perm) != lite::RET_OK) { | |||
| if (GetTransposePerm(cnode, &cur_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get transpose perm failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| @@ -116,7 +285,7 @@ STATUS UnifyFormatPass::PostTransposeFusion(const FuncGraphPtr &func_graph, cons | |||
| if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) { | |||
| std::vector<int> post_trans_perm; | |||
| auto post_trans_node = post_node->cast<CNodePtr>(); | |||
| if (GetTransposePerm(post_trans_node->input(2), &post_trans_perm) != lite::RET_OK) { | |||
| if (GetTransposePerm(post_trans_node, &post_trans_perm) != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "get post transpose node perm failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| @@ -218,7 +387,7 @@ STATUS UnifyFormatPass::InsertPreTransNode(const FuncGraphPtr &func_graph, const | |||
| MS_ASSERT(trans_insert_info != nullptr); | |||
| TransTypePair trans_info; | |||
| auto origin_inputs = cnode->inputs(); | |||
| lite::AnfExporter::RemoveIfMakeTuple(cnode); | |||
| lite::RemoveIfMakeTuple(cnode); | |||
| RemoveIfMonad(cnode); | |||
| if (!transpose_strategy_.CanFusionIfInsert(func_graph, cnode, &trans_info, trans_insert_info)) { | |||
| cnode->set_inputs(origin_inputs); | |||
| @@ -366,8 +535,8 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN | |||
| prim->AddAttr(kTransDone, MakeValue<bool>(true)); | |||
| TransTypePair trans_info; | |||
| GetTransNodeFormatType(cnode, &trans_info); | |||
| if (!need_reset_ && (trans_info.pre_ == kNONE || trans_info.post_ == kNONE)) { | |||
| if (TransTransFusion(func_graph, cnode)) { | |||
| if (trans_info.pre_ == kNONE || trans_info.post_ == kNONE) { | |||
| if (!need_reset_ && TransTransFusion(func_graph, cnode)) { | |||
| return lite::RET_OK; | |||
| } | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> match; | |||
| @@ -401,6 +570,65 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::set<CNodePtr> *visit_transposes) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr && visit_transposes != nullptr); | |||
| auto manager = func_graph->manager(); | |||
| MS_ASSERT(manager != nullptr); | |||
| std::set<CNodePtr> middle_nodes; | |||
| std::set<CNodePtr> in_nodes; | |||
| std::set<CNodePtr> out_nodes; | |||
| auto status = FindAreaSurroundedByTranspose(func_graph, cnode, &in_nodes, &out_nodes, &middle_nodes); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "find an area surrounded by transpose failed."; | |||
| return status; | |||
| } | |||
| for (auto &in_cnode : in_nodes) { | |||
| if (CheckPrimitiveType(in_cnode, prim::kPrimTranspose)) { | |||
| visit_transposes->insert(in_cnode); | |||
| } | |||
| } | |||
| if (!JudgeCanOptimizerForMultiOp(func_graph, in_nodes, out_nodes, middle_nodes)) { | |||
| return lite::RET_NO_CHANGE; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| std::vector<CNodePtr> middle_ops_vec; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| if (middle_nodes.find(node->cast<CNodePtr>()) != middle_nodes.end()) { | |||
| middle_ops_vec.push_back(node->cast<CNodePtr>()); | |||
| middle_nodes.erase(node->cast<CNodePtr>()); | |||
| } | |||
| } | |||
| for (auto &in_cnode : in_nodes) { | |||
| manager->Replace(in_cnode, in_cnode->input(1)); | |||
| } | |||
| for (auto &out_cnode : out_nodes) { | |||
| manager->Replace(out_cnode, out_cnode->input(1)); | |||
| } | |||
| for (auto &middle_cnode : middle_ops_vec) { | |||
| if (IsSpecialType(middle_cnode)) { | |||
| continue; | |||
| } | |||
| for (size_t i = 1; i < middle_cnode->size(); ++i) { | |||
| ConvertNcTensor2Nh(func_graph, middle_cnode, i, fmk_type_, train_flag_); | |||
| } | |||
| status = transpose_strategy_.ChangeOpAxis(func_graph, middle_cnode); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "change op attr failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| status = node_infer_shape_.InferShape(middle_cnode); | |||
| if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { | |||
| MS_LOG(ERROR) << "infer shape failed."; | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> *match) { | |||
| MS_ASSERT(func_graph != nullptr && cnode != nullptr); | |||
| @@ -482,8 +710,8 @@ void UnifyFormatPass::SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPt | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_input = return_node->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(return_node); | |||
| lite::AnfExporter::RemoveIfMakeTuple(return_node); | |||
| lite::RemoveIfDepend(return_node); | |||
| lite::RemoveIfMakeTuple(return_node); | |||
| for (size_t i = 1; i < return_node->size(); ++i) { | |||
| if (!CheckPrimitiveType(return_node->input(i), prim::kPrimTranspose)) { | |||
| continue; | |||
| @@ -511,8 +739,8 @@ void UnifyFormatPass::SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraph | |||
| MS_ASSERT(cnode != nullptr && sub_graph != nullptr); | |||
| auto return_node = sub_graph->get_return(); | |||
| auto origin_inputs = return_node->inputs(); | |||
| lite::AnfExporter::RemoveIfDepend(return_node); | |||
| lite::AnfExporter::RemoveIfMakeTuple(return_node); | |||
| lite::RemoveIfDepend(return_node); | |||
| lite::RemoveIfMakeTuple(return_node); | |||
| AbstractBasePtrList abstract_list; | |||
| bool infer_done = true; | |||
| for (size_t i = 1; i < return_node->size(); ++i) { | |||
| @@ -679,6 +907,49 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| if (manager == nullptr) { | |||
| MS_LOG(ERROR) << "manager is nullptr."; | |||
| return false; | |||
| } | |||
| auto node_list = TopoSort(func_graph->get_return()); | |||
| std::set<CNodePtr> visit_transposes; | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| if (IsSpecialType(cnode) || visit_transposes.find(cnode) != visit_transposes.end()) { | |||
| continue; | |||
| } | |||
| if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { | |||
| auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)DecreaseTransposeForMultiOp(sub_func_graph); | |||
| sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(2)); | |||
| if (sub_func_graph == nullptr) { | |||
| return false; | |||
| } | |||
| (void)DecreaseTransposeForMultiOp(sub_func_graph); | |||
| } | |||
| std::vector<int> perm; | |||
| if (!CheckPrimitiveType(cnode, prim::kPrimTranspose) || GetTransposePerm(cnode, &perm) != lite::RET_OK || | |||
| perm != NH2NC) { | |||
| continue; | |||
| } | |||
| auto status = HandleGraphMultiNode(func_graph, cnode, &visit_transposes); | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "global optimizer failed."; | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto manager = Manage(func_graph, true); | |||
| @@ -774,11 +1045,17 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_LOG(ERROR) << "run framework transpose unify failed."; | |||
| return false; | |||
| } | |||
| // if input's format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. | |||
| // if input format of a certain op can be NHWC, can try transform this op to decrease the number of transpose op. | |||
| if (!DecreaseTransposeForSingleOp(func_graph)) { | |||
| MS_LOG(ERROR) << "run local trans insert optimizer failed."; | |||
| return false; | |||
| } | |||
| // if input format of several ops surrounded only by transpose op all can be NHWC, | |||
| // we can delete these transpose ops, and at the same time, transform these middle ops. | |||
| if (!DecreaseTransposeForMultiOp(func_graph)) { | |||
| MS_LOG(ERROR) << "run global trans insert optimizer failed."; | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace opt | |||
| @@ -48,6 +48,7 @@ class UnifyFormatPass : public Pass { | |||
| bool ResetFuncGraph(const FuncGraphPtr &func_graph); | |||
| bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); | |||
| bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); | |||
| bool DecreaseTransposeForMultiOp(const FuncGraphPtr &func_graph); | |||
| bool TransTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS PostTransposeFusion(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<int> perm, bool before, | |||
| @@ -55,6 +56,8 @@ class UnifyFormatPass : public Pass { | |||
| void GetTransNodeFormatType(const CNodePtr &cnode, TransTypePair *trans_info); | |||
| STATUS HandleGraphInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); | |||
| STATUS HandleGraphMultiNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| std::set<CNodePtr> *visit_transposes); | |||
| STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||
| STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); | |||
| STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); | |||