diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index eee0fc670c..2c6cf10d6c 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -278,6 +278,8 @@ class AbstractTensor : public AbstractUndetermined { AbstractBasePtr Broaden() const override; AbstractBasePtr BroadenWithShape() const; AbstractBasePtr Join(const AbstractBasePtr &other) final; + int format() const { return this->format_; } + void set_format(int format) { this->format_ = format; } bool operator==(const AbstractTensor &other) const; bool operator==(const AbstractBase &other) const override; @@ -294,6 +296,9 @@ class AbstractTensor : public AbstractUndetermined { } return hash_sum; } + + protected: + int format_ = 0; }; using AbstractTensorPtr = std::shared_ptr; using AbstractTensorPtrList = std::vector; diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 36d2a4c432..e8a3a1cd68 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -97,16 +97,12 @@ if (BUILD_CONVERTER) set(PYTHON_LIBRARIES "${py_lib}") endif() include_directories(${PYTHON_INCLUDE_DIRS}) -# include(${TOP_DIR}/cmake/utils.cmake) -# include(${TOP_DIR}/cmake/dependency_utils.cmake) include(${TOP_DIR}/cmake/external_libs/json.cmake) -# include(${TOP_DIR}/cmake/dependency_securec.cmake) include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) include(${TOP_DIR}/cmake/external_libs/eigen.cmake) include_directories(${TOP_DIR}/third_party/protobuf/build/include) link_directories(${TOP_DIR}/third_party/protobuf/build/lib) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) - add_subdirectory(src/common/anf_importer) endif() if (BUILD_DEVICE) diff --git a/mindspore/lite/java/build_aar.sh b/mindspore/lite/java/build_aar.sh index d9c61ea407..7c42644376 100644 --- a/mindspore/lite/java/build_aar.sh +++ b/mindspore/lite/java/build_aar.sh @@ -17,6 +17,7 @@ fi cd ${TOP_PATH}/output/ rm -rf MSLite-0.6.0-linux_arm64 tar -zxvf MSLite-0.6.0-linux_arm64.tar.gz +mkdir -p ${BASE_PATH}/lib/ cp ${TOP_PATH}/output/MSLite-0.6.0-linux_arm64/lib/libmindspore-lite.so ${BASE_PATH}/lib/ cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${BASE_PATH}/lib/ @@ -35,6 +36,7 @@ cp ${BASE_PATH}/native/build/libmindspore-lite-jni.so ${BASE_PATH}/lib/ ## check sdk gradle cd ${BASE_PATH}/java rm -rf .gradle build gradle gradlew gradlew.bat build app/build +mkdir -p ${BASE_PATH}/java/app/libs/arm64-v8a/ rm -rf ${BASE_PATH}/java/app/libs/arm64-v8a/* cp ${BASE_PATH}/lib/*.so ${BASE_PATH}/java/app/libs/arm64-v8a/ gradle init diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc deleted file mode 100644 index 6e7aea3894..0000000000 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ /dev/null @@ -1,418 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/common/anf_exporter/anf_exporter.h" - -#include -#include -#include -#include -#include - -#include "abstract/abstract_value.h" -#include "base/core_ops.h" -#include "mindspore/core/ir/primitive.h" -// #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" -#include "src/ir/primitive_t_value.h" -#include "src/ir/tensor.h" -#include "src/param_value_lite.h" -#include "src/common/utils.h" - -namespace mindspore::lite { -std::set RemoveNodeInAnfExporter{"tuple_getitem", "make_tuple"}; - -void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { - bool hasMakeTuple = false; - std::vector inputs; - inputs.clear(); - - inputs.emplace_back(cnode->input(0)); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - AnfNodePtr inputNode = cnode->input(i); - if (!inputNode->isa()) { - inputs.emplace_back(cnode->input(i)); - continue; - } - auto makeTupleNode = utils::cast(inputNode); - if (IsPrimitiveCNode(makeTupleNode, prim::kPrimMakeTuple)) { - hasMakeTuple = true; - for (size_t j = 1; j < makeTupleNode->inputs().size(); ++j) { - inputs.emplace_back(makeTupleNode->input(j)); - } - } else { - inputs.emplace_back(cnode->input(i)); - } - } - if (hasMakeTuple) { - cnode->set_inputs(inputs); - } -} - -bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { - bool hasTupleGetItem = false; - std::vector inputs; - inputs.clear(); - inputs.emplace_back(cnode->input(0)); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - AnfNodePtr inputNode = cnode->input(i); - if (!inputNode->isa()) { - inputs.emplace_back(cnode->input(i)); - continue; - } - auto tupleGetItemNode = utils::cast(inputNode); - if (IsPrimitiveCNode(tupleGetItemNode, prim::kPrimTupleGetItem)) { - hasTupleGetItem = true; - inputs.emplace_back(tupleGetItemNode->input(1)); - AnfNodePtr indexNode = tupleGetItemNode->input(2); - if (!utils::isa(indexNode)) { - MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; - return false; - } - ValueNodePtr valueNode = utils::cast(indexNode); - mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = GetValue(valueNode->value()); - } else { - inputs.emplace_back(cnode->input(i)); - } - } - if (hasTupleGetItem) { - cnode->set_inputs(inputs); - } - return true; -} - -bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr &metaGraphT, const CNodePtr &cnode) { - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - auto inputNode = cnode->input(i); - if (!inputNode->isa()) { - MS_LOG(ERROR) << "Node of Return's input is not CNode"; - return false; - } - auto inputCNode = utils::cast(inputNode); - auto inputPrimitive = GetValueNode(inputCNode->input(0)); - std::string inputName = inputNode->fullname_with_scope(); - auto graphOutput = nodeIdMap[inputName]; - metaGraphT->outputIndex.emplace_back(graphOutput); - } - return true; -} - -schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { - auto cnodes = funcGraph->GetOrderedCnodes(); - auto metaGraphT = std::make_unique(); - for (const auto &cnode : cnodes) { - auto primitive = GetValueNode(cnode->input(0)); - if (primitive != nullptr) { - if (RemoveNodeInAnfExporter.count(primitive->name()) != 0) { - continue; - } - } else { - auto primitiveT_value = GetValueNode>(cnode->input(0)); - auto primT = primitiveT_value->GetPrimitiveT(); - if (primT->value.type == schema::PrimitiveType_TupleGetItem || - primT->value.type == schema::PrimitiveType_MakeTuple) { - continue; - } - } - mapRemoveGetItem_.clear(); - RemoveIfMakeTuple(cnode); - RemoveIfTupleGetItem(cnode); - - if (primitive != nullptr) { - if (primitive->name() == prim::kPrimReturn->name()) { - AddOutPutIfReturn(metaGraphT, cnode); - continue; - } - } else { - auto primitiveT_value = GetValueNode>(cnode->input(0)); - auto primT = primitiveT_value->GetPrimitiveT(); - if (primT->value.type == schema::PrimitiveType_Return) { - AddOutPutIfReturn(metaGraphT, cnode); - continue; - } - } - - auto node = std::make_unique(); - node->name = cnode->fullname_with_scope(); - node->nodeType = schema::NodeType_CNode; - // populate primitive - // if (primitive != nullptr) { - // primitive = GetValueNode(cnode->input(0)); - // MS_ASSERT(primitive != nullptr); - // std::string opType = primitive->name(); - // auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); - // if (nodeParser == nullptr) { - // MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; - // return nullptr; - // } - // std::vector outputs; - // if (utils::isa(cnode->abstract())) { - // auto abstract_cnode = utils::cast(cnode->abstract()); - // outputs.resize(abstract_cnode->size()); - // } - // - // nodeParser->Parse(cnode, node.get(), &outputs); - // SetOpInputNode(cnode, metaGraphT.get(), node.get()); - // SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); - // metaGraphT->nodes.emplace_back(std::move(node)); - // continue; - // } - auto primitiveT_value = GetValueNode>(cnode->input(0)); - if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; - return nullptr; - } - - auto *lite_primitive = primitiveT_value->GetPrimitiveT(); - if (lite_primitive == nullptr) { - MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr"; - return nullptr; - } - - node->primitive = std::unique_ptr(primitiveT_value->GetPrimitiveT()); - std::vector outputs; - SetOpInputNode(cnode, metaGraphT.get(), node.get()); - SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); - - // add quant param - node->quantType = primitiveT_value->GetQuantType(); - if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) { - MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; - // activation - auto input_quant_params = primitiveT_value->GetInputQuantParams(); - auto node_type = primitiveT_value->GetPrimitiveT()->value.type; - for (int i = 0; i < input_quant_params.size(); i++) { - if (i >= node->inputIndex.size()) { - MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size() - << " quant_params; but only " << node->inputIndex.size() << " input"; - break; - } - auto activate_index = node->inputIndex[i]; - auto tensor_input = metaGraphT->allTensors[activate_index].get(); - if (tensor_input->quantParams.empty()) { - for (auto input_quant_param : input_quant_params[i]) { - std::unique_ptr input_quant_param_ptr = - std::make_unique(input_quant_param); - MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale - << " zp: " << input_quant_param_ptr->zeroPoint; - tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); - } - } - } - - // output - auto output_index = node->outputIndex[0]; - auto tensor_output = metaGraphT->allTensors[output_index].get(); - auto output_quant_params = primitiveT_value->GetOutputQuantParams(); - if (output_quant_params.empty()) { - MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; - } else { - for (auto output_quant_param : output_quant_params[0]) { - if (tensor_output->quantParams.empty()) { - std::unique_ptr output_quant_param_ptr = - std::make_unique(output_quant_param); - MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale - << " zp: " << output_quant_param_ptr->zeroPoint; - tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); - } - } - } - if (node->quantType != schema::QuantType_AwareTraining && - !(node_type == schema::PrimitiveType_QuantDTypeCast && - primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { - tensor_output->dataType = kNumberTypeInt8; - } - // // TensorType - // valuePtr = primitive->GetAttr(kInputTensorDataType); - // if (valuePtr != nullptr) { - // MS_LOG(INFO) << "node: " << node->name << " input tensor data - // type: " << GetValue(valuePtr); for (auto input : - // node->inputIndex) { - // auto tensor = subGraph->allTensors[input].get(); - // tensor->dataType = kNumberTypeUInt8; - // } - // } - } - - metaGraphT->nodes.emplace_back(std::move(node)); - } - // set graph input tensors - for (auto node : graphInputNodes) { - for (auto input : node->inputIndex) { - auto tensor = metaGraphT->allTensors[input].get(); - if (tensor->data.empty()) { - tensor->nodeType = schema::NodeType_ValueNode; - tensor->format = schema::Format_NHWC; - if (!IsContain(metaGraphT->inputIndex, input)) { - metaGraphT->inputIndex.emplace_back(input); - } - } - } - } - return metaGraphT.release(); -} - -void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) { - MS_ASSERT(nullptr != meta_graph); - MS_ASSERT(nullptr != fbNode); - if (cnode->inputs().size() <= 1) { - return; - } - std::string cNodeName = cnode->fullname_with_scope(); - bool isGraphInput = true; - for (int i = 1; i < static_cast(cnode->inputs().size()); i++) { - auto inputNode = cnode->input(i); - if (inputNode->isa()) { - isGraphInput = false; - std::string inputName = inputNode->fullname_with_scope(); - if (!mapRemoveGetItem_.empty()) { - for (auto name : mapRemoveGetItem_) { - if (name.first == inputName) { - inputName = inputName + "_o:" + std::to_string(name.second); - } - } - } - if (nodeIdMap.find(inputName) != nodeIdMap.end()) { - fbNode->inputIndex.emplace_back(nodeIdMap[inputName]); - } - } else if (inputNode->isa()) { - auto paramNode = inputNode->cast(); - if (paramNode->name().empty()) { - paramNode->set_name(cNodeName + "_i:" + std::to_string(i - 1)); - } - if (nodeIdMap.find(paramNode->name()) != nodeIdMap.end()) { - fbNode->inputIndex.emplace_back(nodeIdMap[paramNode->name()]); - continue; - } - auto paramTensor = std::make_unique(); - auto abstractBase = paramNode->abstract(); - if (abstractBase == nullptr) { - MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); - MS_ASSERT(false); - return; - } - if (!utils::isa(abstractBase)) { - MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); - MS_ASSERT(false); - return; - } - auto abstractTensor = utils::cast(abstractBase); - auto typePtr = abstractTensor->element()->GetTypeTrack(); - MS_ASSERT(typePtr != nullptr); - paramTensor->dataType = typePtr->type_id(); - if (!utils::isa(abstractTensor->BuildShape())) { - MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); - MS_ASSERT(false); - return; - } - paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); - auto paramValue = std::dynamic_pointer_cast(paramNode->default_param()); - if (paramValue != nullptr) { - paramTensor->nodeType = schema::NodeType_ValueNode; - paramTensor->data.resize(paramValue->tensor_size()); - memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); - } - nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); - fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); - meta_graph->allTensors.emplace_back(std::move(paramTensor)); - } else if (inputNode->isa()) { - auto valueNode = inputNode->cast(); - auto paramTensor = std::make_unique(); - auto value = valueNode->value(); - if (value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractTensor = utils::cast(valueAbstract); - auto typePtr = abstractTensor->element()->GetTypeTrack(); - paramTensor->dataType = typePtr->type_id(); - paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); - paramTensor->nodeType = schema::NodeType_ValueNode; - auto data = value->cast(); - paramTensor->data.resize(data->Size()); - memcpy(paramTensor->data.data(), data->Data(), data->Size()); - nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); - fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); - meta_graph->allTensors.emplace_back(std::move(paramTensor)); - } else if (value->isa()) { - auto valueAbstract = valueNode->abstract(); - auto abstractScalar = utils::cast(valueAbstract); - auto typePtr = abstractScalar->GetTypeTrack(); - paramTensor->dataType = typePtr->type_id(); - paramTensor->dims = {1}; - paramTensor->nodeType = schema::NodeType_ValueNode; - auto data = value->cast(); - paramTensor->data.emplace_back(data->value()); - nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); - fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); - meta_graph->allTensors.emplace_back(std::move(paramTensor)); - } else if (value->isa()) { - MS_LOG(INFO) << "Value type is ValueSequence."; - break; - } else { - MS_LOG(ERROR) << "Not support value type , need add support."; - } - } - } - if (isGraphInput) { - graphInputNodes.emplace_back(fbNode); - } -} - -void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector &outputTensors, - schema::MetaGraphT *graph, schema::CNodeT *fbnode) { - MS_ASSERT(nullptr != graph); - MS_ASSERT(nullptr != fbnode); - std::string cnodeName = fbnode->name; - if (!outputTensors.empty()) { - int i = 0; - for (auto outputTensor : outputTensors) { - std::string name = cnodeName + "_o:" + std::to_string(i); - auto msTensor = new schema::TensorT(); - msTensor->nodeType = schema::NodeType_Parameter; - nodeIdMap[name] = graph->allTensors.size(); - fbnode->outputIndex.emplace_back(graph->allTensors.size()); - graph->allTensors.emplace_back(msTensor); - i++; - } - return; - } - - if (utils::isa(cnode->abstract())) { - auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); - for (int i = 0; i < tuple->size(); i++) { - auto msTensor = new schema::TensorT(); - msTensor->nodeType = schema::NodeType_Parameter; - fbnode->outputIndex.emplace_back(graph->allTensors.size()); - if (tuple->size() == 1) { - nodeIdMap[cnodeName] = graph->allTensors.size(); - } else { - std::string name = cnodeName + "_o:" + std::to_string(i); - nodeIdMap[name] = graph->allTensors.size(); - } - graph->allTensors.emplace_back(msTensor); - } - } else { - auto msTensor = new schema::TensorT(); - msTensor->nodeType = schema::NodeType_Parameter; - fbnode->outputIndex.emplace_back(graph->allTensors.size()); - nodeIdMap[cnodeName] = graph->allTensors.size(); - graph->allTensors.emplace_back(msTensor); - } -} - -schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) { - AnfExporter anfExporter; - return anfExporter.Export(funcGraph); -} -} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.h b/mindspore/lite/src/common/anf_exporter/anf_exporter.h deleted file mode 100644 index c4f5d7a398..0000000000 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). - * - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ -#define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ - -#include -#include -#include -#include -#include "schema/inner/model_generated.h" -#include "ir/func_graph.h" - -namespace mindspore::lite { -class AnfExporter { - public: - AnfExporter() = default; - virtual ~AnfExporter() = default; - schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); - void SetOpOutputNode(const CNodePtr &cnode, const std::vector &outputTensors, - schema::MetaGraphT *graph, schema::CNodeT *fbnode); - void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); - void RemoveIfMakeTuple(const CNodePtr &cnode); - bool RemoveIfTupleGetItem(const CNodePtr &cnode); - bool AddOutPutIfReturn(const std::unique_ptr &metaGraphT, const CNodePtr &cnode); - private: - std::map nodeIdMap; - std::vector graphInputNodes; - std::map mapRemoveGetItem_; -}; - -schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); -} // namespace mindspore::lite -#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ diff --git a/mindspore/lite/src/common/anf_importer/CMakeLists.txt b/mindspore/lite/src/common/anf_importer/CMakeLists.txt deleted file mode 100644 index 07111121bf..0000000000 --- a/mindspore/lite/src/common/anf_importer/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - *.cc - ) -list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc) -add_library(anf_importer_mid OBJECT - ${ANF_SRC_LIST} - ) diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc deleted file mode 100644 index b0d96fa3ee..0000000000 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.cc +++ /dev/null @@ -1,169 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/common/anf_importer/import_from_meta_graph.h" -#include -#include -#include -#include "frontend/operator/ops.h" -#include "src/param_value_lite.h" -#include "utils/log_adapter.h" -#include "abstract/abstract_value.h" -#include "src/ir/primitive_value.h" -#include "include/errorcode.h" - -namespace mindspore::lite { -void AnfImporterFromMetaGraph::ConverterConstTensor() { - MS_EXCEPTION_IF_NULL(model_); - auto *meta_graph = model_->GetMetaGraph(); - MS_EXCEPTION_IF_NULL(meta_graph); - num_of_tensors_ = meta_graph->allTensors()->size(); - for (size_t i = 0; i < num_of_tensors_; i++) { - auto *tensor = meta_graph->allTensors()->GetAs(i); - MS_EXCEPTION_IF_NULL(tensor); - if ((tensor->nodeType() != schema::NodeType_ValueNode) && (tensor->nodeType() != schema::NodeType_Parameter)) { - continue; - } - MS_ASSERT(tensor->dims() != nullptr); - auto parameter = model_->add_parameter(); - std::vector shape; - for (size_t j = 0; j < tensor->dims()->size(); ++j) { - shape.push_back(tensor->dims()->data()[j]); - } - auto type_id = static_cast(tensor->dataType()); // todo: check error - auto type_ptr = TypeIdToType(type_id); - auto abstractBase = std::make_shared(type_ptr, shape); - // XXX TODO copy format - parameter->set_abstract(abstractBase); - parameter->set_name(std::string("Parameter")); - - if (tensor->nodeType() == schema::NodeType_ValueNode) { - ParamValueLitePtr param_value = std::make_shared(); - MS_EXCEPTION_IF_NULL(param_value); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - if (tensor->data() != nullptr) { - auto size = tensor->data()->size(); - char *tensor_data = new char[size](); - std::memcpy(tensor_data, tensor->data()->data(), size); - MS_EXCEPTION_IF_NULL(tensor_data); - param_value->set_tensor_addr(tensor_data); - param_value->set_tensor_size(size); - } - parameter->set_default_param(param_value); - } - AddNode(i, parameter); - model_->AddAnfNode(i, parameter); - } -} - -int AnfImporterFromMetaGraph::ConverterCNode() { - MS_EXCEPTION_IF_NULL(model_); - auto *meta_graph = model_->GetMetaGraph(); - MS_EXCEPTION_IF_NULL(meta_graph); - - // Crate CNode -- Order of inputs is as follows - // First input should be the Primitive - // Then we have CNodes that contribute to this CNode - // Finally we Have the parameters - - // first itteration -- create CNode with primitive, create originator map - for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { - auto cNode = meta_graph->nodes()->GetAs(i); - MS_EXCEPTION_IF_NULL(cNode); - auto prim = std::make_shared(model_->GetOp(cNode->name()->str())); - if (prim == nullptr) { - MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; - return RET_ERROR; - } - auto value_node = NewValueNode(prim); - // auto prim_name = std::string("PrimitivePy: ") + std::string(cNode->name()->c_str()); - // value_node->set_fullname_with_scope(prim_name); - std::vector op_inputs = {value_node}; - - auto cnode = model_->NewCNode(op_inputs); - auto node_name = std::string(cNode->name()->c_str()) + std::to_string(i); - cnode->set_fullname_with_scope(node_name); - AddNode(num_of_tensors_ + i, cnode); - - for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { - int tensor_id = cNode->outputIndex()->data()[j]; - originator_[tensor_id] = cnode; - } - } - // second itteration -- fill in input CNodes and Parameters - // populate map - for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { - std::vector input; - std::vector output; - int tensor_id; - auto cNode = meta_graph->nodes()->GetAs(i); - MS_EXCEPTION_IF_NULL(cNode); - auto cnode = std::dynamic_pointer_cast(GetNode(num_of_tensors_ + i)); - - for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { - tensor_id = cNode->outputIndex()->data()[j]; - output.push_back(tensor_id); - } - - MS_EXCEPTION_IF_NULL(cNode->inputIndex()); - for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { - tensor_id = cNode->inputIndex()->data()[j]; - input.push_back(tensor_id); - auto *tensor = meta_graph->allTensors()->GetAs(tensor_id); - MS_EXCEPTION_IF_NULL(tensor); - if ((tensor->nodeType() == schema::NodeType_Parameter) && (originator_[tensor_id] != nullptr)) { - cnode->add_input(originator_[tensor_id]); - } - } - // finally add all the Parameters (which are ValueNodes) - for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { - tensor_id = cNode->inputIndex()->data()[j]; - auto *tensor = meta_graph->allTensors()->GetAs(tensor_id); - MS_EXCEPTION_IF_NULL(tensor); - if ((tensor->nodeType() == schema::NodeType_ValueNode) && (GetNode(tensor_id) != nullptr)) { - cnode->add_input(GetNode(tensor_id)); - } - } - - model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); - } - - return RET_OK; -} - -void AnfImporterFromMetaGraph::AddReturnCNode() { - MS_EXCEPTION_IF_NULL(model_); - auto *meta_graph = model_->GetMetaGraph(); - MS_EXCEPTION_IF_NULL(meta_graph); - std::vector input; - std::vector output; - std::vector op_inputs; - auto value_node = NewValueNode(prim::kPrimReturn); - // value_node->set_fullname_with_scope("Primitive"); - op_inputs.push_back(value_node); - for (int i = 0; i < meta_graph->outputIndex()->size(); i++) { - auto prev_cnode = originator_[meta_graph->outputIndex()->data()[i]]; - if (prev_cnode != nullptr) op_inputs.push_back(prev_cnode); - input.push_back(meta_graph->outputIndex()->data()[i]); - } - auto cnode = model_->NewCNode(op_inputs); - cnode->set_fullname_with_scope("return"); - model_->set_return(cnode); - model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); -} -FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model_; } -} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h b/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h deleted file mode 100644 index b8389f42d1..0000000000 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graph.h +++ /dev/null @@ -1,49 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ -#define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ - -#include -#include -#include "src/train/model_impl.h" -#include "schema/model_generated.h" -#include "src/common/anf_importer/anf_importer.h" - -namespace mindspore::lite { -class AnfImporterFromMetaGraph : public AnfImporter { - public: - explicit AnfImporterFromMetaGraph(std::shared_ptr model) : model_(model) {} - - ~AnfImporterFromMetaGraph() override = default; - - FuncGraphPtr GetResult() override; - - private: - void ConverterConstTensor() override; - - int ConverterCNode() override; - - void AddReturnCNode() override; - - private: - std::shared_ptr model_ = nullptr; - std::map originator_; - int num_of_tensors_ = 0; -}; -} // namespace mindspore::lite - -#endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc deleted file mode 100644 index 4056480f4b..0000000000 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include "schema/inner/model_generated.h" -#include "frontend/operator/ops.h" -#include "src/param_value_lite.h" -#include "import_from_meta_graphT.h" -#include "utils/log_adapter.h" -#include "abstract/abstract_value.h" -#include "src/ir/primitive_value.h" -#include "src/ir/primitive_t_value.h" -#include "include/errorcode.h" -#include "src/ops/ops.h" - -namespace mindspore::lite { -void AnfImporterFromMetaGraphT::ConverterConstTensor() { - MS_EXCEPTION_IF_NULL(meta_graph_); - MS_EXCEPTION_IF_NULL(func_graph_); - for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { - auto &tensor = meta_graph_->allTensors.at(i); - MS_EXCEPTION_IF_NULL(tensor); - if (tensor->nodeType != schema::NodeType_ValueNode) { - continue; - } - MS_ASSERT(tensor->dims() != nullptr); - auto parameter = func_graph_->add_parameter(); - std::vector shape; - for (int &dim : tensor->dims) { - shape.push_back(dim); - } - auto type_id = static_cast(tensor->dataType); - auto type_ptr = TypeIdToType(type_id); - auto abstract_tensor = std::make_shared(type_ptr, shape); - parameter->set_abstract(abstract_tensor); - - ParamValueLitePtr param_value = std::make_shared(); - MS_EXCEPTION_IF_NULL(param_value); - param_value->set_tensor_shape(shape); - param_value->set_tensor_type(type_id); - if (!tensor->data.empty()) { - auto size = tensor->data.size(); - char *tensor_data = new char[size]; - std::memcpy(tensor_data, tensor->data.data(), size); - MS_EXCEPTION_IF_NULL(tensor_data); - param_value->set_tensor_addr(tensor_data); - param_value->set_tensor_size(size); - } - if (tensor->quantParams.size() > 0) { - std::unique_ptr quantParam = std::make_unique(); - quantParam->scale = tensor->quantParams[0]->scale; - quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; - quantParam->min = tensor->quantParams[0]->min; - quantParam->max = tensor->quantParams[0]->max; - quantParam->narrowRange = tensor->quantParams[0]->narrowRange; - quantParam->numBits = tensor->quantParams[0]->numBits; - quantParam->inited = tensor->quantParams[0]->inited; - param_value->set_quant_param(quantParam); - } - parameter->set_default_param(param_value); - AddNode(i, parameter); - } -} - -int AnfImporterFromMetaGraphT::ConverterCNode() { - MS_EXCEPTION_IF_NULL(meta_graph_); - MS_EXCEPTION_IF_NULL(func_graph_); - for (size_t i = 0; i < meta_graph_->nodes.size(); i++) { - auto &cNode = meta_graph_->nodes.at(i); - MS_EXCEPTION_IF_NULL(cNode); - - bool flag = false; - if (cNode->outputIndex.size() > 1) { - flag = true; - } - auto primTValue = std::make_shared(cNode->primitive.release()); - // add quant parameter - if (cNode->quantType == schema::QuantType_AwareTraining) { - primTValue->SetQuantType(cNode->quantType); - for (int index : cNode->inputIndex) { - std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; - primTValue->AddInputQuantParam(quant_params); - } - for (int index : cNode->outputIndex) { - std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; - primTValue->AddOutputQuantParam(quant_params); - } - } - cNode->primitive = nullptr; - auto value_node = NewValueNode(primTValue); - - std::vector op_inputs = {value_node}; - for (size_t j = 0; j < cNode->inputIndex.size(); j++) { - auto node = GetNode(cNode->inputIndex.at(j)); - if (nullptr == node) { - MS_LOG(ERROR) << "Can't find input node."; - return RET_ERROR; - } - // todo: CheckInputNodeType, the first node should be op; - op_inputs.push_back(node); - } - - auto new_cnode = func_graph_->NewCNode(op_inputs); - new_cnode->set_fullname_with_scope(cNode->name); - - std::vector out_tensor_ids = cNode->outputIndex; - - AbstractBasePtrList ptr_list; - int total = 0; - for (auto out_tensor_id : out_tensor_ids) { - if (nullptr != GetNode(out_tensor_id)) { - ptr_list.push_back(GetNode(out_tensor_id)->abstract()); - continue; - } - std::vector shape; - auto &tensor = meta_graph_->allTensors.at(out_tensor_id); - for (int &dim : tensor->dims) { - shape.push_back(dim); - } - auto type_id = static_cast(tensor->dataType); - auto type_ptr = TypeIdToType(type_id); - auto abstract_tensor = std::make_shared(type_ptr, shape); - auto getItemPrim = NewValueNode(prim::kPrimTupleGetItem); - if (flag) { - auto getItemIndex = NewValueNode(MakeValue(total++)); - std::vector inputs{getItemPrim, new_cnode, getItemIndex}; - CNodePtr new_item_cnode = func_graph_->NewCNode(inputs); - AddNode(out_tensor_id, new_item_cnode); - } else { - AddNode(out_tensor_id, new_cnode); - } - ptr_list.push_back(std::move(abstract_tensor)); - } - new_cnode->set_abstract(std::make_shared(ptr_list)); - } - return RET_OK; -} - -void AnfImporterFromMetaGraphT::AddReturnCNode() { - MS_EXCEPTION_IF_NULL(meta_graph_); - MS_EXCEPTION_IF_NULL(func_graph_); - std::vector make_tuple_inputs; - auto make_tuple_value_node = NewValueNode(prim::kPrimMakeTuple); - make_tuple_inputs.emplace_back(make_tuple_value_node); - for (auto tensor_id : meta_graph_->outputIndex) { - make_tuple_inputs.emplace_back(GetNode(tensor_id)); - } - auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); - make_tuple_cnode->set_fullname_with_scope("return tuple"); - - std::vector op_inputs; - auto value_node = NewValueNode(prim::kPrimReturn); - op_inputs.emplace_back(value_node); - op_inputs.emplace_back(make_tuple_cnode); - auto cnode = func_graph_->NewCNode(op_inputs); - cnode->set_fullname_with_scope("return"); - func_graph_->set_return(cnode); -} - -FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } -} // namespace mindspore::lite diff --git a/mindspore/lite/src/ir/primitive_t_value.cc b/mindspore/lite/src/ir/primitive_t_value.cc index 9c27cc66fd..ecb757c6ee 100644 --- a/mindspore/lite/src/ir/primitive_t_value.cc +++ b/mindspore/lite/src/ir/primitive_t_value.cc @@ -15,3 +15,26 @@ */ #include "src/ir/primitive_t_value.h" + +namespace mindspore::lite { +std::shared_ptr GetReturnPrim() { + auto return_primitiveT = new schema::PrimitiveT; + return_primitiveT->value.type = schema::PrimitiveType_Return; + return_primitiveT->value.value = new schema::ReturnT; + return std::make_shared(return_primitiveT); +} + +std::shared_ptr GetMakeTuplePrim() { + auto make_tuple_primitiveT = new schema::PrimitiveT; + make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple; + make_tuple_primitiveT->value.value = new schema::MakeTupleT; + return std::make_shared(make_tuple_primitiveT); +} + +std::shared_ptr GetTupleGetItemPrim() { + auto tuple_get_item_primitiveT = new schema::PrimitiveT(); + tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem; + tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT; + return std::make_shared(tuple_get_item_primitiveT); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h index bc71bc7306..4f4ce9ac5b 100644 --- a/mindspore/lite/src/ir/primitive_t_value.h +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -18,8 +18,9 @@ #define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ #include +#include +#include "schema/inner/model_generated.h" #include "ir/value.h" -#include "mindspore/lite/schema/inner/model_generated.h" namespace mindspore::lite { @@ -46,22 +47,17 @@ class PrimitiveTValue : public Value { } } - void SetInputQuantParam(std::vector> vec_quant_param) { - } + void SetInputQuantParam(std::vector> vec_quant_param) {} void AddInputQuantParam(std::vector quant_param) { this->input_quant_param_.emplace_back(quant_param); } - std::vector> GetInputQuantParams() const { - return input_quant_param_; - } + std::vector> GetInputQuantParams() const { return input_quant_param_; } void AddOutputQuantParam(std::vector quant_param) { this->output_quant_param_.emplace_back(quant_param); } - std::vector> GetOutputQuantParams() const { - return output_quant_param_; - } + std::vector> GetOutputQuantParams() const { return output_quant_param_; } void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } @@ -73,7 +69,12 @@ class PrimitiveTValue : public Value { std::vector> output_quant_param_; schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; }; + +std::shared_ptr GetReturnPrim(); + +std::shared_ptr GetMakeTuplePrim(); + +std::shared_ptr GetTupleGetItemPrim(); } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ - diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index cf94fd8319..7aa7160e16 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -170,6 +170,7 @@ set(TEST_LITE_SRC ${LITE_DIR}/src/runtime/thread_pool.cc ${LITE_DIR}/src/runtime/workspace_pool.cc ${LITE_DIR}/src/ir/tensor.cc + ${LITE_DIR}/src/ir/primitive_t_value.cc ${LITE_DIR}/src/context.cc ${LITE_DIR}/src/executor.cc ${LITE_DIR}/src/kernel_factory.cc @@ -218,9 +219,6 @@ if(BUILD_CONVERTER) ${TEST_CASE_TFLITE_PARSERS_SRC} ${TOP_DIR}/mindspore/core/utils/flags.cc ${LITE_DIR}/tools/converter/optimizer.cc - # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc - # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc - # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc ${LITE_DIR}/tools/converter/anf_transform.cc ${LITE_DIR}/tools/converter/graphdef_transform.cc ${LITE_DIR}/tools/converter/converter_flags.cc @@ -300,7 +298,6 @@ if (SUPPORT_TRAIN) set(TEST_SRC ${TEST_SRC} ${TEST_CASE_KERNEL_TRAIN_SRC} - # ${TEST_DIR}/ut/src/train_test.cc ${TEST_DIR}/ut/src/infer_test.cc # temporary ) else() @@ -350,6 +347,7 @@ endif() if (BUILD_CONVERTER) target_link_libraries(lite-test anf_importer_mid + anf_exporter_mid tflite_parser_mid caffe_parser_mid onnx_parser_mid diff --git a/mindspore/lite/test/ut/src/infer_test.cc b/mindspore/lite/test/ut/src/infer_test.cc index 47c58abdf8..7f46069ff5 100644 --- a/mindspore/lite/test/ut/src/infer_test.cc +++ b/mindspore/lite/test/ut/src/infer_test.cc @@ -246,7 +246,7 @@ TEST_F(InferTest, TestAddNode) { TEST_F(InferTest, TestModel) { auto buf = new char *[1]; size_t model_size; - std::string model_path = "./model.ms"; + std::string model_path = "./models/model_hebing_3branch.ms"; ReadFile(model_path.c_str(), &model_size, buf); ASSERT_NE(nullptr, buf[0]); diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc index b86fb765ee..912d2dd38b 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/constant_folding_fusion_test.cc @@ -25,7 +25,7 @@ #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" namespace mindspore { class ConstantFoldingFusionTest : public mindspore::CommonTest { diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc index 24d494c46c..b73d5b652d 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_activation_fusion_test.cc @@ -24,7 +24,7 @@ #include "utils/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" namespace mindspore { class ConvActivationFusionTest : public mindspore::CommonTest { diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc index 32f203275e..e5b44ea8c8 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_biasadd_fusion_test.cc @@ -24,7 +24,7 @@ #include "utils/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" namespace mindspore { class ConvBiasAddFusionTest : public mindspore::CommonTest { diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc index 347b4498e1..8dcd9789ba 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_bn_fusion_test.cc @@ -24,7 +24,7 @@ #include "mindspore/core/utils/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" namespace mindspore { class ConvBNFusionTest : public mindspore::CommonTest { diff --git a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc index 92f119c74a..06f47ed5a3 100644 --- a/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc +++ b/mindspore/lite/test/ut/tools/optimizer/fusion/conv_scale_fusion_test.cc @@ -24,7 +24,7 @@ #include "utils/log_adapter.h" #include "tools/converter/model_parser.h" #include "tools/converter/anf_transform.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" namespace mindspore { class ConvScaleFusionTest : public mindspore::CommonTest { diff --git a/mindspore/lite/tools/anf_exporter/CMakeLists.txt b/mindspore/lite/tools/anf_exporter/CMakeLists.txt new file mode 100644 index 0000000000..d6c4808226 --- /dev/null +++ b/mindspore/lite/tools/anf_exporter/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE ANF_EXPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +add_library(anf_exporter_mid OBJECT + ${ANF_EXPORTER_SRC_LIST} + ) diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc new file mode 100644 index 0000000000..ee70a8c2ff --- /dev/null +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -0,0 +1,431 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tools/anf_exporter/anf_exporter.h" + +#include +#include +#include +#include + +#include "abstract/abstract_value.h" +#include "mindspore/core/ir/primitive.h" +#include "src/ir/tensor.h" +#include "src/param_value_lite.h" +#include "src/common/utils.h" + +namespace mindspore::lite { +void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { + bool has_make_tuple = false; + std::vector inputs; + inputs.clear(); + + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto make_tuple_node = utils::cast(inputNode); + if (IsPrimitiveCNode(make_tuple_node, schema::PrimitiveType_MakeTuple)) { + has_make_tuple = true; + for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { + inputs.emplace_back(make_tuple_node->input(j)); + } + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (has_make_tuple) { + cnode->set_inputs(inputs); + } +} + +bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + bool has_tuple_get_item = false; + std::vector inputs; + inputs.clear(); + inputs.emplace_back(cnode->input(0)); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + AnfNodePtr inputNode = cnode->input(i); + if (!inputNode->isa()) { + inputs.emplace_back(cnode->input(i)); + continue; + } + auto tuple_get_item_node = utils::cast(inputNode); + if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) { + has_tuple_get_item = true; + inputs.emplace_back(tuple_get_item_node->input(1)); + AnfNodePtr indexNode = tuple_get_item_node->input(2); + if (!utils::isa(indexNode)) { + MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; + return false; + } + ValueNodePtr value_node = utils::cast(indexNode); + map_remove_get_item_[tuple_get_item_node->input(1)->fullname_with_scope()] = GetValue(value_node->value()); + } else { + inputs.emplace_back(cnode->input(i)); + } + } + if (has_tuple_get_item) { + cnode->set_inputs(inputs); + } + return true; +} + +bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr &meta_graphT, const CNodePtr &cnode) { + MS_ASSERT(meta_graphT != nullptr); + MS_ASSERT(cnode != nullptr); + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + auto inputNode = cnode->input(i); + if (!inputNode->isa()) { + MS_LOG(ERROR) << "Node of Return's input is not CNode"; + return false; + } + auto inputCNode = utils::cast(inputNode); + std::string inputName = inputNode->fullname_with_scope(); + auto graphOutput = node_id_map_[inputName]; + meta_graphT->outputIndex.emplace_back(graphOutput); + } + return true; +} + +int AnfExporter::ConvertQuantParam(const std::unique_ptr &meta_graph, + const std::shared_ptr primitive, + const std::unique_ptr &dst_node) { + MS_ASSERT(meta_graph != nullptr); + MS_ASSERT(primitive != nullptr); + MS_ASSERT(dst_node != nullptr); + // add quant param + dst_node->quantType = primitive->GetQuantType(); + if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) { + MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; + // activation + auto input_quant_params = primitive->GetInputQuantParams(); + auto node_type = primitive->GetPrimitiveT()->value.type; + for (int i = 0; i < input_quant_params.size(); i++) { + if (i >= dst_node->inputIndex.size()) { + MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() + << " quant_params; but only " << dst_node->inputIndex.size() << " input"; + break; + } + auto activate_index = dst_node->inputIndex[i]; + auto tensor_input = meta_graph->allTensors[activate_index].get(); + if (tensor_input->quantParams.empty()) { + for (auto input_quant_param : input_quant_params[i]) { + std::unique_ptr input_quant_param_ptr = + std::make_unique(input_quant_param); + MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale + << " zp: " << input_quant_param_ptr->zeroPoint; + tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); + } + } + } + + // output + auto output_index = dst_node->outputIndex[0]; + auto tensor_output = meta_graph->allTensors[output_index].get(); + auto output_quant_params = primitive->GetOutputQuantParams(); + if (output_quant_params.empty()) { + MS_LOG(WARNING) << "node: " << dst_node->name << " output quant params is empty"; + } else { + for (auto output_quant_param : output_quant_params[0]) { + if (tensor_output->quantParams.empty()) { + std::unique_ptr output_quant_param_ptr = + std::make_unique(output_quant_param); + MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale + << " zp: " << output_quant_param_ptr->zeroPoint; + tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); + } + } + } + if (dst_node->quantType != schema::QuantType_AwareTraining && + !(node_type == schema::PrimitiveType_QuantDTypeCast && + primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { + tensor_output->dataType = kNumberTypeInt8; + } + // // TensorType + // valuePtr = primitive->GetAttr(kInputTensorDataType); + // if (valuePtr != nullptr) { + // MS_LOG(INFO) << "node: " << node->name << " input tensor data + // type: " << GetValue(valuePtr); for (auto input : + // node->inputIndex) { + // auto tensor = subGraph->allTensors[input].get(); + // tensor->dataType = kNumberTypeUInt8; + // } + // } + } + return RET_OK; +} + +void AnfExporter::SetGraphInputIndex(const std::unique_ptr &meta_graphT) { + for (auto node : graph_input_nodes_) { + for (auto input : node->inputIndex) { + auto tensor = meta_graphT->allTensors[input].get(); + if (tensor->data.empty()) { + tensor->nodeType = schema::NodeType_ValueNode; + tensor->format = schema::Format_NHWC; + if (!IsContain(meta_graphT->inputIndex, input)) { + meta_graphT->inputIndex.emplace_back(input); + } + } + } + } +} + +schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { + auto cnodes = func_graph->GetOrderedCnodes(); + auto meta_graphT = std::make_unique(); + for (const auto &cnode : cnodes) { + auto primitiveT_value = GetValueNode>(cnode->input(0)); + if (primitiveT_value == nullptr) { + MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + return nullptr; + } + auto primT = primitiveT_value->GetPrimitiveT(); + if (primT == nullptr) { + MS_LOG(ERROR) << "PrimitiveT is nullptr"; + return nullptr; + } + if (primT->value.type == schema::PrimitiveType_TupleGetItem || + primT->value.type == schema::PrimitiveType_MakeTuple) { + continue; + } + map_remove_get_item_.clear(); + RemoveIfMakeTuple(cnode); + if (!RemoveIfTupleGetItem(cnode)) { + MS_LOG(ERROR) << "RemoveIfTupleGetItem failed"; + return nullptr; + } + + if (primT->value.type == schema::PrimitiveType_Return) { + AddOutPutIfReturn(meta_graphT, cnode); + continue; + } + + auto node = std::make_unique(); + node->name = cnode->fullname_with_scope(); + node->nodeType = schema::NodeType_CNode; + + node->primitive = std::unique_ptr(primT); + auto ret = SetOpInputNode(cnode, meta_graphT, node.get()); + if (ret != RET_OK) { + MS_LOG(ERROR) << "SetOpInputNode failed"; + return nullptr; + } + + SetOpOutputNode(cnode, meta_graphT, node.get()); + + ret = ConvertQuantParam(meta_graphT, primitiveT_value, node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvertQuantParam failed"; + return nullptr; + } + + meta_graphT->nodes.emplace_back(std::move(node)); + primitiveT_value->SetPrimitiveT(nullptr); + } + // set graph input tensors + SetGraphInputIndex(meta_graphT); + return meta_graphT.release(); +} + +void AnfExporter::ConvertInputCNode(const std::shared_ptr input_anode, schema::CNodeT *output_cnode) { + std::string input_name = input_anode->fullname_with_scope(); + if (!map_remove_get_item_.empty()) { + for (auto name : map_remove_get_item_) { + if (name.first == input_name) { + input_name = input_name + "_o:" + std::to_string(name.second); + } + } + } + if (node_id_map_.find(input_name) != node_id_map_.end()) { + output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); + } +} + +int AnfExporter::ConvertInputParameter(const std::shared_ptr input_anode, size_t anode_index, + const std::unique_ptr &meta_graphT, + schema::CNodeT *output_cnode) { + std::string input_name = input_anode->fullname_with_scope(); + auto paramNode = input_anode->cast(); + if (paramNode->name().empty()) { + paramNode->set_name(input_name + "_i:" + std::to_string(anode_index - 1)); + } + if (node_id_map_.find(paramNode->name()) != node_id_map_.end()) { + output_cnode->inputIndex.emplace_back(node_id_map_[paramNode->name()]); + return RET_OK; + } + auto paramTensor = std::make_unique(); + auto abstractBase = paramNode->abstract(); + if (abstractBase == nullptr) { + MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); + return RET_ERROR; + } + if (!utils::isa(abstractBase)) { + MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); + return RET_ERROR; + } + auto abstractTensor = utils::cast(abstractBase); + auto typePtr = abstractTensor->element()->GetTypeTrack(); + MS_ASSERT(typePtr != nullptr); + paramTensor->dataType = typePtr->type_id(); + if (!utils::isa(abstractTensor->BuildShape())) { + MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); + return RET_ERROR; + } + paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); + auto paramValue = std::dynamic_pointer_cast(paramNode->default_param()); + if (paramValue != nullptr) { + paramTensor->nodeType = schema::NodeType_ValueNode; + paramTensor->data.resize(paramValue->tensor_size()); + memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); + } + node_id_map_[paramNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + return RET_OK; +} + +int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, + const std::unique_ptr &meta_graphT, + schema::CNodeT *output_cnode) { + auto valueNode = input_anode->cast(); + auto paramTensor = std::make_unique(); + auto value = valueNode->value(); + if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractTensor = utils::cast(valueAbstract); + auto typePtr = abstractTensor->element()->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = utils::cast(abstractTensor->BuildShape())->shape(); + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.resize(data->Size()); + memcpy(paramTensor->data.data(), data->Data(), data->Size()); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + auto valueAbstract = valueNode->abstract(); + auto abstractScalar = utils::cast(valueAbstract); + auto typePtr = abstractScalar->GetTypeTrack(); + paramTensor->dataType = typePtr->type_id(); + paramTensor->dims = {1}; + paramTensor->nodeType = schema::NodeType_ValueNode; + auto data = value->cast(); + paramTensor->data.emplace_back(data->value()); + node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); + output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); + meta_graphT->allTensors.emplace_back(std::move(paramTensor)); + } else if (value->isa()) { + MS_LOG(DEBUG) << "Value type is ValueSequence."; + return RET_OK; + } else { + MS_LOG(ERROR) << "Not support value type , need add support."; + return RET_ERROR; + } + return RET_OK; +} + +int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, + schema::CNodeT *fb_node) { + MS_ASSERT(nullptr != meta_graph); + MS_ASSERT(nullptr != fb_node); + if (cnode->inputs().size() <= 1) { + return RET_OK; + } + bool is_graph_input = true; + for (size_t i = 1; i < cnode->inputs().size(); i++) { + auto input_node = cnode->input(i); + if (input_node->isa()) { + is_graph_input = false; + ConvertInputCNode(input_node, fb_node); + } else if (input_node->isa()) { + auto ret = ConvertInputParameter(input_node, i, meta_graphT, fb_node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvertInputParameter failed"; + return RET_ERROR; + } + } else if (input_node->isa()) { + auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node); + if (ret != RET_OK) { + MS_LOG(ERROR) << "ConvertInputValueNode failed"; + return RET_ERROR; + } + } + } + fb_node->name = cnode->fullname_with_scope(); + if (is_graph_input) { + graph_input_nodes_.emplace_back(fb_node); + } + return RET_OK; +} + +void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, + schema::CNodeT *fb_node) { + MS_ASSERT(nullptr != graph); + MS_ASSERT(nullptr != fb_node); + std::string cnode_name = fb_node->name; + + if (utils::isa(cnode->abstract())) { + auto tuple = std::reinterpret_pointer_cast(cnode->abstract()); + for (int i = 0; i < tuple->size(); i++) { + auto msTensor = new schema::TensorT(); + msTensor->nodeType = schema::NodeType_Parameter; + fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); + if (tuple->size() == 1) { + node_id_map_[cnode_name] = meta_graphT->allTensors.size(); + } else { + std::string name = cnode_name + "_o:" + std::to_string(i); + node_id_map_[name] = meta_graphT->allTensors.size(); + } + meta_graphT->allTensors.emplace_back(msTensor); + } + } else { + auto ms_tensor = new schema::TensorT(); + ms_tensor->nodeType = schema::NodeType_Parameter; + fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); + node_id_map_[cnode_name] = meta_graphT->allTensors.size(); + meta_graphT->allTensors.emplace_back(ms_tensor); + } +} + +bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) { + MS_ASSERT(node != nullptr); + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; + } + + const auto &prim = GetValueNode>(cnode->input(0)); + if (prim == nullptr) { + return false; + } + auto *primitiveT = prim->GetPrimitiveT(); + if (primitiveT == nullptr) { + return false; + } + return primitiveT->value.type == type; +} + +schema::MetaGraphT *Export(const FuncGraphPtr &func_graph) { + AnfExporter anf_exporter; + return anf_exporter.Export(func_graph); +} +} // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h new file mode 100644 index 0000000000..cb05852232 --- /dev/null +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -0,0 +1,64 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ +#define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ + +#include +#include +#include +#include +#include "schema/inner/model_generated.h" +#include "src/ir/primitive_t_value.h" +#include "ir/func_graph.h" + +namespace mindspore::lite { +class AnfExporter { + public: + AnfExporter() = default; + virtual ~AnfExporter() = default; + schema::MetaGraphT *Export(const FuncGraphPtr &func_graph); + void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, + schema::CNodeT *fb_node); + int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, + schema::CNodeT *fb_node); + void RemoveIfMakeTuple(const CNodePtr &cnode); + bool RemoveIfTupleGetItem(const CNodePtr &cnode); + bool AddOutPutIfReturn(const std::unique_ptr &meta_graphT, const CNodePtr &cnode); + + protected: + void ConvertInputCNode(const std::shared_ptr input_anode, schema::CNodeT *output_cnode); + int ConvertInputParameter(const std::shared_ptr input_anode, size_t anode_index, + const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); + int ConvertInputValueNode(std::shared_ptr input_anode, + const std::unique_ptr &meta_graphT, schema::CNodeT *output_cnode); + void SetGraphInputIndex(const std::unique_ptr &meta_graphT); + bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); + int ConvertQuantParam(const std::unique_ptr &meta_graph, + const std::shared_ptr primitive, + const std::unique_ptr &dst_node); + + private: + std::map node_id_map_; + std::vector graph_input_nodes_; + std::map map_remove_get_item_; +}; + +schema::MetaGraphT *Export(const FuncGraphPtr &func_graph); +} // namespace mindspore::lite +#endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ diff --git a/mindspore/lite/tools/anf_importer/CMakeLists.txt b/mindspore/lite/tools/anf_importer/CMakeLists.txt new file mode 100644 index 0000000000..040ebda12d --- /dev/null +++ b/mindspore/lite/tools/anf_importer/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE ANF_IMPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} + *.cc + ) +add_library(anf_importer_mid OBJECT + ${ANF_IMPORTER_SRC_LIST} + ) diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.cc b/mindspore/lite/tools/anf_importer/anf_importer.cc similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_importer.cc rename to mindspore/lite/tools/anf_importer/anf_importer.cc index 2921f9422b..50a78e7fea 100644 --- a/mindspore/lite/src/common/anf_importer/anf_importer.cc +++ b/mindspore/lite/tools/anf_importer/anf_importer.cc @@ -18,7 +18,7 @@ #include #include #include -#include "src/common/anf_importer/anf_importer.h" +#include "tools/anf_importer/anf_importer.h" #include "schema/model_generated.h" #include "ir/dtype.h" #include "ir/primitive.h" @@ -160,13 +160,21 @@ void MinnieBuildGraph::FbTest(const GraphDef *graph_def) { #endif int AnfImporter::Import(const schema::QuantType &quantType) { - ConverterConstTensor(); - auto ret = ConverterCNode(); + auto ret = ConverterConstTensor(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; + return ret; + } + ret = ConverterCNode(); if (RET_OK != ret) { MS_LOG(ERROR) << "ConverterCNode failed " << ret; return ret; } - AddReturnCNode(); + ret = AddReturnCNode(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "AddReturnCNode failed " << ret; + return ret; + } return RET_OK; } @@ -181,4 +189,3 @@ AnfNodePtr AnfImporter::GetNode(int tensor_id) { void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/src/common/anf_importer/anf_importer.h b/mindspore/lite/tools/anf_importer/anf_importer.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_importer.h rename to mindspore/lite/tools/anf_importer/anf_importer.h index 87e0edd3dc..8ac3406db3 100644 --- a/mindspore/lite/src/common/anf_importer/anf_importer.h +++ b/mindspore/lite/tools/anf_importer/anf_importer.h @@ -36,11 +36,11 @@ class AnfImporter { protected: // convert const tensor into parameter and save in nodes_ - virtual void ConverterConstTensor() = 0; + virtual int ConverterConstTensor() = 0; // convert other node into cnode and save in nodes_ virtual int ConverterCNode() = 0; - virtual void AddReturnCNode() = 0; + virtual int AddReturnCNode() = 0; AnfNodePtr GetNode(int tensor_id); @@ -52,4 +52,3 @@ class AnfImporter { } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ - diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc similarity index 92% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc index 942b6b4311..12eae7f4f0 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_activation_populater.h" +#include "tools/anf_importer/anf_populater/anf_activation_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h index ab43ff7d77..d976a8a4e9 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_activation_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_activation_populater.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfActivationPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc similarity index 90% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc index f66c61562b..cca157e370 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h" +#include "tools/anf_importer/anf_populater/anf_batchnorm_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h index 84ce7e0567..92fb87e6bb 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_batchnorm_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_batchnorm_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H #define MINDSPORE_ANF_BATCHNORM_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfBatchnormPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc similarity index 89% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc index 44e2a35330..e72ce6dba8 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h" +#include "tools/anf_importer/anf_populater/anf_biasadd_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h index 3fbf17ee49..508e47ef04 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_biasadd_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_biasadd_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_BIASADD_PARSER_H #define MINDSPORE_ANF_BIASADD_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfBiasAddPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc similarity index 91% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc index 51c52eca68..30f964cb1c 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.cc @@ -16,11 +16,11 @@ * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_concat_populater.h" +#include "tools/anf_importer/anf_populater/anf_concat_populater.h" #include #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h index aa59219f92..c9af84fdad 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_concat_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_concat_populater.h @@ -18,7 +18,7 @@ #ifndef MINDSPORE_ANF_CONCAT_PARSER_H #define MINDSPORE_ANF_CONCAT_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfConcatPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc similarity index 85% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc index 4e4206845f..0a00578fa6 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.cc @@ -17,24 +17,19 @@ * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_conv_populater.h" - -#include - -#include +#include "tools/anf_importer/anf_populater/anf_conv_populater.h" #include #include - +#include +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" -#include "ir/primitive.h" -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "src/ir/tensor.h" #include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { -void AnfConvPopulater::PopulaterConv2DMultiGroup( - const PrimitivePtr &prim, - const std::unique_ptr &primitive, const int &group) { +void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim, + const std::unique_ptr &primitive, + const int &group) { auto attr = std::make_unique(); auto format = GetValue(prim->GetAttr("data_format")); if (format == "NCHW") { @@ -75,9 +70,9 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup( primitive->value.value = attr.release(); } -void AnfConvPopulater::PopulaterConv2DSingleGroup( - const PrimitivePtr &prim, - const std::unique_ptr &primitive, const int &group) { +void AnfConvPopulater::PopulaterConv2DSingleGroup(const PrimitivePtr &prim, + const std::unique_ptr &primitive, + const int &group) { auto attr = std::make_unique(); attr->group = group; auto format = GetValue(prim->GetAttr("data_format")); @@ -120,17 +115,15 @@ void AnfConvPopulater::PopulaterConv2DSingleGroup( primitive->value.value = attr.release(); } -void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, - float *mMin, float *mMax) { +void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { constexpr float qmin = 0; constexpr float qmax = 255; *mMin = static_cast((qmin - mean) / stdDev); *mMax = static_cast((qmax - mean) / stdDev); } -void AnfConvPopulater::PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecQuantParam) { +void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, + std::vector> *vecQuantParam) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim->GetAttr("num_bits"); @@ -158,8 +151,8 @@ void AnfConvPopulater::PopulaterQuantParam( quantParam.min = *minBuf; quantParam.max = *maxBuf; } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, numbitsRangeQuantParam); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); quants.emplace_back(quantParam); vecQuantParam->emplace_back(quants); @@ -176,8 +169,7 @@ void AnfConvPopulater::PopulaterQuantParam( for (int i = 0; i < biasQuantSize; ++i) { quantParam.min = *(minBuf++); quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); } @@ -189,8 +181,7 @@ void AnfConvPopulater::PopulaterQuantParam( quantParam.min = 0.0; quantParam.max = 0.0; quantParam.zeroPoint = 0; - quantParam.scale = - vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; + quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; quants.emplace_back(quantParam); } vecQuantParam->emplace_back(quants); @@ -205,15 +196,14 @@ void AnfConvPopulater::PopulaterQuantParam( float *maxBuf = static_cast(outputMaxPtr->Data()); quantParam.min = *minBuf; quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, numbitsRangeQuantParam); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); quants.emplace_back(quantParam); vecQuantParam->emplace_back(quants); } } -int AnfConvPopulater::Populate(const PrimitivePtr &prim, - PrimitiveTValue *primitiveTValuePtr, +int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) { MS_ASSERT(primitiveTValuePtr != nullptr); auto primitive = std::make_unique(); diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h similarity index 69% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h index eb2905a8bb..e4befe36df 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_conv_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h @@ -19,9 +19,10 @@ #ifndef MINDSPORE_ANF_CONV_PARSER_H #define MINDSPORE_ANF_CONV_PARSER_H + +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" namespace mindspore::lite { class AnfConvPopulater : public AnfNodePopulater { public: @@ -31,16 +32,12 @@ class AnfConvPopulater : public AnfNodePopulater { const std::vector &inputs) override; private: - void PopulaterConv2DMultiGroup( - const PrimitivePtr &prim, - const std::unique_ptr &primitive, const int &group); - void PopulaterConv2DSingleGroup( - const PrimitivePtr &prim, - const std::unique_ptr &primitive, const int &group); - void PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, - float *mMax); + void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, + const int &group); + void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, + const int &group); + void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc similarity index 82% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc index 9db2805947..6cf45542e9 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc @@ -13,31 +13,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" -#include -#include +#include "tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" #include - +#include +#include +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" -#include "ir/primitive.h" -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "src/ir/tensor.h" #include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { -void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, - const double &stdDev, float *mMin, - float *mMax) { +void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { constexpr float qmin = 0; constexpr float qmax = 255; *mMin = static_cast((qmin - mean) / stdDev); *mMax = static_cast((qmax - mean) / stdDev); } -void AnfDepwiseconv2DPopulater::PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecQuantParam) { +void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, + std::vector> *vecQuantParam) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim->GetAttr("num_bits"); @@ -65,8 +60,8 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( quantParam.min = *minBuf; quantParam.max = *maxBuf; } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, numbitsRangeQuantParam); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); quants.emplace_back(quantParam); vecQuantParam->emplace_back(quants); @@ -83,8 +78,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( for (int i = 0; i < biasQuantSize; ++i) { quantParam.min = *(minBuf++); quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); } @@ -96,8 +90,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( quantParam.min = 0.0; quantParam.max = 0.0; quantParam.zeroPoint = 0; - quantParam.scale = - vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; + quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; quants.emplace_back(quantParam); } vecQuantParam->emplace_back(quants); @@ -112,15 +105,14 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( float *maxBuf = static_cast(outputMaxPtr->Data()); quantParam.min = *minBuf; quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, numbitsRangeQuantParam); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); quants.emplace_back(quantParam); vecQuantParam->emplace_back(quants); } } -int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, - PrimitiveTValue *primitiveTValuePtr, +int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -171,13 +163,10 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, auto abstractBase = paramNode->abstract(); MS_ASSERT(abstractBase != nullptr); if (utils::isa(abstractBase)) { - auto abstractTensor = - utils::cast(abstractBase); + auto abstractTensor = utils::cast(abstractBase); MS_ASSERT(abstractTensor != nullptr); if (utils::isa(abstractTensor->BuildShape())) { - auto dims = - utils::cast(abstractTensor->BuildShape()) - ->shape(); + auto dims = utils::cast(abstractTensor->BuildShape())->shape(); attr->channelIn = dims[kAnfPopulaterOne]; } } @@ -195,8 +184,6 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, } return 0; } -AnfNodePopulaterRegistrar anfdepthwise2dPopulater( - "DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); -AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater( - "DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); +AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h similarity index 82% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h index 6377ea372f..5b58bf3b6e 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h @@ -15,9 +15,10 @@ */ #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H + +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" namespace mindspore::lite { class AnfDepwiseconv2DPopulater : public AnfNodePopulater { public: @@ -25,11 +26,10 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { ~AnfDepwiseconv2DPopulater() override = default; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) override; + private: - void PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, - float *mMax); + void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc similarity index 90% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc index 5df3a75c92..4c88cce9da 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_dequant_populater.h" +#include "tools/anf_importer/anf_populater/anf_dequant_populater.h" #include #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h index 936468d85e..77bb3f2b5f 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_dequant_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_dequant_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H #define MINDSPORE_ANF_DEQUANT_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfDequantPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc similarity index 89% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc index 0e669345e6..db80e41463 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_flatten_populater.h" +#include "tools/anf_importer/anf_populater/anf_flatten_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h index 8ec178b213..5366873fc1 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_flatten_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_flatten_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H #define MINDSPORE_ANF_FLATTEN_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfFlattenPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc similarity index 81% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc index b6bb890856..fe315d780e 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc @@ -13,29 +13,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" -#include +#include "tools/anf_importer/anf_populater/anf_matmul_populater.h" #include - +#include +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" -#include "ir/primitive.h" -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "src/ir/tensor.h" #include "tools/converter/quantizer/quantize_util.h" namespace mindspore::lite { -void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, - float *mMin, float *mMax) { +void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { constexpr float qmin = 0; constexpr float qmax = 255; *mMin = static_cast((qmin - mean) / stdDev); *mMax = static_cast((qmax - mean) / stdDev); } -void AnfMatmulPopulater::PopulaterQuantParam( - const PrimitivePtr &prim, - std::vector> *vecQuantParam) { +void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, + std::vector> *vecQuantParam) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim->GetAttr("num_bits"); @@ -63,8 +59,8 @@ void AnfMatmulPopulater::PopulaterQuantParam( quantParam.min = *minBuf; quantParam.max = *maxBuf; } - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, numbitsRangeQuantParam); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); quants.emplace_back(quantParam); vecQuantParam->emplace_back(quants); @@ -79,8 +75,7 @@ void AnfMatmulPopulater::PopulaterQuantParam( for (int i = 0; i < filterMinPtr->DataSize(); ++i) { quantParam.min = *(minBuf++); quantParam.max = *(maxBuf++); - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); } @@ -97,15 +92,14 @@ void AnfMatmulPopulater::PopulaterQuantParam( float *maxBuf = static_cast(outputMaxPtr->Data()); quantParam.min = *minBuf; quantParam.max = *maxBuf; - quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, - narrowRangeQuantParam, numbitsRangeQuantParam); + quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, + numbitsRangeQuantParam); quants.emplace_back(quantParam); vecQuantParam->emplace_back(quants); } } -int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, - PrimitiveTValue *primitiveTValuePtr, +int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) { auto primitive = std::make_unique(); auto attr = std::make_unique(); @@ -124,8 +118,6 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, return 0; } -AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", - new AnfMatmulPopulater()); -AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", - new AnfMatmulPopulater()); +AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); +AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater()); } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h similarity index 81% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h index 3ce23f5389..39b7be7f5a 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_matmul_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_MATMUL_PARSER_H #define MINDSPORE_ANF_MATMUL_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfMatmulPopulater : public AnfNodePopulater { @@ -24,11 +24,10 @@ class AnfMatmulPopulater : public AnfNodePopulater { ~AnfMatmulPopulater() override = default; int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) override; + private: - void PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, - float *mMax); + void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc similarity index 89% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc index 7edf1a9328..0ba673b49a 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_mul_populater.h" +#include "tools/anf_importer/anf_populater/anf_mul_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h index 2761300d46..30eb3f7173 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_mul_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_mul_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfMulPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc similarity index 90% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc index 9bec531c00..609a73b7ef 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.cc @@ -14,6 +14,6 @@ * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" namespace mindspore::lite {} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h similarity index 100% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater.h diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc index 1d8d36bf10..d877e48c97 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include namespace mindspore { namespace lite { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h similarity index 95% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h index 0d88eec3b1..2d1ebb74bf 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_node_populater_registry.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_node_populater_registry.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include #include namespace mindspore::lite { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc index 0aa53df227..5f06a84d2a 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_pool_populater.h" +#include "tools/anf_importer/anf_populater/anf_pool_populater.h" #include #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h index 0589172505..7aefc44409 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_pool_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_pool_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_POOL_PARSER_H #define MINDSPORE_ANF_POOL_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfPoolPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc similarity index 90% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc index 1f4c7e0716..98b858180d 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_quant_populater.h" +#include "tools/anf_importer/anf_populater/anf_quant_populater.h" #include #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h similarity index 93% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h index a9aed77da6..e7eec3cb09 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_quant_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_quant_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_QUANT_PARSER_H #define MINDSPORE_ANF_QUANT_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfQuantPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc similarity index 90% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc index 00bf3d7105..5da3735db4 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.cc @@ -13,18 +13,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_reducemean_populater.h" +#include "tools/anf_importer/anf_populater/anf_reducemean_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" namespace mindspore::lite { namespace { - constexpr int kReduceInputNum = 3; - constexpr int kReduceInputIndex = 2; -} +constexpr int kReduceInputNum = 3; +constexpr int kReduceInputIndex = 2; +} // namespace int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, const std::vector &inputs) { auto primitive = std::make_unique(); diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h index f82a3997f0..ba4a1c6b3a 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reducemean_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reducemean_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfReduceMeanPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc similarity index 92% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc index 6695faaae6..ce86fc780b 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_reshape_populater.h" +#include "tools/anf_importer/anf_populater/anf_reshape_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h index b46d931cf9..fd2d35a875 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_reshape_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_reshape_populater.h @@ -16,7 +16,7 @@ #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H #define MINDSPORE_ANF_RESHAPE_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfReshapePopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc similarity index 89% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc index d0bb01c4c9..4b8d951db8 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_tensoradd_populater.h" +#include "tools/anf_importer/anf_populater/anf_tensoradd_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h index b7ecf326fb..7b990bbb85 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tensoradd_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tensoradd_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H #define MINDSPORE_ANF_ACTIVATION_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfTensorAddPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc similarity index 92% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc index e2c1548ff6..9df1f97b33 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.cc @@ -13,11 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_transpose_populater.h" +#include "tools/anf_importer/anf_populater/anf_transpose_populater.h" #include #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h index 583912d2b1..60281c8fd1 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_transpose_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_transpose_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H #define MINDSPORE_ANF_TRANSPOSE_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfTransposePopulater : public AnfNodePopulater { diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc similarity index 89% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc rename to mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc index ec5e6b7433..a63f244ce0 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.cc @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h" +#include "tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h" #include #include -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" #include "ir/func_graph.h" #include "ir/primitive.h" diff --git a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h similarity index 94% rename from mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h rename to mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h index b6b256a39a..40f4c0a15d 100644 --- a/mindspore/lite/src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h @@ -15,7 +15,7 @@ */ #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H #define MINDSPORE_ANF_BATCHNORM_PARSER_H -#include "src/common/anf_importer/anf_populater/anf_node_populater.h" +#include "tools/anf_importer/anf_populater/anf_node_populater.h" #include namespace mindspore::lite { class AnfTupleGetItemPopulater : public AnfNodePopulater { diff --git a/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc new file mode 100644 index 0000000000..1f22b39b35 --- /dev/null +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.cc @@ -0,0 +1,194 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "schema/inner/model_generated.h" +#include "frontend/operator/ops.h" +#include "src/param_value_lite.h" +#include "import_from_meta_graphT.h" +#include "utils/log_adapter.h" +#include "include/errorcode.h" +#include "src/ops/ops.h" + +namespace mindspore::lite { +int AnfImporterFromMetaGraphT::ConverterConstTensor() { + MS_ASSERT(nullptr != meta_graph_); + MS_ASSERT(nullptr != func_graph_); + for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { + auto &tensor = meta_graph_->allTensors.at(i); + MS_ASSERT(tensor != nullptr); + // converter weight and graph input into parameter node + if (tensor->nodeType != schema::NodeType_ValueNode) { + continue; + } + MS_ASSERT(tensor->dims() != nullptr); + auto parameter = func_graph_->add_parameter(); + std::vector shape(tensor->dims.size()); + std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, shape); + abstract_tensor->set_format(tensor->format); + parameter->set_abstract(abstract_tensor); + parameter->set_name("const_" + std::to_string(i)); + + ParamValueLitePtr param_value = std::make_shared(); + MS_ASSERT(param_value != nullptr); + param_value->set_tensor_shape(shape); + param_value->set_tensor_type(type_id); + if (!tensor->data.empty()) { + auto size = tensor->data.size(); + char *tensor_data = new (std::nothrow) char[size]; + if (tensor_data == nullptr) { + MS_LOG(ERROR) << "new char[] failed"; + return RET_ERROR; + } + std::memcpy(tensor_data, tensor->data.data(), size); + param_value->set_tensor_addr(tensor_data); + param_value->set_tensor_size(size); + } + if (!tensor->quantParams.empty()) { + std::unique_ptr quantParam = std::make_unique(); + quantParam->scale = tensor->quantParams[0]->scale; + quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; + quantParam->min = tensor->quantParams[0]->min; + quantParam->max = tensor->quantParams[0]->max; + quantParam->narrowRange = tensor->quantParams[0]->narrowRange; + quantParam->numBits = tensor->quantParams[0]->numBits; + quantParam->inited = tensor->quantParams[0]->inited; + param_value->set_quant_param(quantParam); + } + parameter->set_default_param(param_value); + AddNode(i, parameter); + } + return RET_OK; +} + +ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr &cNode) { + MS_ASSERT(nullptr != meta_graph_); + MS_ASSERT(nullptr != cNode); + auto primTValue = std::make_shared(cNode->primitive.release()); + cNode->primitive = nullptr; + // add quant parameter + if (cNode->quantType == schema::QuantType_AwareTraining) { + primTValue->SetQuantType(cNode->quantType); + for (int index : cNode->inputIndex) { + std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; + primTValue->AddInputQuantParam(quant_params); + } + for (int index : cNode->outputIndex) { + std::vector quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; + primTValue->AddOutputQuantParam(quant_params); + } + } + auto value_node = NewValueNode(primTValue); + return value_node; +} + +abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTensor( + const std::unique_ptr &tensor) { + MS_ASSERT(nullptr != tensor); + std::vector shape(tensor->dims.size()); + std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); + auto type_id = static_cast(tensor->dataType); + auto type_ptr = TypeIdToType(type_id); + return std::make_shared(type_ptr, shape); +} + +void AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr &src_cnode, + const CNodePtr &dst_cnode) { + MS_ASSERT(nullptr != meta_graph_); + MS_ASSERT(nullptr != src_cnode); + MS_ASSERT(nullptr != dst_cnode); + std::vector out_tensor_ids = src_cnode->outputIndex; + if (out_tensor_ids.size() == 1) { + auto out_tensor_id = out_tensor_ids.front(); + MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); + auto &tensor = meta_graph_->allTensors.at(out_tensor_id); + MS_ASSERT(nullptr != tensor); + dst_cnode->set_abstract(ConvertTensorToAbstractTensor(tensor)); + AddNode(out_tensor_id, dst_cnode); + } else { + AbstractBasePtrList abstract_list; + for (size_t i = 0; i < out_tensor_ids.size(); i++) { + auto out_tensor_id = out_tensor_ids.at(i); + MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); + auto &tensor = meta_graph_->allTensors.at(out_tensor_id); + MS_ASSERT(nullptr != tensor); + abstract_list.emplace_back(ConvertTensorToAbstractTensor(tensor)); + auto tuple_get_item_prim = NewValueNode(GetTupleGetItemPrim()); + auto get_item_value = NewValueNode(MakeValue(i)); + std::vector inputs{tuple_get_item_prim, dst_cnode, get_item_value}; + CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); + AddNode(out_tensor_id, get_item_cnode); + } + dst_cnode->set_abstract(std::make_shared(abstract_list)); + } +} + +int AnfImporterFromMetaGraphT::ConverterCNode() { + MS_ASSERT(nullptr != meta_graph_); + MS_ASSERT(nullptr != func_graph_); + for (const auto &cNode : meta_graph_->nodes) { + MS_ASSERT(nullptr != cNode); + + std::vector op_inputs = {ConvertPrimitive(cNode)}; + for (unsigned int j : cNode->inputIndex) { + auto node = GetNode(j); + if (nullptr == node) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_ERROR; + } + // todo: CheckInputNodeType, the first node should be op; + op_inputs.push_back(node); + } + auto new_cnode = func_graph_->NewCNode(op_inputs); + new_cnode->set_fullname_with_scope(cNode->name); + ConvertAbstract(cNode, new_cnode); + } + return RET_OK; +} + +int AnfImporterFromMetaGraphT::AddReturnCNode() { + MS_EXCEPTION_IF_NULL(meta_graph_); + MS_EXCEPTION_IF_NULL(func_graph_); + std::vector make_tuple_inputs; + auto make_tuple_prim = NewValueNode(GetMakeTuplePrim()); + make_tuple_inputs.emplace_back(make_tuple_prim); + for (auto tensor_id : meta_graph_->outputIndex) { + auto cNode = GetNode(tensor_id); + if (nullptr == cNode) { + MS_LOG(ERROR) << "Can't find input node."; + return RET_ERROR; + } + make_tuple_inputs.emplace_back(cNode); + } + auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); + make_tuple_cnode->set_fullname_with_scope("return tuple"); + + std::vector op_inputs; + auto value_node = NewValueNode(GetReturnPrim()); + op_inputs.emplace_back(value_node); + op_inputs.emplace_back(make_tuple_cnode); + auto cnode = func_graph_->NewCNode(op_inputs); + cnode->set_fullname_with_scope("return"); + func_graph_->set_return(cnode); + return RET_OK; +} + +FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } +} // namespace mindspore::lite diff --git a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.h similarity index 73% rename from mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h rename to mindspore/lite/tools/anf_importer/import_from_meta_graphT.h index 5b3799a256..0e8f3e8ca2 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_meta_graphT.h +++ b/mindspore/lite/tools/anf_importer/import_from_meta_graphT.h @@ -18,9 +18,11 @@ #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ #include - +#include #include "schema/inner/model_generated.h" -#include "src/common/anf_importer/anf_importer.h" +#include "tools/anf_importer/anf_importer.h" +#include "src/ir/primitive_t_value.h" +#include "abstract/abstract_value.h" namespace mindspore::lite { class AnfImporterFromMetaGraphT : public AnfImporter { @@ -33,11 +35,15 @@ class AnfImporterFromMetaGraphT : public AnfImporter { FuncGraphPtr GetResult() override; private: - void ConverterConstTensor() override; + int ConverterConstTensor() override; int ConverterCNode() override; - void AddReturnCNode() override; + ValueNodePtr ConvertPrimitive(const std::unique_ptr &cNode); + abstract::AbstractTensorPtr ConvertTensorToAbstractTensor(const std::unique_ptr &tensor); + void ConvertAbstract(const std::unique_ptr &src_cnode, const CNodePtr &dst_cnode); + + int AddReturnCNode() override; private: schema::MetaGraphT *meta_graph_; @@ -46,4 +52,3 @@ class AnfImporterFromMetaGraphT : public AnfImporter { } // namespace mindspore::lite #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ - diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc similarity index 84% rename from mindspore/lite/src/common/anf_importer/import_from_protobuf.cc rename to mindspore/lite/tools/anf_importer/import_from_protobuf.cc index 9bd7a41e28..3f41d3a197 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/common/anf_importer/import_from_protobuf.h" +#include "tools/anf_importer/import_from_protobuf.h" #include #include @@ -35,11 +35,11 @@ #include "ir/func_graph.h" #include "schema/inner/model_generated.h" #include "securec/include/securec.h" -#include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" #include "src/ir/tensor.h" #include "src/param_value_lite.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "utils/log_adapter.h" +#include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" using string = std::string; using int32 = int32_t; @@ -60,24 +60,16 @@ enum ParseForm : int { }; static std::map kParseTypeSwitchMap{ - {"type", FORM_PARSE_TYPE}, - {"scalar", FORM_PARSE_SCALAR}, - {"tensor", FORM_PARSE_TENSOR}}; + {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; static std::unordered_map kDefaultValueSwitchMap{ - {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, - {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, - {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, - {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, - {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, - {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, - {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, - {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, - {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, - {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, - {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, - {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, - {onnx::TensorProto_DataType_STRING, kObjectTypeString}, + {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, + {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, + {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, + {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, + {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, + {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, + {onnx::TensorProto_DataType_STRING, kObjectTypeString}, }; #if 0 @@ -197,16 +189,15 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map(attr_tensor.type##_data(0)); \ - return MakeValue(value); \ - } else { \ - MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ - } \ - return {}; \ +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ + if (attr_tensor.type##_data_size() == 1) { \ + auto value = static_cast(attr_tensor.type##_data(0)); \ + return MakeValue(value); \ + } else { \ + MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ + } \ + return {}; \ } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) @@ -652,21 +643,20 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc } #else -#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ - void ParseAttrInScalar_##type##_##valuetype( \ - const PrimitivePtr &prim, const std::string &attr_name, \ - const onnx::TensorProto &attr_tensor) { \ - MS_EXCEPTION_IF_NULL(prim); \ - std::vector attr_value_vec; \ - for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ - auto value = static_cast(attr_tensor.type##_data(i)); \ - attr_value_vec.push_back(MakeValue(value)); \ - } \ - if (attr_value_vec.size() == 1) { \ - prim->AddAttr(attr_name, attr_value_vec[0]); \ - } else { \ - prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ - } \ +#define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ + void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ + const onnx::TensorProto &attr_tensor) { \ + MS_EXCEPTION_IF_NULL(prim); \ + std::vector attr_value_vec; \ + for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ + auto value = static_cast(attr_tensor.type##_data(i)); \ + attr_value_vec.push_back(MakeValue(value)); \ + } \ + if (attr_value_vec.size() == 1) { \ + prim->AddAttr(attr_name, attr_value_vec[0]); \ + } else { \ + prim->AddAttr(attr_name, std::make_shared(attr_value_vec)); \ + } \ } PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) @@ -677,8 +667,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) -bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( - const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { +bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, + const onnx::ValueInfoProto &value_proto) { MS_EXCEPTION_IF_NULL(node); if (!value_proto.has_type() || !value_proto.has_name()) { MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; @@ -701,30 +691,24 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( shape.push_back(tensor_shape.dim(i).dim_value()); } - if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == - kDefaultValueSwitchMap.end()) { + if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; return false; } - auto type_ptr = - TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); - auto abstract_tensor = - std::make_shared(type_ptr, shape); + auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); + auto abstract_tensor = std::make_shared(type_ptr, shape); node->set_abstract(abstract_tensor); if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { - tensor::Tensor *tensor_info = new tensor::Tensor( - kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); + tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); MS_EXCEPTION_IF_NULL(tensor_info); tensor_info->MallocData(); - const onnx::TensorProto initialize_proto = - default_para_map_[value_proto.name()]; + const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; std::string initial_data = initialize_proto.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); MS_EXCEPTION_IF_NULL(tensor_data_buf); - auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), - initial_data.data(), initial_data.size()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); if (EOK != ret) { MS_LOG(ERROR) << "memcpy_s error"; return false; @@ -740,18 +724,15 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( return true; } -bool AnfImporterFromProtobuf::ImportParametersForGraph( - const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { +bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); - MS_LOG(INFO) << "Parameters had default paramerer size is: " - << importProto.initializer_size(); + MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); for (int i = 0; i < importProto.initializer_size(); ++i) { const onnx::TensorProto &initializer_proto = importProto.initializer(i); if (!initializer_proto.has_name()) { - MS_LOG(ERROR) - << "initializer vector of onnx GraphProto has no name at index: " - << i; + MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; return false; } default_para_map_[initializer_proto.name()] = initializer_proto; @@ -760,8 +741,7 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph( MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); for (int i = 0; i < importProto.input_size(); ++i) { const onnx::ValueInfoProto &input_proto = importProto.input(i); - if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), - input_proto)) { + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; return false; } @@ -769,25 +749,20 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph( return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm( - const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { MS_EXCEPTION_IF_NULL(prim); const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == - kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" - << attr_tensor_type; + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; return false; } - prim->AddAttr(attr_name, - TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( - const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { MS_EXCEPTION_IF_NULL(prim); const int attr_tensor_type = attr_tensor.data_type(); switch (attr_tensor_type) { @@ -821,16 +796,14 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( break; } default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " - << attr_tensor_type; + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return false; } return true; } -bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( - const PrimitivePtr &prim, const std::string &attr_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, + const onnx::TensorProto &attr_tensor) { MS_EXCEPTION_IF_NULL(prim); const int attr_tensor_type = attr_tensor.data_type(); const std::string &tensor_buf = attr_tensor.raw_data(); @@ -840,31 +813,26 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape.push_back(attr_tensor.dims(i)); } - tensor::TensorPtr tensor_info = std::make_shared( - kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); tensor_info->MallocData(); auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); - ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), - tensor_buf.size()); + ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); prim->set_attr(attr_name, MakeValue(tensor_info)); } else { if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { size_t data_size = sizeof(double); double attr_value = 0.0; - ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), - tensor_buf.size()); + ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); prim->set_attr(attr_name, MakeValue(attr_value)); } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { size_t data_size = sizeof(int64_t); int32_t attr_value = 0; - ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), - tensor_buf.size()); + ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); prim->set_attr(attr_name, MakeValue(attr_value)); } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { size_t data_size = sizeof(bool); bool attr_value = false; - ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), - tensor_buf.size()); + ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); prim->set_attr(attr_name, MakeValue(attr_value)); } } @@ -872,8 +840,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( return ret == EOK; } -bool AnfImporterFromProtobuf::GetAttrValueForCNode( - const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { +bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { MS_EXCEPTION_IF_NULL(prim); const std::string &attr_name = attr_proto.name(); if (!attr_proto.has_ref_attr_name()) { @@ -897,20 +864,18 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode( return false; } } -bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( - const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); std::vector shape; for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape.push_back(attr_tensor.dims(i)); } - tensor::TensorPtr tensor_info = std::make_shared( - kDefaultValueSwitchMap[attr_tensor_type], shape); + tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[attr_tensor_type], shape); tensor_info->MallocData(); const std::string &tensor_buf = attr_tensor.raw_data(); auto *tensor_data_buf = reinterpret_cast(tensor_info->Data()); - auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), - tensor_buf.size()); + auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); if (EOK != ret) { MS_LOG(ERROR) << "memcpy_s error"; return false; @@ -918,15 +883,14 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( auto new_value_node = NewValueNode(MakeValue(tensor_info)); MS_EXCEPTION_IF_NULL(new_value_node); auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); - auto abstract_tensor = - std::make_shared(type_ptr, shape); + auto abstract_tensor = std::make_shared(type_ptr, shape); new_value_node->set_abstract(abstract_tensor); anfnode_build_map_[value_node_name] = new_value_node; return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( - const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); ValuePtr value_ptr = nullptr; switch (attr_tensor_type) { @@ -961,8 +925,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( break; } default: - MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " - << attr_tensor_type; + MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; return false; } auto new_value_node = NewValueNode(value_ptr); @@ -973,28 +936,23 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( return true; } -bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm( - const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); - if (kDefaultValueSwitchMap.find(attr_tensor_type) == - kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) - << "Obtain ValueNode attr in type-form has not support input type: " - << attr_tensor_type; + if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { + MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; return false; } - auto new_value_node = - NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); - abstract::AbstractTypePtr abs_type = - std::make_shared(std::make_shared()); + auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); + abstract::AbstractTypePtr abs_type = std::make_shared(std::make_shared()); new_value_node->set_abstract(abs_type); anfnode_build_map_[value_node_name] = new_value_node; return true; } -bool AnfImporterFromProtobuf::GetAttrValueForValueNode( - const std::string &ref_attr_name, const std::string &value_node_name, - const onnx::TensorProto &attr_tensor) { +bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, + const std::string &value_node_name, + const onnx::TensorProto &attr_tensor) { switch (kParseTypeSwitchMap[ref_attr_name]) { case FORM_PARSE_SCALAR: { return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); @@ -1006,14 +964,12 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode( return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); } default: - MS_LOG(ERROR) - << "parse ValueNode value don't support input of ref_attr_name"; + MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; return false; } } -bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( - const onnx::NodeProto &node_proto) { +bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { const std::string &value_node_name = node_proto.output(0); const onnx::AttributeProto &attr_proto = node_proto.attribute(0); if (!attr_proto.has_ref_attr_name()) { @@ -1026,23 +982,21 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); } -abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode( - const onnx::AttributeProto &attr_proto) { +abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { std::vector shape_vec; const onnx::TensorProto &attr_tensor = attr_proto.t(); for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape_vec.push_back(attr_tensor.dims(i)); } auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); - auto abstract_tensor = - std::make_shared(type_ptr, shape_vec); + auto abstract_tensor = std::make_shared(type_ptr, shape_vec); MS_EXCEPTION_IF_NULL(abstract_tensor); return abstract_tensor; } -CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( - const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, - const schema::QuantType &quantType) { +CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::NodeProto &node_proto, + const schema::QuantType &quantType) { MS_EXCEPTION_IF_NULL(outputFuncGraph); if (!node_proto.has_op_type()) { MS_LOG(ERROR) << "Get CNode op_type failed!"; @@ -1082,23 +1036,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( for (int i = 0; i < node_proto.input_size(); ++i) { const std::string &input_name = node_proto.input(i); if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { - MS_LOG(ERROR) << node_name << " input " << i << input_name - << "can't find in nodes have parsed"; + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; return nullptr; } inputs.push_back(anfnode_build_map_[input_name]); } std::string opType = prim->name(); - auto node_parser = - AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); + auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); if (node_parser == nullptr) { MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; return nullptr; } auto primitiveT = std::make_unique(); // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); - std::shared_ptr primitiveTValuePtr = - std::make_shared(primitiveT.release()); + std::shared_ptr primitiveTValuePtr = std::make_shared(primitiveT.release()); primitiveTValuePtr->SetQuantType(quantType); node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); MS_ASSERT(primitiveTValuePtr != nullptr); @@ -1130,9 +1081,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( return cnode_ptr; } -bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( - const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const CNodePtr &cnode_ptr) { +bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { MS_EXCEPTION_IF_NULL(outputFuncGraph); MS_EXCEPTION_IF_NULL(cnode_ptr); std::vector inputs; @@ -1147,8 +1097,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( elem.push_back(anfnode_build_map_[out_tuple]->abstract()); } auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); - maketuple_ptr->set_abstract( - std::make_shared(elem)); + maketuple_ptr->set_abstract(std::make_shared(elem)); inputs.clear(); inputs.push_back(NewValueNode(prim::kPrimReturn)); inputs.push_back(maketuple_ptr); @@ -1161,14 +1110,11 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( const onnx::TypeProto &output_typeproto = output_node.type(); int output_type = output_typeproto.tensor_type().elem_type(); std::vector output_shape; - for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); - ++i) { - output_shape.push_back( - output_typeproto.tensor_type().shape().dim(i).dim_value()); + for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { + output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); } auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); - auto abstract_tensor = - std::make_shared(type_ptr, output_shape); + auto abstract_tensor = std::make_shared(type_ptr, output_shape); inputs.clear(); inputs.push_back(NewValueNode(prim::kPrimReturn)); @@ -1182,9 +1128,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( return true; } -bool AnfImporterFromProtobuf::ImportNodesForGraph( - const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { +bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, + const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { MS_EXCEPTION_IF_NULL(outputFuncGraph); MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); CNodePtr cnode_ptr = nullptr; @@ -1210,9 +1156,8 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph( } #endif -bool AnfImporterFromProtobuf::BuildFuncGraph( - const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, - const schema::QuantType &quantType) { +bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, + const schema::QuantType &quantType) { MS_EXCEPTION_IF_NULL(outputFuncGraph); GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); MS_EXCEPTION_IF_NULL(debug_info_ptr); @@ -1228,8 +1173,7 @@ bool AnfImporterFromProtobuf::BuildFuncGraph( return ImportNodesForGraph(outputFuncGraph, importProto, quantType); } -bool AnfImporterFromProtobuf::ParseModelConfigureInfo( - const onnx::ModelProto &model_proto) { +bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { if (!model_proto.has_producer_name()) { MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; return false; @@ -1267,8 +1211,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { return RET_OK; } -onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary( - const std::string &model_path) { +onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { MS_LOG(ERROR) << "open file failed."; diff --git a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h b/mindspore/lite/tools/anf_importer/import_from_protobuf.h similarity index 70% rename from mindspore/lite/src/common/anf_importer/import_from_protobuf.h rename to mindspore/lite/tools/anf_importer/import_from_protobuf.h index e7064fab39..446f8a5be5 100644 --- a/mindspore/lite/src/common/anf_importer/import_from_protobuf.h +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.h @@ -22,15 +22,15 @@ #include #include -#include "abstract/abstract_value.h" -#include "src/common/anf_importer/anf_importer.h" +#include "include/errorcode.h" #include "tools/converter/parser/onnx/onnx.pb.h" +#include "tools/anf_importer/anf_importer.h" +#include "abstract/abstract_value.h" namespace mindspore::lite { class AnfImporterFromProtobuf : public AnfImporter { public: - explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, - FuncGraphPtr func_graph) + explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} ~AnfImporterFromProtobuf() override = default; @@ -39,16 +39,14 @@ class AnfImporterFromProtobuf : public AnfImporter { FuncGraphPtr GetResult() override; - int Import(const schema::QuantType &quantType = - schema::QuantType_QUANT_NONE) override; + int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override; private: - void ConverterConstTensor() override{}; - int ConverterCNode() override{}; - void AddReturnCNode() override{}; + int ConverterConstTensor() override{ return RET_ERROR; }; + int ConverterCNode() override{ return RET_ERROR; }; + int AddReturnCNode() override{ return RET_ERROR; }; bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); - bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, + bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, const schema::QuantType &quantType); #if 0 bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, @@ -81,43 +79,29 @@ class AnfImporterFromProtobuf : public AnfImporter { std::unordered_map GetAbstractForCNode(const onnx::AttributeProto &attr_proto); #else - bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto); - bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, + bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); + bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, const schema::QuantType &quantType); - bool BuildParameterForFuncGraph(const ParameterPtr &node, - const onnx::ValueInfoProto &value_proto); - CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::NodeProto &node_proto, + bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); + CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, const schema::QuantType &quantType); - bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, - const onnx::GraphProto &importProto, + bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr); - bool GetAttrValueForCNode(const PrimitivePtr &prim, - const onnx::AttributeProto &attr_proto); - bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, - const std::string &attr_name, + bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); + bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, - const std::string &attr_name, + bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); - bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, - const std::string &attr_name, + bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, const onnx::TensorProto &attr_tensor); bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); - bool ObtainValueNodeInTensorForm(const string &value_node_name, - const onnx::TensorProto &attr_tensor); + bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); - bool ObtainValueNodeInScalarForm(const string &value_node_name, - const onnx::TensorProto &attr_tensor); - bool GetAttrValueForValueNode(const string &ref_attr_name, - const std::string &value_node_name, + bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, const onnx::TensorProto &attr_tensor); - bool ObtainValueNodeInTypeForm(const string &value_node_name, - const onnx::TensorProto &attr_tensor); - abstract::AbstractTensorPtr GetAbstractForCNode( - const onnx::AttributeProto &attr_proto); + bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); + abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); #endif diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 0553a1e000..706721e1a1 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -70,7 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc - # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc ../optimizer/common/node_pass_extends.cc ../optimizer/common/pass_manager_extends.cc @@ -83,6 +83,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/constant_folding_fusion.cc ) +add_subdirectory(../anf_importer anf_importer) +add_subdirectory(../anf_exporter anf_exporter) add_subdirectory(parser/caffe) add_subdirectory(parser/tflite) add_subdirectory(parser/onnx) @@ -100,6 +102,7 @@ target_link_libraries(converter_lite PRIVATE caffe_parser_mid onnx_parser_mid anf_importer_mid + anf_exporter_mid node_mid graph_pass_mid fusion_mid diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index c1ad5cff8b..8ff14c519b 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -28,8 +28,8 @@ #include "parser/caffe/caffe_converter.h" #include "parser/tflite/tflite_converter.h" #include "parser/onnx/onnx_converter.h" -#include "src/common/anf_exporter/anf_exporter.h" -#include "src/common/anf_importer/import_from_protobuf.h" +#include "tools/anf_exporter/anf_exporter.h" +#include "tools/anf_importer/import_from_protobuf.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/quantizer/weight_quantizer.h" #include "tools/converter/quantizer/post_training_quantizer.h" diff --git a/mindspore/lite/tools/converter/converter.h b/mindspore/lite/tools/converter/converter.h index 54e3560a87..f4f20e5640 100644 --- a/mindspore/lite/tools/converter/converter.h +++ b/mindspore/lite/tools/converter/converter.h @@ -22,7 +22,7 @@ #include "schema/inner/model_generated.h" #include "tools/converter/graphdef_transform.h" #include "tools/converter/model_parser.h" -#include "src/common/anf_importer/anf_importer.h" +#include "tools/anf_importer/anf_importer.h" #include "tools/converter/converter_flags.h" #include "tools/converter/anf_transform.h" #include "tools/converter/quantizer/quantizer.h" diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index c3c03af6a1..c519b768ca 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(graph_pass_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/eltwise_format_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc new file mode 100644 index 0000000000..097f3c2eca --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc @@ -0,0 +1,200 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h" +#include "tools/common/converter_op_utils.h" +#include "tools/common/node_util.h" +#include "utils/log_adapter.h" +#include "src/common/common.h" +#include "src/common/utils.h" + +namespace mindspore { +namespace lite { +#define kMinInputNum 1 +#define kOutputNum 1 + +STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { + if (fmkType == converter::FmkType_TF) { + return RET_OK; + } + MS_ASSERT(graph != nullptr); + auto status = DoModelInputFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; + return status; + } + status = DoNodeInoutFormatTrans(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; + return status; + } + return RET_OK; +} + +STATUS EltwiseFormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { + if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) { + return RET_OK; + } + MS_ASSERT(graph != nullptr); + // insert trans node in model input tensor + if (graph->nodes.empty()) { + return RET_OK; + } + auto graphInputIdxes = graph->inputIndex; + for (size_t i = 0; i < graphInputIdxes.size(); i++) { + auto inputIdx = graphInputIdxes.at(i); + MS_ASSERT(inputIdx < subGraph->allTensors.size()); + auto &tensor = graph->allTensors.at(inputIdx); + if (tensor->dims.size() != kNCHWDimNumber) { + continue; + } + + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + auto &node = *iter; + for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { + if (node->inputIndex.at(inputIndexIdx) == inputIdx) { + STATUS status = RET_OK; + iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; + return status; + } + // set first tensor format to nhwc + auto &transNode = *(iter - 1); + MS_ASSERT(transNode != nullptr); + MS_ASSERT(transNode->inputIndex.size() == 1); + MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); + auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); + graphInTensor->format = schema::Format_NHWC; + // assume parser not reformat shape + auto oldDims = graphInTensor->dims; + graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; + break; + } + } + } + } + return RET_OK; +} + +// inference needed inputFormat: +// conv deconv depth dedepth +// fp32 NCHW NCHW NCHW NCHW +// uint8 NCHW ? NCHW ? +STATUS EltwiseFormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { + MS_ASSERT(graph != nullptr); + // insert before and after the op cal by nchw/nc4hw4 + for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { + FormatTransNodeType beforeNodeType, afterNodeType; + if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc + // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use + // nhwc + // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only + // support nhwc + // continue; + // } + // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + // continue; + // } + // } else { + // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + // } + // } + // beforeNodeType = kNCHW2NHWC; + // afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw + // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc + // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc + // continue; + // } + // } else { + // continue; + // } + if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + } + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else if (fmkType == converter::FmkType_MS) { + if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { + continue; + } + beforeNodeType = kNCHW2NHWC; + afterNodeType = kNHWC2NCHW; + } else { + MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; + return RET_ERROR; + } + auto &node = *iter; + auto nodeName = node->name; + if (node->inputIndex.size() < kMinInputNum) { + MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; + return RET_ERROR; + } + if (node->outputIndex.size() != kOutputNum) { + MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; + return RET_ERROR; + } + STATUS status; + iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; + return RET_ERROR; + } + + iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); + if (status != RET_OK) { + MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; + return RET_ERROR; + } + } + return RET_OK; +} + +NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, + InsertPlace place, size_t inoutIdx, FormatTransNodeType nodeType, + STATUS *errorCode) { + MS_ASSERT((*existNodeIter) != nullptr); + auto existNodeName = (*existNodeIter)->name; + std::string tileName; + if (place == kBefore) { + tileName = existNodeName + "_pre"; + } else { + tileName = existNodeName + "_post"; + } + auto transNode = std::make_unique(); + transNode->primitive = std::make_unique(); + + if (nodeType == kNCHW2NHWC) { + transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; + } else { + transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); + transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; + } + return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); +} + +void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } + +void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h new file mode 100644 index 0000000000..5a5d754ac1 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H +#define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H + +#include "tools/converter/optimizer.h" +#include "tools/common/graph_util.h" +#include "tools/converter/converter_flags.h" + +namespace mindspore { +namespace lite { +enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW }; + +class EltwiseFormatTransPass : public GraphPass { + public: + EltwiseFormatTransPass() : id(0) {} + + ~EltwiseFormatTransPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + void SetQuantType(QuantType quantType); + + void SetFmk(converter::FmkType fmkType); + + private: + STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); + + STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); + + NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, + FormatTransNodeType nodeType, STATUS *errorCode); + + private: + size_t id; + QuantType quantType = QuantType_QUANT_NONE; + converter::FmkType fmkType = converter::FmkType_TF; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H diff --git a/mindspore/lite/tools/converter/model_parser.h b/mindspore/lite/tools/converter/model_parser.h index 02ebf0f1d8..64c3cfac52 100644 --- a/mindspore/lite/tools/converter/model_parser.h +++ b/mindspore/lite/tools/converter/model_parser.h @@ -20,7 +20,7 @@ #include #include #include "schema/inner/model_generated.h" -#include "src/common/anf_importer/import_from_meta_graphT.h" +#include "tools/anf_importer/import_from_meta_graphT.h" #include "ir/anf.h" #include "include/errorcode.h" diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index 335ebcbfec..009fd01ae5 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -12,7 +12,6 @@ add_library(quantizer_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc - ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/common/anf_exporter/anf_exporter.cc ) if(ENABLE_ASAN) diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index 01d85b2674..d9eba3480e 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -27,7 +27,7 @@ #include #include "schema/inner/model_generated.h" #include "src/ir/tensor.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quantize_util.h" #include "src/common/common.h" diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index 119cec8bba..f665f63ee0 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -283,7 +283,7 @@ void CheckLeastInputSize(const CNodePtr &node, const int size) { } } -AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, +ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, const ParamValueLitePtr &weight_tensor) { auto bias_parameter = func_graph->add_parameter(); MS_ASSERT(bias_parameter != nullptr); diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 9827877be1..def65d0537 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -47,7 +47,7 @@ void CheckIfNodeIsParam(const AnfNodePtr &node); void CheckLeastInputSize(const CNodePtr &node, int size); -AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, +ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, const ParamValueLitePtr &weight_tensor); schema::PrimitiveType GetCNodeType(const BaseRef &node); diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index 14be819ad5..0cd2036cec 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -21,7 +21,7 @@ #include "schema/inner/model_generated.h" #include "tools/optimizer/common/gllo_utils.h" #include "src/kernel_factory.h" -#include "src/common/anf_exporter/anf_exporter.h" +#include "tools/anf_exporter/anf_exporter.h" #include "src/scheduler.h" #include "include/context.h" #include "src/lite_session.h" @@ -38,7 +38,7 @@ const std::vector GetCNodeInputTensors(const CNodePtr &CNode) { auto tmp_meta_graph = std::make_unique(); auto tmp_fb_node = std::make_unique(); lite::AnfExporter anfExporter; - anfExporter.SetOpInputNode(CNode, tmp_meta_graph.get(), tmp_fb_node.get()); + anfExporter.SetOpInputNode(CNode, tmp_meta_graph, tmp_fb_node.get()); std::vector input_tensors; for (auto input_index : tmp_fb_node->inputIndex) { auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index 4b47a9f469..e3e8bc2728 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -33,8 +33,8 @@ constexpr size_t kConvWithBiasLen = 4; bool IsConvExtendNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); - return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D - || type == schema::PrimitiveType_DeConv2D; + return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D || + type == schema::PrimitiveType_DeConv2D; } return false; } @@ -59,8 +59,8 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { if (type == schema::PrimitiveType_Conv2D) { return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier - * primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * + primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; } else if (type == schema::PrimitiveType_DeConv2D) { return primitive->GetPrimitiveT()->value.AsDeConv2D()->channelOut; } else { @@ -83,16 +83,16 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c if (kernel_nums <= 0) { MS_LOG(EXCEPTION) << "kernel num less than 0"; } - auto add_bias_data = new(std::nothrow) float[kernel_nums]; + auto add_bias_data = new (std::nothrow) float[kernel_nums]; auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); CheckIfNodeIsParam(bias_add_weight); auto add_weight_param = bias_add_weight->cast()->default_param(); auto add_weight_tensor = std::dynamic_pointer_cast(add_weight_param); auto add_weight_data = reinterpret_cast(add_weight_tensor->tensor_addr()); auto add_weight_shape = add_weight_tensor->tensor_shape(); - if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] ==1)) { - for (size_t i = 0; i < kernel_nums; i++) { - add_bias_data[i] = *add_weight_data; + if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { + for (size_t i = 0; i < kernel_nums; i++) { + add_bias_data[i] = *add_weight_data; } } else { if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { @@ -115,6 +115,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c auto conv_weight_param = conv_weight_node->cast()->default_param(); auto conv_weight_tensor = std::dynamic_pointer_cast(conv_weight_param); auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); + conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); conv_node->add_input(conv_new_bias); } } @@ -159,4 +160,3 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons return conv_node; } } // namespace mindspore::opt - diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index a5c09287a3..c858214fb6 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -44,8 +44,8 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { if (type == schema::PrimitiveType_Conv2D) { return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; } else if (type == schema::PrimitiveType_DepthwiseConv2D) { - return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier - * primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; + return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * + primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; } else { MS_LOG(ERROR) << "Unsupported opType, " << type; return 0; @@ -74,8 +74,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString(); return node; } - auto trans_scale = new(std::nothrow) float[kernel_nums]; - auto trans_bias = new(std::nothrow) float[kernel_nums]; + auto trans_scale = new (std::nothrow) float[kernel_nums]; + auto trans_bias = new (std::nothrow) float[kernel_nums]; GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias); GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); delete[] trans_bias; @@ -93,8 +93,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co return pre_node; } -const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, - float *trans_scale, float *trans_bias) const { +const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, + float *trans_bias) const { if (trans_scale == nullptr) { MS_LOG(EXCEPTION) << "new transScale failed"; } @@ -112,8 +112,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in } const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, - int kernel_num, const float *trans_scale, const float *trans_bias) -const { + int kernel_num, const float *trans_scale, + const float *trans_bias) const { MS_ASSERT(conv_node != nullptr); AnfNodePtr conv_weight_node = nullptr; AnfNodePtr conv_bias_node = nullptr; @@ -152,18 +152,19 @@ const { bias_data = reinterpret_cast(bias_tensor->tensor_addr()); bias_flag = true; } else { - bias_data = new(std::nothrow) float[kernel_num]; + bias_data = new (std::nothrow) float[kernel_num]; } CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias); if (!bias_flag) { auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor); + bias_node->set_name(conv_node->fullname_with_scope() + "_bias"); conv_node->add_input(bias_node); } } const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, const float *trans_scale) const { MS_ASSERT(weight_data != nullptr); - auto tmp_weight_data = new(std::nothrow) float[kernel_num * kernel_size]; + auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; MS_ASSERT(new_weight_data != nullptr); auto data_size = kernel_num * kernel_size * sizeof(float); if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) { @@ -189,7 +190,7 @@ const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_nu const float *trans_scale, const float *trans_bias) const { MS_ASSERT(bias_data != nullptr); if (bias_flag) { - auto tmp_bias_data = new(std::nothrow) float[kernel_num]; + auto tmp_bias_data = new (std::nothrow) float[kernel_num]; if (EOK != memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) { MS_LOG(EXCEPTION) << "memset bias data failed"; }