Merge pull request !4486 from hangq/mastertags/v0.7.0-beta
| @@ -278,6 +278,8 @@ class AbstractTensor : public AbstractUndetermined { | |||||
| AbstractBasePtr Broaden() const override; | AbstractBasePtr Broaden() const override; | ||||
| AbstractBasePtr BroadenWithShape() const; | AbstractBasePtr BroadenWithShape() const; | ||||
| AbstractBasePtr Join(const AbstractBasePtr &other) final; | AbstractBasePtr Join(const AbstractBasePtr &other) final; | ||||
| int format() const { return this->format_; } | |||||
| void set_format(int format) { this->format_ = format; } | |||||
| bool operator==(const AbstractTensor &other) const; | bool operator==(const AbstractTensor &other) const; | ||||
| bool operator==(const AbstractBase &other) const override; | bool operator==(const AbstractBase &other) const override; | ||||
| @@ -294,6 +296,9 @@ class AbstractTensor : public AbstractUndetermined { | |||||
| } | } | ||||
| return hash_sum; | return hash_sum; | ||||
| } | } | ||||
| protected: | |||||
| int format_ = 0; | |||||
| }; | }; | ||||
| using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | using AbstractTensorPtr = std::shared_ptr<AbstractTensor>; | ||||
| using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | using AbstractTensorPtrList = std::vector<AbstractTensorPtr>; | ||||
| @@ -97,16 +97,12 @@ if (BUILD_CONVERTER) | |||||
| set(PYTHON_LIBRARIES "${py_lib}") | set(PYTHON_LIBRARIES "${py_lib}") | ||||
| endif() | endif() | ||||
| include_directories(${PYTHON_INCLUDE_DIRS}) | include_directories(${PYTHON_INCLUDE_DIRS}) | ||||
| # include(${TOP_DIR}/cmake/utils.cmake) | |||||
| # include(${TOP_DIR}/cmake/dependency_utils.cmake) | |||||
| include(${TOP_DIR}/cmake/external_libs/json.cmake) | include(${TOP_DIR}/cmake/external_libs/json.cmake) | ||||
| # include(${TOP_DIR}/cmake/dependency_securec.cmake) | |||||
| include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) | include(${TOP_DIR}/cmake/external_libs/pybind11.cmake) | ||||
| include(${TOP_DIR}/cmake/external_libs/eigen.cmake) | include(${TOP_DIR}/cmake/external_libs/eigen.cmake) | ||||
| include_directories(${TOP_DIR}/third_party/protobuf/build/include) | include_directories(${TOP_DIR}/third_party/protobuf/build/include) | ||||
| link_directories(${TOP_DIR}/third_party/protobuf/build/lib) | link_directories(${TOP_DIR}/third_party/protobuf/build/lib) | ||||
| add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) | add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools/converter) | ||||
| add_subdirectory(src/common/anf_importer) | |||||
| endif() | endif() | ||||
| if (BUILD_DEVICE) | if (BUILD_DEVICE) | ||||
| @@ -17,6 +17,7 @@ fi | |||||
| cd ${TOP_PATH}/output/ | cd ${TOP_PATH}/output/ | ||||
| rm -rf MSLite-0.6.0-linux_arm64 | rm -rf MSLite-0.6.0-linux_arm64 | ||||
| tar -zxvf MSLite-0.6.0-linux_arm64.tar.gz | tar -zxvf MSLite-0.6.0-linux_arm64.tar.gz | ||||
| mkdir -p ${BASE_PATH}/lib/ | |||||
| cp ${TOP_PATH}/output/MSLite-0.6.0-linux_arm64/lib/libmindspore-lite.so ${BASE_PATH}/lib/ | cp ${TOP_PATH}/output/MSLite-0.6.0-linux_arm64/lib/libmindspore-lite.so ${BASE_PATH}/lib/ | ||||
| cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${BASE_PATH}/lib/ | cp ${ANDROID_NDK}/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android/libc++_shared.so ${BASE_PATH}/lib/ | ||||
| @@ -35,6 +36,7 @@ cp ${BASE_PATH}/native/build/libmindspore-lite-jni.so ${BASE_PATH}/lib/ | |||||
| ## check sdk gradle | ## check sdk gradle | ||||
| cd ${BASE_PATH}/java | cd ${BASE_PATH}/java | ||||
| rm -rf .gradle build gradle gradlew gradlew.bat build app/build | rm -rf .gradle build gradle gradlew gradlew.bat build app/build | ||||
| mkdir -p ${BASE_PATH}/java/app/libs/arm64-v8a/ | |||||
| rm -rf ${BASE_PATH}/java/app/libs/arm64-v8a/* | rm -rf ${BASE_PATH}/java/app/libs/arm64-v8a/* | ||||
| cp ${BASE_PATH}/lib/*.so ${BASE_PATH}/java/app/libs/arm64-v8a/ | cp ${BASE_PATH}/lib/*.so ${BASE_PATH}/java/app/libs/arm64-v8a/ | ||||
| gradle init | gradle init | ||||
| @@ -1,418 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "base/core_ops.h" | |||||
| #include "mindspore/core/ir/primitive.h" | |||||
| // #include "src/common/anf_exporter/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "src/ir/tensor.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace mindspore::lite { | |||||
| std::set<std::string> RemoveNodeInAnfExporter{"tuple_getitem", "make_tuple"}; | |||||
| void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||||
| bool hasMakeTuple = false; | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| inputs.emplace_back(cnode->input(0)); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| AnfNodePtr inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| continue; | |||||
| } | |||||
| auto makeTupleNode = utils::cast<CNodePtr>(inputNode); | |||||
| if (IsPrimitiveCNode(makeTupleNode, prim::kPrimMakeTuple)) { | |||||
| hasMakeTuple = true; | |||||
| for (size_t j = 1; j < makeTupleNode->inputs().size(); ++j) { | |||||
| inputs.emplace_back(makeTupleNode->input(j)); | |||||
| } | |||||
| } else { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| } | |||||
| } | |||||
| if (hasMakeTuple) { | |||||
| cnode->set_inputs(inputs); | |||||
| } | |||||
| } | |||||
| bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { | |||||
| bool hasTupleGetItem = false; | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| inputs.emplace_back(cnode->input(0)); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| AnfNodePtr inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| continue; | |||||
| } | |||||
| auto tupleGetItemNode = utils::cast<CNodePtr>(inputNode); | |||||
| if (IsPrimitiveCNode(tupleGetItemNode, prim::kPrimTupleGetItem)) { | |||||
| hasTupleGetItem = true; | |||||
| inputs.emplace_back(tupleGetItemNode->input(1)); | |||||
| AnfNodePtr indexNode = tupleGetItemNode->input(2); | |||||
| if (!utils::isa<ValueNode>(indexNode)) { | |||||
| MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr valueNode = utils::cast<ValueNodePtr>(indexNode); | |||||
| mapRemoveGetItem_[tupleGetItemNode->input(1)->fullname_with_scope()] = GetValue<int>(valueNode->value()); | |||||
| } else { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| } | |||||
| } | |||||
| if (hasTupleGetItem) { | |||||
| cnode->set_inputs(inputs); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode) { | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| auto inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| MS_LOG(ERROR) << "Node of Return's input is not CNode"; | |||||
| return false; | |||||
| } | |||||
| auto inputCNode = utils::cast<CNodePtr>(inputNode); | |||||
| auto inputPrimitive = GetValueNode<PrimitivePtr>(inputCNode->input(0)); | |||||
| std::string inputName = inputNode->fullname_with_scope(); | |||||
| auto graphOutput = nodeIdMap[inputName]; | |||||
| metaGraphT->outputIndex.emplace_back(graphOutput); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { | |||||
| auto cnodes = funcGraph->GetOrderedCnodes(); | |||||
| auto metaGraphT = std::make_unique<schema::MetaGraphT>(); | |||||
| for (const auto &cnode : cnodes) { | |||||
| auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| if (primitive != nullptr) { | |||||
| if (RemoveNodeInAnfExporter.count(primitive->name()) != 0) { | |||||
| continue; | |||||
| } | |||||
| } else { | |||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||||
| auto primT = primitiveT_value->GetPrimitiveT(); | |||||
| if (primT->value.type == schema::PrimitiveType_TupleGetItem || | |||||
| primT->value.type == schema::PrimitiveType_MakeTuple) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| mapRemoveGetItem_.clear(); | |||||
| RemoveIfMakeTuple(cnode); | |||||
| RemoveIfTupleGetItem(cnode); | |||||
| if (primitive != nullptr) { | |||||
| if (primitive->name() == prim::kPrimReturn->name()) { | |||||
| AddOutPutIfReturn(metaGraphT, cnode); | |||||
| continue; | |||||
| } | |||||
| } else { | |||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||||
| auto primT = primitiveT_value->GetPrimitiveT(); | |||||
| if (primT->value.type == schema::PrimitiveType_Return) { | |||||
| AddOutPutIfReturn(metaGraphT, cnode); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| auto node = std::make_unique<schema::CNodeT>(); | |||||
| node->name = cnode->fullname_with_scope(); | |||||
| node->nodeType = schema::NodeType_CNode; | |||||
| // populate primitive | |||||
| // if (primitive != nullptr) { | |||||
| // primitive = GetValueNode<PrimitivePtr>(cnode->input(0)); | |||||
| // MS_ASSERT(primitive != nullptr); | |||||
| // std::string opType = primitive->name(); | |||||
| // auto nodeParser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| // if (nodeParser == nullptr) { | |||||
| // MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | |||||
| // return nullptr; | |||||
| // } | |||||
| // std::vector<schema::TensorT *> outputs; | |||||
| // if (utils::isa<abstract::AbstractSequeue>(cnode->abstract())) { | |||||
| // auto abstract_cnode = utils::cast<abstract::AbstractSequeuePtr>(cnode->abstract()); | |||||
| // outputs.resize(abstract_cnode->size()); | |||||
| // } | |||||
| // | |||||
| // nodeParser->Parse(cnode, node.get(), &outputs); | |||||
| // SetOpInputNode(cnode, metaGraphT.get(), node.get()); | |||||
| // SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); | |||||
| // metaGraphT->nodes.emplace_back(std::move(node)); | |||||
| // continue; | |||||
| // } | |||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||||
| if (primitiveT_value == nullptr) { | |||||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto *lite_primitive = primitiveT_value->GetPrimitiveT(); | |||||
| if (lite_primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "Primitive in primitiveT_value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| node->primitive = std::unique_ptr<schema::PrimitiveT>(primitiveT_value->GetPrimitiveT()); | |||||
| std::vector<schema::TensorT *> outputs; | |||||
| SetOpInputNode(cnode, metaGraphT.get(), node.get()); | |||||
| SetOpOutputNode(cnode, outputs, metaGraphT.get(), node.get()); | |||||
| // add quant param | |||||
| node->quantType = primitiveT_value->GetQuantType(); | |||||
| if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTraining) { | |||||
| MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; | |||||
| // activation | |||||
| auto input_quant_params = primitiveT_value->GetInputQuantParams(); | |||||
| auto node_type = primitiveT_value->GetPrimitiveT()->value.type; | |||||
| for (int i = 0; i < input_quant_params.size(); i++) { | |||||
| if (i >= node->inputIndex.size()) { | |||||
| MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size() | |||||
| << " quant_params; but only " << node->inputIndex.size() << " input"; | |||||
| break; | |||||
| } | |||||
| auto activate_index = node->inputIndex[i]; | |||||
| auto tensor_input = metaGraphT->allTensors[activate_index].get(); | |||||
| if (tensor_input->quantParams.empty()) { | |||||
| for (auto input_quant_param : input_quant_params[i]) { | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_param); | |||||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param_ptr->scale | |||||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||||
| } | |||||
| } | |||||
| } | |||||
| // output | |||||
| auto output_index = node->outputIndex[0]; | |||||
| auto tensor_output = metaGraphT->allTensors[output_index].get(); | |||||
| auto output_quant_params = primitiveT_value->GetOutputQuantParams(); | |||||
| if (output_quant_params.empty()) { | |||||
| MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; | |||||
| } else { | |||||
| for (auto output_quant_param : output_quant_params[0]) { | |||||
| if (tensor_output->quantParams.empty()) { | |||||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(output_quant_param); | |||||
| MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << output_quant_param_ptr->scale | |||||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (node->quantType != schema::QuantType_AwareTraining && | |||||
| !(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | |||||
| tensor_output->dataType = kNumberTypeInt8; | |||||
| } | |||||
| // // TensorType | |||||
| // valuePtr = primitive->GetAttr(kInputTensorDataType); | |||||
| // if (valuePtr != nullptr) { | |||||
| // MS_LOG(INFO) << "node: " << node->name << " input tensor data | |||||
| // type: " << GetValue<int>(valuePtr); for (auto input : | |||||
| // node->inputIndex) { | |||||
| // auto tensor = subGraph->allTensors[input].get(); | |||||
| // tensor->dataType = kNumberTypeUInt8; | |||||
| // } | |||||
| // } | |||||
| } | |||||
| metaGraphT->nodes.emplace_back(std::move(node)); | |||||
| } | |||||
| // set graph input tensors | |||||
| for (auto node : graphInputNodes) { | |||||
| for (auto input : node->inputIndex) { | |||||
| auto tensor = metaGraphT->allTensors[input].get(); | |||||
| if (tensor->data.empty()) { | |||||
| tensor->nodeType = schema::NodeType_ValueNode; | |||||
| tensor->format = schema::Format_NHWC; | |||||
| if (!IsContain(metaGraphT->inputIndex, input)) { | |||||
| metaGraphT->inputIndex.emplace_back(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return metaGraphT.release(); | |||||
| } | |||||
| void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode) { | |||||
| MS_ASSERT(nullptr != meta_graph); | |||||
| MS_ASSERT(nullptr != fbNode); | |||||
| if (cnode->inputs().size() <= 1) { | |||||
| return; | |||||
| } | |||||
| std::string cNodeName = cnode->fullname_with_scope(); | |||||
| bool isGraphInput = true; | |||||
| for (int i = 1; i < static_cast<int>(cnode->inputs().size()); i++) { | |||||
| auto inputNode = cnode->input(i); | |||||
| if (inputNode->isa<CNode>()) { | |||||
| isGraphInput = false; | |||||
| std::string inputName = inputNode->fullname_with_scope(); | |||||
| if (!mapRemoveGetItem_.empty()) { | |||||
| for (auto name : mapRemoveGetItem_) { | |||||
| if (name.first == inputName) { | |||||
| inputName = inputName + "_o:" + std::to_string(name.second); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (nodeIdMap.find(inputName) != nodeIdMap.end()) { | |||||
| fbNode->inputIndex.emplace_back(nodeIdMap[inputName]); | |||||
| } | |||||
| } else if (inputNode->isa<Parameter>()) { | |||||
| auto paramNode = inputNode->cast<ParameterPtr>(); | |||||
| if (paramNode->name().empty()) { | |||||
| paramNode->set_name(cNodeName + "_i:" + std::to_string(i - 1)); | |||||
| } | |||||
| if (nodeIdMap.find(paramNode->name()) != nodeIdMap.end()) { | |||||
| fbNode->inputIndex.emplace_back(nodeIdMap[paramNode->name()]); | |||||
| continue; | |||||
| } | |||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | |||||
| auto abstractBase = paramNode->abstract(); | |||||
| if (abstractBase == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); | |||||
| MS_ASSERT(false); | |||||
| return; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); | |||||
| MS_ASSERT(false); | |||||
| return; | |||||
| } | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(typePtr != nullptr); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | |||||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); | |||||
| MS_ASSERT(false); | |||||
| return; | |||||
| } | |||||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||||
| if (paramValue != nullptr) { | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| paramTensor->data.resize(paramValue->tensor_size()); | |||||
| memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | |||||
| } | |||||
| nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); | |||||
| fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); | |||||
| meta_graph->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (inputNode->isa<ValueNode>()) { | |||||
| auto valueNode = inputNode->cast<ValueNodePtr>(); | |||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | |||||
| auto value = valueNode->value(); | |||||
| if (value->isa<lite::tensor::Tensor>()) { | |||||
| auto valueAbstract = valueNode->abstract(); | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); | |||||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| auto data = value->cast<lite::tensor::TensorPtr>(); | |||||
| paramTensor->data.resize(data->Size()); | |||||
| memcpy(paramTensor->data.data(), data->Data(), data->Size()); | |||||
| nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); | |||||
| fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); | |||||
| meta_graph->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<mindspore::Int32Imm>()) { | |||||
| auto valueAbstract = valueNode->abstract(); | |||||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| paramTensor->dims = {1}; | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| auto data = value->cast<mindspore::Int32ImmPtr>(); | |||||
| paramTensor->data.emplace_back(data->value()); | |||||
| nodeIdMap[valueNode->fullname_with_scope()] = meta_graph->allTensors.size(); | |||||
| fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); | |||||
| meta_graph->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<mindspore::ValueSequeue>()) { | |||||
| MS_LOG(INFO) << "Value type is ValueSequence."; | |||||
| break; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (isGraphInput) { | |||||
| graphInputNodes.emplace_back(fbNode); | |||||
| } | |||||
| } | |||||
| void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::vector<schema::TensorT *> &outputTensors, | |||||
| schema::MetaGraphT *graph, schema::CNodeT *fbnode) { | |||||
| MS_ASSERT(nullptr != graph); | |||||
| MS_ASSERT(nullptr != fbnode); | |||||
| std::string cnodeName = fbnode->name; | |||||
| if (!outputTensors.empty()) { | |||||
| int i = 0; | |||||
| for (auto outputTensor : outputTensors) { | |||||
| std::string name = cnodeName + "_o:" + std::to_string(i); | |||||
| auto msTensor = new schema::TensorT(); | |||||
| msTensor->nodeType = schema::NodeType_Parameter; | |||||
| nodeIdMap[name] = graph->allTensors.size(); | |||||
| fbnode->outputIndex.emplace_back(graph->allTensors.size()); | |||||
| graph->allTensors.emplace_back(msTensor); | |||||
| i++; | |||||
| } | |||||
| return; | |||||
| } | |||||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||||
| for (int i = 0; i < tuple->size(); i++) { | |||||
| auto msTensor = new schema::TensorT(); | |||||
| msTensor->nodeType = schema::NodeType_Parameter; | |||||
| fbnode->outputIndex.emplace_back(graph->allTensors.size()); | |||||
| if (tuple->size() == 1) { | |||||
| nodeIdMap[cnodeName] = graph->allTensors.size(); | |||||
| } else { | |||||
| std::string name = cnodeName + "_o:" + std::to_string(i); | |||||
| nodeIdMap[name] = graph->allTensors.size(); | |||||
| } | |||||
| graph->allTensors.emplace_back(msTensor); | |||||
| } | |||||
| } else { | |||||
| auto msTensor = new schema::TensorT(); | |||||
| msTensor->nodeType = schema::NodeType_Parameter; | |||||
| fbnode->outputIndex.emplace_back(graph->allTensors.size()); | |||||
| nodeIdMap[cnodeName] = graph->allTensors.size(); | |||||
| graph->allTensors.emplace_back(msTensor); | |||||
| } | |||||
| } | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph) { | |||||
| AnfExporter anfExporter; | |||||
| return anfExporter.Export(funcGraph); | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -1,49 +0,0 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore::lite { | |||||
| class AnfExporter { | |||||
| public: | |||||
| AnfExporter() = default; | |||||
| virtual ~AnfExporter() = default; | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); | |||||
| void SetOpOutputNode(const CNodePtr &cnode, const std::vector<schema::TensorT *> &outputTensors, | |||||
| schema::MetaGraphT *graph, schema::CNodeT *fbnode); | |||||
| void SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta_graph, schema::CNodeT *fbNode); | |||||
| void RemoveIfMakeTuple(const CNodePtr &cnode); | |||||
| bool RemoveIfTupleGetItem(const CNodePtr &cnode); | |||||
| bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &metaGraphT, const CNodePtr &cnode); | |||||
| private: | |||||
| std::map<std::string, int> nodeIdMap; | |||||
| std::vector<schema::CNodeT *> graphInputNodes; | |||||
| std::map<std::string, int> mapRemoveGetItem_; | |||||
| }; | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &funcGraph); | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| @@ -1,7 +0,0 @@ | |||||
| file(GLOB_RECURSE ANF_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| *.cc | |||||
| ) | |||||
| list(REMOVE_ITEM ANF_SRC_LIST import_from_meta_graph.cc) | |||||
| add_library(anf_importer_mid OBJECT | |||||
| ${ANF_SRC_LIST} | |||||
| ) | |||||
| @@ -1,169 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/common/anf_importer/import_from_meta_graph.h" | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "src/ir/primitive_value.h" | |||||
| #include "include/errorcode.h" | |||||
| namespace mindspore::lite { | |||||
| void AnfImporterFromMetaGraph::ConverterConstTensor() { | |||||
| MS_EXCEPTION_IF_NULL(model_); | |||||
| auto *meta_graph = model_->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(meta_graph); | |||||
| num_of_tensors_ = meta_graph->allTensors()->size(); | |||||
| for (size_t i = 0; i < num_of_tensors_; i++) { | |||||
| auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(i); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if ((tensor->nodeType() != schema::NodeType_ValueNode) && (tensor->nodeType() != schema::NodeType_Parameter)) { | |||||
| continue; | |||||
| } | |||||
| MS_ASSERT(tensor->dims() != nullptr); | |||||
| auto parameter = model_->add_parameter(); | |||||
| std::vector<int> shape; | |||||
| for (size_t j = 0; j < tensor->dims()->size(); ++j) { | |||||
| shape.push_back(tensor->dims()->data()[j]); | |||||
| } | |||||
| auto type_id = static_cast<TypeId>(tensor->dataType()); // todo: check error | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| auto abstractBase = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| // XXX TODO copy format | |||||
| parameter->set_abstract(abstractBase); | |||||
| parameter->set_name(std::string("Parameter")); | |||||
| if (tensor->nodeType() == schema::NodeType_ValueNode) { | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| MS_EXCEPTION_IF_NULL(param_value); | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_tensor_type(type_id); | |||||
| if (tensor->data() != nullptr) { | |||||
| auto size = tensor->data()->size(); | |||||
| char *tensor_data = new char[size](); | |||||
| std::memcpy(tensor_data, tensor->data()->data(), size); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| param_value->set_tensor_size(size); | |||||
| } | |||||
| parameter->set_default_param(param_value); | |||||
| } | |||||
| AddNode(i, parameter); | |||||
| model_->AddAnfNode(i, parameter); | |||||
| } | |||||
| } | |||||
| int AnfImporterFromMetaGraph::ConverterCNode() { | |||||
| MS_EXCEPTION_IF_NULL(model_); | |||||
| auto *meta_graph = model_->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(meta_graph); | |||||
| // Crate CNode -- Order of inputs is as follows | |||||
| // First input should be the Primitive | |||||
| // Then we have CNodes that contribute to this CNode | |||||
| // Finally we Have the parameters | |||||
| // first itteration -- create CNode with primitive, create originator map | |||||
| for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { | |||||
| auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i); | |||||
| MS_EXCEPTION_IF_NULL(cNode); | |||||
| auto prim = std::make_shared<PrimitiveValue>(model_->GetOp(cNode->name()->str())); | |||||
| if (prim == nullptr) { | |||||
| MS_LOG(ERROR) << "th tensorDef in subGraphDef is nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto value_node = NewValueNode(prim); | |||||
| // auto prim_name = std::string("PrimitivePy: ") + std::string(cNode->name()->c_str()); | |||||
| // value_node->set_fullname_with_scope(prim_name); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node}; | |||||
| auto cnode = model_->NewCNode(op_inputs); | |||||
| auto node_name = std::string(cNode->name()->c_str()) + std::to_string(i); | |||||
| cnode->set_fullname_with_scope(node_name); | |||||
| AddNode(num_of_tensors_ + i, cnode); | |||||
| for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { | |||||
| int tensor_id = cNode->outputIndex()->data()[j]; | |||||
| originator_[tensor_id] = cnode; | |||||
| } | |||||
| } | |||||
| // second itteration -- fill in input CNodes and Parameters | |||||
| // populate map | |||||
| for (size_t i = 0; i < meta_graph->nodes()->size(); i++) { | |||||
| std::vector<int> input; | |||||
| std::vector<int> output; | |||||
| int tensor_id; | |||||
| auto cNode = meta_graph->nodes()->GetAs<schema::CNode>(i); | |||||
| MS_EXCEPTION_IF_NULL(cNode); | |||||
| auto cnode = std::dynamic_pointer_cast<CNode>(GetNode(num_of_tensors_ + i)); | |||||
| for (size_t j = 0; j < cNode->outputIndex()->size(); j++) { | |||||
| tensor_id = cNode->outputIndex()->data()[j]; | |||||
| output.push_back(tensor_id); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(cNode->inputIndex()); | |||||
| for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { | |||||
| tensor_id = cNode->inputIndex()->data()[j]; | |||||
| input.push_back(tensor_id); | |||||
| auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if ((tensor->nodeType() == schema::NodeType_Parameter) && (originator_[tensor_id] != nullptr)) { | |||||
| cnode->add_input(originator_[tensor_id]); | |||||
| } | |||||
| } | |||||
| // finally add all the Parameters (which are ValueNodes) | |||||
| for (size_t j = 0; j < cNode->inputIndex()->size(); j++) { | |||||
| tensor_id = cNode->inputIndex()->data()[j]; | |||||
| auto *tensor = meta_graph->allTensors()->GetAs<schema::Tensor>(tensor_id); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if ((tensor->nodeType() == schema::NodeType_ValueNode) && (GetNode(tensor_id) != nullptr)) { | |||||
| cnode->add_input(GetNode(tensor_id)); | |||||
| } | |||||
| } | |||||
| model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void AnfImporterFromMetaGraph::AddReturnCNode() { | |||||
| MS_EXCEPTION_IF_NULL(model_); | |||||
| auto *meta_graph = model_->GetMetaGraph(); | |||||
| MS_EXCEPTION_IF_NULL(meta_graph); | |||||
| std::vector<int> input; | |||||
| std::vector<int> output; | |||||
| std::vector<AnfNodePtr> op_inputs; | |||||
| auto value_node = NewValueNode(prim::kPrimReturn); | |||||
| // value_node->set_fullname_with_scope("Primitive"); | |||||
| op_inputs.push_back(value_node); | |||||
| for (int i = 0; i < meta_graph->outputIndex()->size(); i++) { | |||||
| auto prev_cnode = originator_[meta_graph->outputIndex()->data()[i]]; | |||||
| if (prev_cnode != nullptr) op_inputs.push_back(prev_cnode); | |||||
| input.push_back(meta_graph->outputIndex()->data()[i]); | |||||
| } | |||||
| auto cnode = model_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | |||||
| model_->set_return(cnode); | |||||
| model_->AddCNodeInputOutput(cnode->fullname_with_scope(), input, output); | |||||
| } | |||||
| FuncGraphPtr AnfImporterFromMetaGraph::GetResult() { return this->model_; } | |||||
| } // namespace mindspore::lite | |||||
| @@ -1,49 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | |||||
| #include <memory> | |||||
| #include <map> | |||||
| #include "src/train/model_impl.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| namespace mindspore::lite { | |||||
| class AnfImporterFromMetaGraph : public AnfImporter { | |||||
| public: | |||||
| explicit AnfImporterFromMetaGraph(std::shared_ptr<ModelImpl> model) : model_(model) {} | |||||
| ~AnfImporterFromMetaGraph() override = default; | |||||
| FuncGraphPtr GetResult() override; | |||||
| private: | |||||
| void ConverterConstTensor() override; | |||||
| int ConverterCNode() override; | |||||
| void AddReturnCNode() override; | |||||
| private: | |||||
| std::shared_ptr<ModelImpl> model_ = nullptr; | |||||
| std::map<int, AnfNodePtr> originator_; | |||||
| int num_of_tensors_ = 0; | |||||
| }; | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPH_H_ | |||||
| @@ -1,175 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "import_from_meta_graphT.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "src/ir/primitive_value.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/ops/ops.h" | |||||
| namespace mindspore::lite { | |||||
| void AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||||
| MS_EXCEPTION_IF_NULL(meta_graph_); | |||||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||||
| for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | |||||
| auto &tensor = meta_graph_->allTensors.at(i); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| if (tensor->nodeType != schema::NodeType_ValueNode) { | |||||
| continue; | |||||
| } | |||||
| MS_ASSERT(tensor->dims() != nullptr); | |||||
| auto parameter = func_graph_->add_parameter(); | |||||
| std::vector<int> shape; | |||||
| for (int &dim : tensor->dims) { | |||||
| shape.push_back(dim); | |||||
| } | |||||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| MS_EXCEPTION_IF_NULL(param_value); | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_tensor_type(type_id); | |||||
| if (!tensor->data.empty()) { | |||||
| auto size = tensor->data.size(); | |||||
| char *tensor_data = new char[size]; | |||||
| std::memcpy(tensor_data, tensor->data.data(), size); | |||||
| MS_EXCEPTION_IF_NULL(tensor_data); | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| param_value->set_tensor_size(size); | |||||
| } | |||||
| if (tensor->quantParams.size() > 0) { | |||||
| std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>(); | |||||
| quantParam->scale = tensor->quantParams[0]->scale; | |||||
| quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; | |||||
| quantParam->min = tensor->quantParams[0]->min; | |||||
| quantParam->max = tensor->quantParams[0]->max; | |||||
| quantParam->narrowRange = tensor->quantParams[0]->narrowRange; | |||||
| quantParam->numBits = tensor->quantParams[0]->numBits; | |||||
| quantParam->inited = tensor->quantParams[0]->inited; | |||||
| param_value->set_quant_param(quantParam); | |||||
| } | |||||
| parameter->set_default_param(param_value); | |||||
| AddNode(i, parameter); | |||||
| } | |||||
| } | |||||
| int AnfImporterFromMetaGraphT::ConverterCNode() { | |||||
| MS_EXCEPTION_IF_NULL(meta_graph_); | |||||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||||
| for (size_t i = 0; i < meta_graph_->nodes.size(); i++) { | |||||
| auto &cNode = meta_graph_->nodes.at(i); | |||||
| MS_EXCEPTION_IF_NULL(cNode); | |||||
| bool flag = false; | |||||
| if (cNode->outputIndex.size() > 1) { | |||||
| flag = true; | |||||
| } | |||||
| auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); | |||||
| // add quant parameter | |||||
| if (cNode->quantType == schema::QuantType_AwareTraining) { | |||||
| primTValue->SetQuantType(cNode->quantType); | |||||
| for (int index : cNode->inputIndex) { | |||||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||||
| primTValue->AddInputQuantParam(quant_params); | |||||
| } | |||||
| for (int index : cNode->outputIndex) { | |||||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||||
| primTValue->AddOutputQuantParam(quant_params); | |||||
| } | |||||
| } | |||||
| cNode->primitive = nullptr; | |||||
| auto value_node = NewValueNode(primTValue); | |||||
| std::vector<AnfNodePtr> op_inputs = {value_node}; | |||||
| for (size_t j = 0; j < cNode->inputIndex.size(); j++) { | |||||
| auto node = GetNode(cNode->inputIndex.at(j)); | |||||
| if (nullptr == node) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // todo: CheckInputNodeType, the first node should be op; | |||||
| op_inputs.push_back(node); | |||||
| } | |||||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||||
| new_cnode->set_fullname_with_scope(cNode->name); | |||||
| std::vector<uint32_t> out_tensor_ids = cNode->outputIndex; | |||||
| AbstractBasePtrList ptr_list; | |||||
| int total = 0; | |||||
| for (auto out_tensor_id : out_tensor_ids) { | |||||
| if (nullptr != GetNode(out_tensor_id)) { | |||||
| ptr_list.push_back(GetNode(out_tensor_id)->abstract()); | |||||
| continue; | |||||
| } | |||||
| std::vector<int> shape; | |||||
| auto &tensor = meta_graph_->allTensors.at(out_tensor_id); | |||||
| for (int &dim : tensor->dims) { | |||||
| shape.push_back(dim); | |||||
| } | |||||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| auto getItemPrim = NewValueNode(prim::kPrimTupleGetItem); | |||||
| if (flag) { | |||||
| auto getItemIndex = NewValueNode(MakeValue<int>(total++)); | |||||
| std::vector<AnfNodePtr> inputs{getItemPrim, new_cnode, getItemIndex}; | |||||
| CNodePtr new_item_cnode = func_graph_->NewCNode(inputs); | |||||
| AddNode(out_tensor_id, new_item_cnode); | |||||
| } else { | |||||
| AddNode(out_tensor_id, new_cnode); | |||||
| } | |||||
| ptr_list.push_back(std::move(abstract_tensor)); | |||||
| } | |||||
| new_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(ptr_list)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void AnfImporterFromMetaGraphT::AddReturnCNode() { | |||||
| MS_EXCEPTION_IF_NULL(meta_graph_); | |||||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| auto make_tuple_value_node = NewValueNode(prim::kPrimMakeTuple); | |||||
| make_tuple_inputs.emplace_back(make_tuple_value_node); | |||||
| for (auto tensor_id : meta_graph_->outputIndex) { | |||||
| make_tuple_inputs.emplace_back(GetNode(tensor_id)); | |||||
| } | |||||
| auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||||
| std::vector<AnfNodePtr> op_inputs; | |||||
| auto value_node = NewValueNode(prim::kPrimReturn); | |||||
| op_inputs.emplace_back(value_node); | |||||
| op_inputs.emplace_back(make_tuple_cnode); | |||||
| auto cnode = func_graph_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | |||||
| func_graph_->set_return(cnode); | |||||
| } | |||||
| FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } | |||||
| } // namespace mindspore::lite | |||||
| @@ -15,3 +15,26 @@ | |||||
| */ | */ | ||||
| #include "src/ir/primitive_t_value.h" | #include "src/ir/primitive_t_value.h" | ||||
| namespace mindspore::lite { | |||||
| std::shared_ptr<PrimitiveTValue> GetReturnPrim() { | |||||
| auto return_primitiveT = new schema::PrimitiveT; | |||||
| return_primitiveT->value.type = schema::PrimitiveType_Return; | |||||
| return_primitiveT->value.value = new schema::ReturnT; | |||||
| return std::make_shared<PrimitiveTValue>(return_primitiveT); | |||||
| } | |||||
| std::shared_ptr<PrimitiveTValue> GetMakeTuplePrim() { | |||||
| auto make_tuple_primitiveT = new schema::PrimitiveT; | |||||
| make_tuple_primitiveT->value.type = schema::PrimitiveType_MakeTuple; | |||||
| make_tuple_primitiveT->value.value = new schema::MakeTupleT; | |||||
| return std::make_shared<PrimitiveTValue>(make_tuple_primitiveT); | |||||
| } | |||||
| std::shared_ptr<PrimitiveTValue> GetTupleGetItemPrim() { | |||||
| auto tuple_get_item_primitiveT = new schema::PrimitiveT(); | |||||
| tuple_get_item_primitiveT->value.type = schema::PrimitiveType_TupleGetItem; | |||||
| tuple_get_item_primitiveT->value.value = new schema::TupleGetItemT; | |||||
| return std::make_shared<PrimitiveTValue>(tuple_get_item_primitiveT); | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -18,8 +18,9 @@ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ | #define MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "ir/value.h" | #include "ir/value.h" | ||||
| #include "mindspore/lite/schema/inner/model_generated.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -46,22 +47,17 @@ class PrimitiveTValue : public Value { | |||||
| } | } | ||||
| } | } | ||||
| void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) { | |||||
| } | |||||
| void SetInputQuantParam(std::vector<std::vector<schema::QuantParamT>> vec_quant_param) {} | |||||
| void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { | void AddInputQuantParam(std::vector<schema::QuantParamT> quant_param) { | ||||
| this->input_quant_param_.emplace_back(quant_param); | this->input_quant_param_.emplace_back(quant_param); | ||||
| } | } | ||||
| std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const { | |||||
| return input_quant_param_; | |||||
| } | |||||
| std::vector<std::vector<schema::QuantParamT>> GetInputQuantParams() const { return input_quant_param_; } | |||||
| void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) { | void AddOutputQuantParam(std::vector<schema::QuantParamT> quant_param) { | ||||
| this->output_quant_param_.emplace_back(quant_param); | this->output_quant_param_.emplace_back(quant_param); | ||||
| } | } | ||||
| std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const { | |||||
| return output_quant_param_; | |||||
| } | |||||
| std::vector<std::vector<schema::QuantParamT>> GetOutputQuantParams() const { return output_quant_param_; } | |||||
| void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } | void SetQuantType(schema::QuantType quant_type) { this->quant_type_ = quant_type; } | ||||
| @@ -73,7 +69,12 @@ class PrimitiveTValue : public Value { | |||||
| std::vector<std::vector<schema::QuantParamT>> output_quant_param_; | std::vector<std::vector<schema::QuantParamT>> output_quant_param_; | ||||
| schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | schema::QuantType quant_type_{schema::QuantType_QUANT_NONE}; | ||||
| }; | }; | ||||
| std::shared_ptr<PrimitiveTValue> GetReturnPrim(); | |||||
| std::shared_ptr<PrimitiveTValue> GetMakeTuplePrim(); | |||||
| std::shared_ptr<PrimitiveTValue> GetTupleGetItemPrim(); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ | #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_PRIMITIVET_H_ | ||||
| @@ -170,6 +170,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/runtime/thread_pool.cc | ${LITE_DIR}/src/runtime/thread_pool.cc | ||||
| ${LITE_DIR}/src/runtime/workspace_pool.cc | ${LITE_DIR}/src/runtime/workspace_pool.cc | ||||
| ${LITE_DIR}/src/ir/tensor.cc | ${LITE_DIR}/src/ir/tensor.cc | ||||
| ${LITE_DIR}/src/ir/primitive_t_value.cc | |||||
| ${LITE_DIR}/src/context.cc | ${LITE_DIR}/src/context.cc | ||||
| ${LITE_DIR}/src/executor.cc | ${LITE_DIR}/src/executor.cc | ||||
| ${LITE_DIR}/src/kernel_factory.cc | ${LITE_DIR}/src/kernel_factory.cc | ||||
| @@ -218,9 +219,6 @@ if(BUILD_CONVERTER) | |||||
| ${TEST_CASE_TFLITE_PARSERS_SRC} | ${TEST_CASE_TFLITE_PARSERS_SRC} | ||||
| ${TOP_DIR}/mindspore/core/utils/flags.cc | ${TOP_DIR}/mindspore/core/utils/flags.cc | ||||
| ${LITE_DIR}/tools/converter/optimizer.cc | ${LITE_DIR}/tools/converter/optimizer.cc | ||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_meta_graphT.cc | |||||
| # ${LITE_DIR}/src/common/anf_importer/import_from_protobuf.cc | |||||
| ${LITE_DIR}/tools/converter/anf_transform.cc | ${LITE_DIR}/tools/converter/anf_transform.cc | ||||
| ${LITE_DIR}/tools/converter/graphdef_transform.cc | ${LITE_DIR}/tools/converter/graphdef_transform.cc | ||||
| ${LITE_DIR}/tools/converter/converter_flags.cc | ${LITE_DIR}/tools/converter/converter_flags.cc | ||||
| @@ -300,7 +298,6 @@ if (SUPPORT_TRAIN) | |||||
| set(TEST_SRC | set(TEST_SRC | ||||
| ${TEST_SRC} | ${TEST_SRC} | ||||
| ${TEST_CASE_KERNEL_TRAIN_SRC} | ${TEST_CASE_KERNEL_TRAIN_SRC} | ||||
| # ${TEST_DIR}/ut/src/train_test.cc | |||||
| ${TEST_DIR}/ut/src/infer_test.cc # temporary | ${TEST_DIR}/ut/src/infer_test.cc # temporary | ||||
| ) | ) | ||||
| else() | else() | ||||
| @@ -350,6 +347,7 @@ endif() | |||||
| if (BUILD_CONVERTER) | if (BUILD_CONVERTER) | ||||
| target_link_libraries(lite-test | target_link_libraries(lite-test | ||||
| anf_importer_mid | anf_importer_mid | ||||
| anf_exporter_mid | |||||
| tflite_parser_mid | tflite_parser_mid | ||||
| caffe_parser_mid | caffe_parser_mid | ||||
| onnx_parser_mid | onnx_parser_mid | ||||
| @@ -246,7 +246,7 @@ TEST_F(InferTest, TestAddNode) { | |||||
| TEST_F(InferTest, TestModel) { | TEST_F(InferTest, TestModel) { | ||||
| auto buf = new char *[1]; | auto buf = new char *[1]; | ||||
| size_t model_size; | size_t model_size; | ||||
| std::string model_path = "./model.ms"; | |||||
| std::string model_path = "./models/model_hebing_3branch.ms"; | |||||
| ReadFile(model_path.c_str(), &model_size, buf); | ReadFile(model_path.c_str(), &model_size, buf); | ||||
| ASSERT_NE(nullptr, buf[0]); | ASSERT_NE(nullptr, buf[0]); | ||||
| @@ -25,7 +25,7 @@ | |||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | #include "tools/optimizer/fusion/constant_folding_fusion.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ConstantFoldingFusionTest : public mindspore::CommonTest { | class ConstantFoldingFusionTest : public mindspore::CommonTest { | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ConvActivationFusionTest : public mindspore::CommonTest { | class ConvActivationFusionTest : public mindspore::CommonTest { | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ConvBiasAddFusionTest : public mindspore::CommonTest { | class ConvBiasAddFusionTest : public mindspore::CommonTest { | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "mindspore/core/utils/log_adapter.h" | #include "mindspore/core/utils/log_adapter.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ConvBNFusionTest : public mindspore::CommonTest { | class ConvBNFusionTest : public mindspore::CommonTest { | ||||
| @@ -24,7 +24,7 @@ | |||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class ConvScaleFusionTest : public mindspore::CommonTest { | class ConvScaleFusionTest : public mindspore::CommonTest { | ||||
| @@ -0,0 +1,6 @@ | |||||
| file(GLOB_RECURSE ANF_EXPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| *.cc | |||||
| ) | |||||
| add_library(anf_exporter_mid OBJECT | |||||
| ${ANF_EXPORTER_SRC_LIST} | |||||
| ) | |||||
| @@ -0,0 +1,431 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "abstract/abstract_value.h" | |||||
| #include "mindspore/core/ir/primitive.h" | |||||
| #include "src/ir/tensor.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace mindspore::lite { | |||||
| void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { | |||||
| bool has_make_tuple = false; | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| inputs.emplace_back(cnode->input(0)); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| AnfNodePtr inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| continue; | |||||
| } | |||||
| auto make_tuple_node = utils::cast<CNodePtr>(inputNode); | |||||
| if (IsPrimitiveCNode(make_tuple_node, schema::PrimitiveType_MakeTuple)) { | |||||
| has_make_tuple = true; | |||||
| for (size_t j = 1; j < make_tuple_node->inputs().size(); ++j) { | |||||
| inputs.emplace_back(make_tuple_node->input(j)); | |||||
| } | |||||
| } else { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| } | |||||
| } | |||||
| if (has_make_tuple) { | |||||
| cnode->set_inputs(inputs); | |||||
| } | |||||
| } | |||||
| bool AnfExporter::RemoveIfTupleGetItem(const CNodePtr &cnode) { | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| bool has_tuple_get_item = false; | |||||
| std::vector<AnfNodePtr> inputs; | |||||
| inputs.clear(); | |||||
| inputs.emplace_back(cnode->input(0)); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| AnfNodePtr inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| continue; | |||||
| } | |||||
| auto tuple_get_item_node = utils::cast<CNodePtr>(inputNode); | |||||
| if (IsPrimitiveCNode(tuple_get_item_node, schema::PrimitiveType_TupleGetItem)) { | |||||
| has_tuple_get_item = true; | |||||
| inputs.emplace_back(tuple_get_item_node->input(1)); | |||||
| AnfNodePtr indexNode = tuple_get_item_node->input(2); | |||||
| if (!utils::isa<ValueNode>(indexNode)) { | |||||
| MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode"; | |||||
| return false; | |||||
| } | |||||
| ValueNodePtr value_node = utils::cast<ValueNodePtr>(indexNode); | |||||
| map_remove_get_item_[tuple_get_item_node->input(1)->fullname_with_scope()] = GetValue<int>(value_node->value()); | |||||
| } else { | |||||
| inputs.emplace_back(cnode->input(i)); | |||||
| } | |||||
| } | |||||
| if (has_tuple_get_item) { | |||||
| cnode->set_inputs(inputs); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AnfExporter::AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode) { | |||||
| MS_ASSERT(meta_graphT != nullptr); | |||||
| MS_ASSERT(cnode != nullptr); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| auto inputNode = cnode->input(i); | |||||
| if (!inputNode->isa<CNode>()) { | |||||
| MS_LOG(ERROR) << "Node of Return's input is not CNode"; | |||||
| return false; | |||||
| } | |||||
| auto inputCNode = utils::cast<CNodePtr>(inputNode); | |||||
| std::string inputName = inputNode->fullname_with_scope(); | |||||
| auto graphOutput = node_id_map_[inputName]; | |||||
| meta_graphT->outputIndex.emplace_back(graphOutput); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| int AnfExporter::ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||||
| const std::shared_ptr<PrimitiveTValue> primitive, | |||||
| const std::unique_ptr<schema::CNodeT> &dst_node) { | |||||
| MS_ASSERT(meta_graph != nullptr); | |||||
| MS_ASSERT(primitive != nullptr); | |||||
| MS_ASSERT(dst_node != nullptr); | |||||
| // add quant param | |||||
| dst_node->quantType = primitive->GetQuantType(); | |||||
| if (dst_node->quantType == schema::QuantType_PostTraining || dst_node->quantType == schema::QuantType_AwareTraining) { | |||||
| MS_LOG(DEBUG) << "node: " << dst_node->name << " add QuantParam"; | |||||
| // activation | |||||
| auto input_quant_params = primitive->GetInputQuantParams(); | |||||
| auto node_type = primitive->GetPrimitiveT()->value.type; | |||||
| for (int i = 0; i < input_quant_params.size(); i++) { | |||||
| if (i >= dst_node->inputIndex.size()) { | |||||
| MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() | |||||
| << " quant_params; but only " << dst_node->inputIndex.size() << " input"; | |||||
| break; | |||||
| } | |||||
| auto activate_index = dst_node->inputIndex[i]; | |||||
| auto tensor_input = meta_graph->allTensors[activate_index].get(); | |||||
| if (tensor_input->quantParams.empty()) { | |||||
| for (auto input_quant_param : input_quant_params[i]) { | |||||
| std::unique_ptr<schema::QuantParamT> input_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(input_quant_param); | |||||
| MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale | |||||
| << " zp: " << input_quant_param_ptr->zeroPoint; | |||||
| tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); | |||||
| } | |||||
| } | |||||
| } | |||||
| // output | |||||
| auto output_index = dst_node->outputIndex[0]; | |||||
| auto tensor_output = meta_graph->allTensors[output_index].get(); | |||||
| auto output_quant_params = primitive->GetOutputQuantParams(); | |||||
| if (output_quant_params.empty()) { | |||||
| MS_LOG(WARNING) << "node: " << dst_node->name << " output quant params is empty"; | |||||
| } else { | |||||
| for (auto output_quant_param : output_quant_params[0]) { | |||||
| if (tensor_output->quantParams.empty()) { | |||||
| std::unique_ptr<schema::QuantParamT> output_quant_param_ptr = | |||||
| std::make_unique<schema::QuantParamT>(output_quant_param); | |||||
| MS_LOG(DEBUG) << "[input]node: " << dst_node->name << " scale: " << output_quant_param_ptr->scale | |||||
| << " zp: " << output_quant_param_ptr->zeroPoint; | |||||
| tensor_output->quantParams.emplace_back(std::move(output_quant_param_ptr)); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (dst_node->quantType != schema::QuantType_AwareTraining && | |||||
| !(node_type == schema::PrimitiveType_QuantDTypeCast && | |||||
| primitive->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { | |||||
| tensor_output->dataType = kNumberTypeInt8; | |||||
| } | |||||
| // // TensorType | |||||
| // valuePtr = primitive->GetAttr(kInputTensorDataType); | |||||
| // if (valuePtr != nullptr) { | |||||
| // MS_LOG(INFO) << "node: " << node->name << " input tensor data | |||||
| // type: " << GetValue<int>(valuePtr); for (auto input : | |||||
| // node->inputIndex) { | |||||
| // auto tensor = subGraph->allTensors[input].get(); | |||||
| // tensor->dataType = kNumberTypeUInt8; | |||||
| // } | |||||
| // } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT) { | |||||
| for (auto node : graph_input_nodes_) { | |||||
| for (auto input : node->inputIndex) { | |||||
| auto tensor = meta_graphT->allTensors[input].get(); | |||||
| if (tensor->data.empty()) { | |||||
| tensor->nodeType = schema::NodeType_ValueNode; | |||||
| tensor->format = schema::Format_NHWC; | |||||
| if (!IsContain(meta_graphT->inputIndex, input)) { | |||||
| meta_graphT->inputIndex.emplace_back(input); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph) { | |||||
| auto cnodes = func_graph->GetOrderedCnodes(); | |||||
| auto meta_graphT = std::make_unique<schema::MetaGraphT>(); | |||||
| for (const auto &cnode : cnodes) { | |||||
| auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||||
| if (primitiveT_value == nullptr) { | |||||
| MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| auto primT = primitiveT_value->GetPrimitiveT(); | |||||
| if (primT == nullptr) { | |||||
| MS_LOG(ERROR) << "PrimitiveT is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| if (primT->value.type == schema::PrimitiveType_TupleGetItem || | |||||
| primT->value.type == schema::PrimitiveType_MakeTuple) { | |||||
| continue; | |||||
| } | |||||
| map_remove_get_item_.clear(); | |||||
| RemoveIfMakeTuple(cnode); | |||||
| if (!RemoveIfTupleGetItem(cnode)) { | |||||
| MS_LOG(ERROR) << "RemoveIfTupleGetItem failed"; | |||||
| return nullptr; | |||||
| } | |||||
| if (primT->value.type == schema::PrimitiveType_Return) { | |||||
| AddOutPutIfReturn(meta_graphT, cnode); | |||||
| continue; | |||||
| } | |||||
| auto node = std::make_unique<schema::CNodeT>(); | |||||
| node->name = cnode->fullname_with_scope(); | |||||
| node->nodeType = schema::NodeType_CNode; | |||||
| node->primitive = std::unique_ptr<schema::PrimitiveT>(primT); | |||||
| auto ret = SetOpInputNode(cnode, meta_graphT, node.get()); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "SetOpInputNode failed"; | |||||
| return nullptr; | |||||
| } | |||||
| SetOpOutputNode(cnode, meta_graphT, node.get()); | |||||
| ret = ConvertQuantParam(meta_graphT, primitiveT_value, node); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvertQuantParam failed"; | |||||
| return nullptr; | |||||
| } | |||||
| meta_graphT->nodes.emplace_back(std::move(node)); | |||||
| primitiveT_value->SetPrimitiveT(nullptr); | |||||
| } | |||||
| // set graph input tensors | |||||
| SetGraphInputIndex(meta_graphT); | |||||
| return meta_graphT.release(); | |||||
| } | |||||
| void AnfExporter::ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode) { | |||||
| std::string input_name = input_anode->fullname_with_scope(); | |||||
| if (!map_remove_get_item_.empty()) { | |||||
| for (auto name : map_remove_get_item_) { | |||||
| if (name.first == input_name) { | |||||
| input_name = input_name + "_o:" + std::to_string(name.second); | |||||
| } | |||||
| } | |||||
| } | |||||
| if (node_id_map_.find(input_name) != node_id_map_.end()) { | |||||
| output_cnode->inputIndex.emplace_back(node_id_map_[input_name]); | |||||
| } | |||||
| } | |||||
| int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, size_t anode_index, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *output_cnode) { | |||||
| std::string input_name = input_anode->fullname_with_scope(); | |||||
| auto paramNode = input_anode->cast<ParameterPtr>(); | |||||
| if (paramNode->name().empty()) { | |||||
| paramNode->set_name(input_name + "_i:" + std::to_string(anode_index - 1)); | |||||
| } | |||||
| if (node_id_map_.find(paramNode->name()) != node_id_map_.end()) { | |||||
| output_cnode->inputIndex.emplace_back(node_id_map_[paramNode->name()]); | |||||
| return RET_OK; | |||||
| } | |||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | |||||
| auto abstractBase = paramNode->abstract(); | |||||
| if (abstractBase == nullptr) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << paramNode->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | |||||
| MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << paramNode->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||||
| MS_ASSERT(typePtr != nullptr); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| if (!utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | |||||
| MS_LOG(ERROR) << "Shape of Abstract of parameter should be ShapePtr, " << paramNode->name(); | |||||
| return RET_ERROR; | |||||
| } | |||||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| auto paramValue = std::dynamic_pointer_cast<ParamValueLite>(paramNode->default_param()); | |||||
| if (paramValue != nullptr) { | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| paramTensor->data.resize(paramValue->tensor_size()); | |||||
| memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); | |||||
| } | |||||
| node_id_map_[paramNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||||
| return RET_OK; | |||||
| } | |||||
| int AnfExporter::ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *output_cnode) { | |||||
| auto valueNode = input_anode->cast<ValueNodePtr>(); | |||||
| auto paramTensor = std::make_unique<schema::TensorT>(); | |||||
| auto value = valueNode->value(); | |||||
| if (value->isa<lite::tensor::Tensor>()) { | |||||
| auto valueAbstract = valueNode->abstract(); | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(valueAbstract); | |||||
| auto typePtr = abstractTensor->element()->GetTypeTrack(); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| paramTensor->dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| auto data = value->cast<lite::tensor::TensorPtr>(); | |||||
| paramTensor->data.resize(data->Size()); | |||||
| memcpy(paramTensor->data.data(), data->Data(), data->Size()); | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<mindspore::Int32Imm>()) { | |||||
| auto valueAbstract = valueNode->abstract(); | |||||
| auto abstractScalar = utils::cast<abstract::AbstractScalarPtr>(valueAbstract); | |||||
| auto typePtr = abstractScalar->GetTypeTrack(); | |||||
| paramTensor->dataType = typePtr->type_id(); | |||||
| paramTensor->dims = {1}; | |||||
| paramTensor->nodeType = schema::NodeType_ValueNode; | |||||
| auto data = value->cast<mindspore::Int32ImmPtr>(); | |||||
| paramTensor->data.emplace_back(data->value()); | |||||
| node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); | |||||
| output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| meta_graphT->allTensors.emplace_back(std::move(paramTensor)); | |||||
| } else if (value->isa<mindspore::ValueSequeue>()) { | |||||
| MS_LOG(DEBUG) << "Value type is ValueSequence."; | |||||
| return RET_OK; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Not support value type , need add support."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *fb_node) { | |||||
| MS_ASSERT(nullptr != meta_graph); | |||||
| MS_ASSERT(nullptr != fb_node); | |||||
| if (cnode->inputs().size() <= 1) { | |||||
| return RET_OK; | |||||
| } | |||||
| bool is_graph_input = true; | |||||
| for (size_t i = 1; i < cnode->inputs().size(); i++) { | |||||
| auto input_node = cnode->input(i); | |||||
| if (input_node->isa<CNode>()) { | |||||
| is_graph_input = false; | |||||
| ConvertInputCNode(input_node, fb_node); | |||||
| } else if (input_node->isa<Parameter>()) { | |||||
| auto ret = ConvertInputParameter(input_node, i, meta_graphT, fb_node); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvertInputParameter failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } else if (input_node->isa<ValueNode>()) { | |||||
| auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ConvertInputValueNode failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| fb_node->name = cnode->fullname_with_scope(); | |||||
| if (is_graph_input) { | |||||
| graph_input_nodes_.emplace_back(fb_node); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *fb_node) { | |||||
| MS_ASSERT(nullptr != graph); | |||||
| MS_ASSERT(nullptr != fb_node); | |||||
| std::string cnode_name = fb_node->name; | |||||
| if (utils::isa<abstract::AbstractTuple>(cnode->abstract())) { | |||||
| auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); | |||||
| for (int i = 0; i < tuple->size(); i++) { | |||||
| auto msTensor = new schema::TensorT(); | |||||
| msTensor->nodeType = schema::NodeType_Parameter; | |||||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| if (tuple->size() == 1) { | |||||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||||
| } else { | |||||
| std::string name = cnode_name + "_o:" + std::to_string(i); | |||||
| node_id_map_[name] = meta_graphT->allTensors.size(); | |||||
| } | |||||
| meta_graphT->allTensors.emplace_back(msTensor); | |||||
| } | |||||
| } else { | |||||
| auto ms_tensor = new schema::TensorT(); | |||||
| ms_tensor->nodeType = schema::NodeType_Parameter; | |||||
| fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); | |||||
| node_id_map_[cnode_name] = meta_graphT->allTensors.size(); | |||||
| meta_graphT->allTensors.emplace_back(ms_tensor); | |||||
| } | |||||
| } | |||||
| bool AnfExporter::IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type) { | |||||
| MS_ASSERT(node != nullptr); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| if (cnode == nullptr) { | |||||
| return false; | |||||
| } | |||||
| const auto &prim = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); | |||||
| if (prim == nullptr) { | |||||
| return false; | |||||
| } | |||||
| auto *primitiveT = prim->GetPrimitiveT(); | |||||
| if (primitiveT == nullptr) { | |||||
| return false; | |||||
| } | |||||
| return primitiveT->value.type == type; | |||||
| } | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph) { | |||||
| AnfExporter anf_exporter; | |||||
| return anf_exporter.Export(func_graph); | |||||
| } | |||||
| } // namespace mindspore::lite | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). | |||||
| * | |||||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #define MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "ir/func_graph.h" | |||||
| namespace mindspore::lite { | |||||
| class AnfExporter { | |||||
| public: | |||||
| AnfExporter() = default; | |||||
| virtual ~AnfExporter() = default; | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph); | |||||
| void SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *fb_node); | |||||
| int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<schema::MetaGraphT> &meta_graphT, | |||||
| schema::CNodeT *fb_node); | |||||
| void RemoveIfMakeTuple(const CNodePtr &cnode); | |||||
| bool RemoveIfTupleGetItem(const CNodePtr &cnode); | |||||
| bool AddOutPutIfReturn(const std::unique_ptr<schema::MetaGraphT> &meta_graphT, const CNodePtr &cnode); | |||||
| protected: | |||||
| void ConvertInputCNode(const std::shared_ptr<AnfNode> input_anode, schema::CNodeT *output_cnode); | |||||
| int ConvertInputParameter(const std::shared_ptr<AnfNode> input_anode, size_t anode_index, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||||
| int ConvertInputValueNode(std::shared_ptr<AnfNode> input_anode, | |||||
| const std::unique_ptr<schema::MetaGraphT> &meta_graphT, schema::CNodeT *output_cnode); | |||||
| void SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &meta_graphT); | |||||
| bool IsPrimitiveCNode(const AnfNodePtr &node, schema::PrimitiveType type); | |||||
| int ConvertQuantParam(const std::unique_ptr<schema::MetaGraphT> &meta_graph, | |||||
| const std::shared_ptr<PrimitiveTValue> primitive, | |||||
| const std::unique_ptr<schema::CNodeT> &dst_node); | |||||
| private: | |||||
| std::map<std::string, int> node_id_map_; | |||||
| std::vector<schema::CNodeT *> graph_input_nodes_; | |||||
| std::map<std::string, int> map_remove_get_item_; | |||||
| }; | |||||
| schema::MetaGraphT *Export(const FuncGraphPtr &func_graph); | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_ANF_EXPORTER_ANF_EXPORTER_H_ | |||||
| @@ -0,0 +1,6 @@ | |||||
| file(GLOB_RECURSE ANF_IMPORTER_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| *.cc | |||||
| ) | |||||
| add_library(anf_importer_mid OBJECT | |||||
| ${ANF_IMPORTER_SRC_LIST} | |||||
| ) | |||||
| @@ -18,7 +18,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| #include "tools/anf_importer/anf_importer.h" | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "ir/dtype.h" | #include "ir/dtype.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -160,13 +160,21 @@ void MinnieBuildGraph::FbTest(const GraphDef *graph_def) { | |||||
| #endif | #endif | ||||
| int AnfImporter::Import(const schema::QuantType &quantType) { | int AnfImporter::Import(const schema::QuantType &quantType) { | ||||
| ConverterConstTensor(); | |||||
| auto ret = ConverterCNode(); | |||||
| auto ret = ConverterConstTensor(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "ConverterConstTensor failed " << ret; | |||||
| return ret; | |||||
| } | |||||
| ret = ConverterCNode(); | |||||
| if (RET_OK != ret) { | if (RET_OK != ret) { | ||||
| MS_LOG(ERROR) << "ConverterCNode failed " << ret; | MS_LOG(ERROR) << "ConverterCNode failed " << ret; | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| AddReturnCNode(); | |||||
| ret = AddReturnCNode(); | |||||
| if (RET_OK != ret) { | |||||
| MS_LOG(ERROR) << "AddReturnCNode failed " << ret; | |||||
| return ret; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -181,4 +189,3 @@ AnfNodePtr AnfImporter::GetNode(int tensor_id) { | |||||
| void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } | void AnfImporter::AddNode(int tensor_id, AnfNodePtr node) { nodes_[tensor_id] = std::move(node); } | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,11 +36,11 @@ class AnfImporter { | |||||
| protected: | protected: | ||||
| // convert const tensor into parameter and save in nodes_ | // convert const tensor into parameter and save in nodes_ | ||||
| virtual void ConverterConstTensor() = 0; | |||||
| virtual int ConverterConstTensor() = 0; | |||||
| // convert other node into cnode and save in nodes_ | // convert other node into cnode and save in nodes_ | ||||
| virtual int ConverterCNode() = 0; | virtual int ConverterCNode() = 0; | ||||
| virtual void AddReturnCNode() = 0; | |||||
| virtual int AddReturnCNode() = 0; | |||||
| AnfNodePtr GetNode(int tensor_id); | AnfNodePtr GetNode(int tensor_id); | ||||
| @@ -52,4 +52,3 @@ class AnfImporter { | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_ANF_IMPORTER_H_ | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_activation_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_activation_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -16,7 +16,7 @@ | |||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfActivationPopulater : public AnfNodePopulater { | class AnfActivationPopulater : public AnfNodePopulater { | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_batchnorm_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_batchnorm_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #define MINDSPORE_ANF_BATCHNORM_PARSER_H | #define MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfBatchnormPopulater : public AnfNodePopulater { | class AnfBatchnormPopulater : public AnfNodePopulater { | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_biasadd_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_biasadd_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_BIASADD_PARSER_H | #ifndef MINDSPORE_ANF_BIASADD_PARSER_H | ||||
| #define MINDSPORE_ANF_BIASADD_PARSER_H | #define MINDSPORE_ANF_BIASADD_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfBiasAddPopulater : public AnfNodePopulater { | class AnfBiasAddPopulater : public AnfNodePopulater { | ||||
| @@ -16,11 +16,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_concat_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_concat_populater.h" | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -18,7 +18,7 @@ | |||||
| #ifndef MINDSPORE_ANF_CONCAT_PARSER_H | #ifndef MINDSPORE_ANF_CONCAT_PARSER_H | ||||
| #define MINDSPORE_ANF_CONCAT_PARSER_H | #define MINDSPORE_ANF_CONCAT_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfConcatPopulater : public AnfNodePopulater { | class AnfConcatPopulater : public AnfNodePopulater { | ||||
| @@ -17,24 +17,19 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_conv_populater.h" | |||||
| #include <mindspore/lite/src/ir/tensor.h> | |||||
| #include <memory> | |||||
| #include "tools/anf_importer/anf_populater/anf_conv_populater.h" | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| void AnfConvPopulater::PopulaterConv2DMultiGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) { | |||||
| void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, | |||||
| const int &group) { | |||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | ||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | auto format = GetValue<std::string>(prim->GetAttr("data_format")); | ||||
| if (format == "NCHW") { | if (format == "NCHW") { | ||||
| @@ -75,9 +70,9 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup( | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } | } | ||||
| void AnfConvPopulater::PopulaterConv2DSingleGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group) { | |||||
| void AnfConvPopulater::PopulaterConv2DSingleGroup(const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, | |||||
| const int &group) { | |||||
| auto attr = std::make_unique<schema::Conv2DT>(); | auto attr = std::make_unique<schema::Conv2DT>(); | ||||
| attr->group = group; | attr->group = group; | ||||
| auto format = GetValue<std::string>(prim->GetAttr("data_format")); | auto format = GetValue<std::string>(prim->GetAttr("data_format")); | ||||
| @@ -120,17 +115,15 @@ void AnfConvPopulater::PopulaterConv2DSingleGroup( | |||||
| primitive->value.value = attr.release(); | primitive->value.value = attr.release(); | ||||
| } | } | ||||
| void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, | |||||
| float *mMin, float *mMax) { | |||||
| void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||||
| constexpr float qmin = 0; | constexpr float qmin = 0; | ||||
| constexpr float qmax = 255; | constexpr float qmax = 255; | ||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | *mMin = static_cast<float>((qmin - mean) / stdDev); | ||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | *mMax = static_cast<float>((qmax - mean) / stdDev); | ||||
| } | } | ||||
| void AnfConvPopulater::PopulaterQuantParam( | |||||
| const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| auto narrow_range = prim->GetAttr("narrow_range"); | auto narrow_range = prim->GetAttr("narrow_range"); | ||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | ||||
| auto num_bits = prim->GetAttr("num_bits"); | auto num_bits = prim->GetAttr("num_bits"); | ||||
| @@ -158,8 +151,8 @@ void AnfConvPopulater::PopulaterQuantParam( | |||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| } | } | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| @@ -176,8 +169,7 @@ void AnfConvPopulater::PopulaterQuantParam( | |||||
| for (int i = 0; i < biasQuantSize; ++i) { | for (int i = 0; i < biasQuantSize; ++i) { | ||||
| quantParam.min = *(minBuf++); | quantParam.min = *(minBuf++); | ||||
| quantParam.max = *(maxBuf++); | quantParam.max = *(maxBuf++); | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | numbitsRangeQuantParam); | ||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| } | } | ||||
| @@ -189,8 +181,7 @@ void AnfConvPopulater::PopulaterQuantParam( | |||||
| quantParam.min = 0.0; | quantParam.min = 0.0; | ||||
| quantParam.max = 0.0; | quantParam.max = 0.0; | ||||
| quantParam.zeroPoint = 0; | quantParam.zeroPoint = 0; | ||||
| quantParam.scale = | |||||
| vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; | |||||
| quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| } | } | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| @@ -205,15 +196,14 @@ void AnfConvPopulater::PopulaterQuantParam( | |||||
| float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | ||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| } | } | ||||
| } | } | ||||
| int AnfConvPopulater::Populate(const PrimitivePtr &prim, | |||||
| PrimitiveTValue *primitiveTValuePtr, | |||||
| int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | const std::vector<AnfNodePtr> &inputs) { | ||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | MS_ASSERT(primitiveTValuePtr != nullptr); | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| @@ -19,9 +19,10 @@ | |||||
| #ifndef MINDSPORE_ANF_CONV_PARSER_H | #ifndef MINDSPORE_ANF_CONV_PARSER_H | ||||
| #define MINDSPORE_ANF_CONV_PARSER_H | #define MINDSPORE_ANF_CONV_PARSER_H | ||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfConvPopulater : public AnfNodePopulater { | class AnfConvPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| @@ -31,16 +32,12 @@ class AnfConvPopulater : public AnfNodePopulater { | |||||
| const std::vector<AnfNodePtr> &inputs) override; | const std::vector<AnfNodePtr> &inputs) override; | ||||
| private: | private: | ||||
| void PopulaterConv2DMultiGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group); | |||||
| void PopulaterConv2DSingleGroup( | |||||
| const PrimitivePtr &prim, | |||||
| const std::unique_ptr<schema::PrimitiveT> &primitive, const int &group); | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, | |||||
| float *mMax); | |||||
| void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive, | |||||
| const int &group); | |||||
| void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr<schema::PrimitiveT> &primitive, | |||||
| const int &group); | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,31 +13,26 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | |||||
| #include <memory> | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, | |||||
| const double &stdDev, float *mMin, | |||||
| float *mMax) { | |||||
| void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||||
| constexpr float qmin = 0; | constexpr float qmin = 0; | ||||
| constexpr float qmax = 255; | constexpr float qmax = 255; | ||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | *mMin = static_cast<float>((qmin - mean) / stdDev); | ||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | *mMax = static_cast<float>((qmax - mean) / stdDev); | ||||
| } | } | ||||
| void AnfDepwiseconv2DPopulater::PopulaterQuantParam( | |||||
| const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| auto narrow_range = prim->GetAttr("narrow_range"); | auto narrow_range = prim->GetAttr("narrow_range"); | ||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | ||||
| auto num_bits = prim->GetAttr("num_bits"); | auto num_bits = prim->GetAttr("num_bits"); | ||||
| @@ -65,8 +60,8 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( | |||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| } | } | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| @@ -83,8 +78,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( | |||||
| for (int i = 0; i < biasQuantSize; ++i) { | for (int i = 0; i < biasQuantSize; ++i) { | ||||
| quantParam.min = *(minBuf++); | quantParam.min = *(minBuf++); | ||||
| quantParam.max = *(maxBuf++); | quantParam.max = *(maxBuf++); | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | numbitsRangeQuantParam); | ||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| } | } | ||||
| @@ -96,8 +90,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( | |||||
| quantParam.min = 0.0; | quantParam.min = 0.0; | ||||
| quantParam.max = 0.0; | quantParam.max = 0.0; | ||||
| quantParam.zeroPoint = 0; | quantParam.zeroPoint = 0; | ||||
| quantParam.scale = | |||||
| vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; | |||||
| quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| } | } | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| @@ -112,15 +105,14 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam( | |||||
| float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | ||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| } | } | ||||
| } | } | ||||
| int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, | |||||
| PrimitiveTValue *primitiveTValuePtr, | |||||
| int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | const std::vector<AnfNodePtr> &inputs) { | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | auto attr = std::make_unique<schema::DepthwiseConv2DT>(); | ||||
| @@ -171,13 +163,10 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, | |||||
| auto abstractBase = paramNode->abstract(); | auto abstractBase = paramNode->abstract(); | ||||
| MS_ASSERT(abstractBase != nullptr); | MS_ASSERT(abstractBase != nullptr); | ||||
| if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) { | ||||
| auto abstractTensor = | |||||
| utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase); | |||||
| MS_ASSERT(abstractTensor != nullptr); | MS_ASSERT(abstractTensor != nullptr); | ||||
| if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | if (utils::isa<abstract::ShapePtr>(abstractTensor->BuildShape())) { | ||||
| auto dims = | |||||
| utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape()) | |||||
| ->shape(); | |||||
| auto dims = utils::cast<abstract::ShapePtr>(abstractTensor->BuildShape())->shape(); | |||||
| attr->channelIn = dims[kAnfPopulaterOne]; | attr->channelIn = dims[kAnfPopulaterOne]; | ||||
| } | } | ||||
| } | } | ||||
| @@ -195,8 +184,6 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, | |||||
| } | } | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfdepthwise2dPopulater( | |||||
| "DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater( | |||||
| "DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dPopulater("DepthwiseConv2D", new AnfDepwiseconv2DPopulater()); | |||||
| AnfNodePopulaterRegistrar anfdepthwise2dnativePopulater("DepthwiseConv2dNative", new AnfDepwiseconv2DPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,9 +15,10 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | #ifndef MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | ||||
| #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | #define MINDSPORE_ANF_DEPTHWISECONV2D_PARSER_H | ||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | ||||
| public: | public: | ||||
| @@ -25,11 +26,10 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { | |||||
| ~AnfDepwiseconv2DPopulater() override = default; | ~AnfDepwiseconv2DPopulater() override = default; | ||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | ||||
| const std::vector<AnfNodePtr> &inputs) override; | const std::vector<AnfNodePtr> &inputs) override; | ||||
| private: | private: | ||||
| void PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, | |||||
| float *mMax); | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,11 +13,11 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_dequant_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_dequant_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H | #ifndef MINDSPORE_ANF_DEQUANT_PARSER_H | ||||
| #define MINDSPORE_ANF_DEQUANT_PARSER_H | #define MINDSPORE_ANF_DEQUANT_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfDequantPopulater : public AnfNodePopulater { | class AnfDequantPopulater : public AnfNodePopulater { | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_flatten_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_flatten_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H | #ifndef MINDSPORE_ANF_FLATTEN_PARSER_H | ||||
| #define MINDSPORE_ANF_FLATTEN_PARSER_H | #define MINDSPORE_ANF_FLATTEN_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfFlattenPopulater : public AnfNodePopulater { | class AnfFlattenPopulater : public AnfNodePopulater { | ||||
| @@ -13,29 +13,25 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_matmul_populater.h" | |||||
| #include <memory> | |||||
| #include "tools/anf_importer/anf_populater/anf_matmul_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | |||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, | |||||
| float *mMin, float *mMax) { | |||||
| void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax) { | |||||
| constexpr float qmin = 0; | constexpr float qmin = 0; | ||||
| constexpr float qmax = 255; | constexpr float qmax = 255; | ||||
| *mMin = static_cast<float>((qmin - mean) / stdDev); | *mMin = static_cast<float>((qmin - mean) / stdDev); | ||||
| *mMax = static_cast<float>((qmax - mean) / stdDev); | *mMax = static_cast<float>((qmax - mean) / stdDev); | ||||
| } | } | ||||
| void AnfMatmulPopulater::PopulaterQuantParam( | |||||
| const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) { | |||||
| auto narrow_range = prim->GetAttr("narrow_range"); | auto narrow_range = prim->GetAttr("narrow_range"); | ||||
| bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | bool narrowRangeQuantParam = GetValue<bool>(narrow_range); | ||||
| auto num_bits = prim->GetAttr("num_bits"); | auto num_bits = prim->GetAttr("num_bits"); | ||||
| @@ -63,8 +59,8 @@ void AnfMatmulPopulater::PopulaterQuantParam( | |||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| } | } | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| @@ -79,8 +75,7 @@ void AnfMatmulPopulater::PopulaterQuantParam( | |||||
| for (int i = 0; i < filterMinPtr->DataSize(); ++i) { | for (int i = 0; i < filterMinPtr->DataSize(); ++i) { | ||||
| quantParam.min = *(minBuf++); | quantParam.min = *(minBuf++); | ||||
| quantParam.max = *(maxBuf++); | quantParam.max = *(maxBuf++); | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | numbitsRangeQuantParam); | ||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| } | } | ||||
| @@ -97,15 +92,14 @@ void AnfMatmulPopulater::PopulaterQuantParam( | |||||
| float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | float *maxBuf = static_cast<float *>(outputMaxPtr->Data()); | ||||
| quantParam.min = *minBuf; | quantParam.min = *minBuf; | ||||
| quantParam.max = *maxBuf; | quantParam.max = *maxBuf; | ||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, | |||||
| narrowRangeQuantParam, numbitsRangeQuantParam); | |||||
| quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, | |||||
| numbitsRangeQuantParam); | |||||
| quants.emplace_back(quantParam); | quants.emplace_back(quantParam); | ||||
| vecQuantParam->emplace_back(quants); | vecQuantParam->emplace_back(quants); | ||||
| } | } | ||||
| } | } | ||||
| int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, | |||||
| PrimitiveTValue *primitiveTValuePtr, | |||||
| int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | |||||
| const std::vector<AnfNodePtr> &inputs) { | const std::vector<AnfNodePtr> &inputs) { | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| auto attr = std::make_unique<schema::MatMulT>(); | auto attr = std::make_unique<schema::MatMulT>(); | ||||
| @@ -124,8 +118,6 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", | |||||
| new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", | |||||
| new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); | |||||
| AnfNodePopulaterRegistrar anfMatMulPopulater("MatMul", new AnfMatmulPopulater()); | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_MATMUL_PARSER_H | #ifndef MINDSPORE_ANF_MATMUL_PARSER_H | ||||
| #define MINDSPORE_ANF_MATMUL_PARSER_H | #define MINDSPORE_ANF_MATMUL_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfMatmulPopulater : public AnfNodePopulater { | class AnfMatmulPopulater : public AnfNodePopulater { | ||||
| @@ -24,11 +24,10 @@ class AnfMatmulPopulater : public AnfNodePopulater { | |||||
| ~AnfMatmulPopulater() override = default; | ~AnfMatmulPopulater() override = default; | ||||
| int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | int Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | ||||
| const std::vector<AnfNodePtr> &inputs) override; | const std::vector<AnfNodePtr> &inputs) override; | ||||
| private: | private: | ||||
| void PopulaterQuantParam(const PrimitivePtr &prim, | |||||
| std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, | |||||
| float *mMax); | |||||
| void PopulaterQuantParam(const PrimitivePtr &prim, std::vector<std::vector<schema::QuantParamT>> *vecQuantParam); | |||||
| void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_mul_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_mul_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfMulPopulater : public AnfNodePopulater { | class AnfMulPopulater : public AnfNodePopulater { | ||||
| @@ -14,6 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| namespace mindspore::lite {} // namespace mindspore::lite | namespace mindspore::lite {} // namespace mindspore::lite | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include <string> | #include <string> | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -16,7 +16,7 @@ | |||||
| #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | #ifndef MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | ||||
| #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | #define MINDSPORE_ANF_NODE_PARSER_REGISTRY_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <string> | #include <string> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -13,11 +13,11 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_pool_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_pool_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_POOL_PARSER_H | #ifndef MINDSPORE_ANF_POOL_PARSER_H | ||||
| #define MINDSPORE_ANF_POOL_PARSER_H | #define MINDSPORE_ANF_POOL_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfPoolPopulater : public AnfNodePopulater { | class AnfPoolPopulater : public AnfNodePopulater { | ||||
| @@ -13,11 +13,11 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_quant_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_quant_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_QUANT_PARSER_H | #ifndef MINDSPORE_ANF_QUANT_PARSER_H | ||||
| #define MINDSPORE_ANF_QUANT_PARSER_H | #define MINDSPORE_ANF_QUANT_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfQuantPopulater : public AnfNodePopulater { | class AnfQuantPopulater : public AnfNodePopulater { | ||||
| @@ -13,18 +13,18 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_reducemean_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_reducemean_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| namespace { | namespace { | ||||
| constexpr int kReduceInputNum = 3; | |||||
| constexpr int kReduceInputIndex = 2; | |||||
| } | |||||
| constexpr int kReduceInputNum = 3; | |||||
| constexpr int kReduceInputIndex = 2; | |||||
| } // namespace | |||||
| int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | int AnfReduceMeanPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primitiveTValuePtr, | ||||
| const std::vector<AnfNodePtr> &inputs) { | const std::vector<AnfNodePtr> &inputs) { | ||||
| auto primitive = std::make_unique<schema::PrimitiveT>(); | auto primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfReduceMeanPopulater : public AnfNodePopulater { | class AnfReduceMeanPopulater : public AnfNodePopulater { | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_reshape_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_reshape_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -16,7 +16,7 @@ | |||||
| #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H | #ifndef MINDSPORE_ANF_RESHAPE_PARSER_H | ||||
| #define MINDSPORE_ANF_RESHAPE_PARSER_H | #define MINDSPORE_ANF_RESHAPE_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfReshapePopulater : public AnfNodePopulater { | class AnfReshapePopulater : public AnfNodePopulater { | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_tensoradd_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_tensoradd_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | #ifndef MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #define MINDSPORE_ANF_ACTIVATION_PARSER_H | #define MINDSPORE_ANF_ACTIVATION_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfTensorAddPopulater : public AnfNodePopulater { | class AnfTensorAddPopulater : public AnfNodePopulater { | ||||
| @@ -13,11 +13,11 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_transpose_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_transpose_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H | #ifndef MINDSPORE_ANF_TRANSPOSE_PARSER_H | ||||
| #define MINDSPORE_ANF_TRANSPOSE_PARSER_H | #define MINDSPORE_ANF_TRANSPOSE_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfTransposePopulater : public AnfNodePopulater { | class AnfTransposePopulater : public AnfNodePopulater { | ||||
| @@ -13,10 +13,10 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/anf_populater/anf_tuple_getitem_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_tuple_getitem_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| #include <memory> | #include <memory> | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "ir/primitive.h" | #include "ir/primitive.h" | ||||
| @@ -15,7 +15,7 @@ | |||||
| */ | */ | ||||
| #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | #ifndef MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #define MINDSPORE_ANF_BATCHNORM_PARSER_H | #define MINDSPORE_ANF_BATCHNORM_PARSER_H | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include "tools/anf_importer/anf_populater/anf_node_populater.h" | |||||
| #include <vector> | #include <vector> | ||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfTupleGetItemPopulater : public AnfNodePopulater { | class AnfTupleGetItemPopulater : public AnfNodePopulater { | ||||
| @@ -0,0 +1,194 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <vector> | |||||
| #include <algorithm> | |||||
| #include "schema/inner/model_generated.h" | |||||
| #include "frontend/operator/ops.h" | |||||
| #include "src/param_value_lite.h" | |||||
| #include "import_from_meta_graphT.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "src/ops/ops.h" | |||||
| namespace mindspore::lite { | |||||
| int AnfImporterFromMetaGraphT::ConverterConstTensor() { | |||||
| MS_ASSERT(nullptr != meta_graph_); | |||||
| MS_ASSERT(nullptr != func_graph_); | |||||
| for (size_t i = 0; i < meta_graph_->allTensors.size(); i++) { | |||||
| auto &tensor = meta_graph_->allTensors.at(i); | |||||
| MS_ASSERT(tensor != nullptr); | |||||
| // converter weight and graph input into parameter node | |||||
| if (tensor->nodeType != schema::NodeType_ValueNode) { | |||||
| continue; | |||||
| } | |||||
| MS_ASSERT(tensor->dims() != nullptr); | |||||
| auto parameter = func_graph_->add_parameter(); | |||||
| std::vector<int> shape(tensor->dims.size()); | |||||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| abstract_tensor->set_format(tensor->format); | |||||
| parameter->set_abstract(abstract_tensor); | |||||
| parameter->set_name("const_" + std::to_string(i)); | |||||
| ParamValueLitePtr param_value = std::make_shared<ParamValueLite>(); | |||||
| MS_ASSERT(param_value != nullptr); | |||||
| param_value->set_tensor_shape(shape); | |||||
| param_value->set_tensor_type(type_id); | |||||
| if (!tensor->data.empty()) { | |||||
| auto size = tensor->data.size(); | |||||
| char *tensor_data = new (std::nothrow) char[size]; | |||||
| if (tensor_data == nullptr) { | |||||
| MS_LOG(ERROR) << "new char[] failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::memcpy(tensor_data, tensor->data.data(), size); | |||||
| param_value->set_tensor_addr(tensor_data); | |||||
| param_value->set_tensor_size(size); | |||||
| } | |||||
| if (!tensor->quantParams.empty()) { | |||||
| std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>(); | |||||
| quantParam->scale = tensor->quantParams[0]->scale; | |||||
| quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint; | |||||
| quantParam->min = tensor->quantParams[0]->min; | |||||
| quantParam->max = tensor->quantParams[0]->max; | |||||
| quantParam->narrowRange = tensor->quantParams[0]->narrowRange; | |||||
| quantParam->numBits = tensor->quantParams[0]->numBits; | |||||
| quantParam->inited = tensor->quantParams[0]->inited; | |||||
| param_value->set_quant_param(quantParam); | |||||
| } | |||||
| parameter->set_default_param(param_value); | |||||
| AddNode(i, parameter); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| ValueNodePtr AnfImporterFromMetaGraphT::ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode) { | |||||
| MS_ASSERT(nullptr != meta_graph_); | |||||
| MS_ASSERT(nullptr != cNode); | |||||
| auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release()); | |||||
| cNode->primitive = nullptr; | |||||
| // add quant parameter | |||||
| if (cNode->quantType == schema::QuantType_AwareTraining) { | |||||
| primTValue->SetQuantType(cNode->quantType); | |||||
| for (int index : cNode->inputIndex) { | |||||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||||
| primTValue->AddInputQuantParam(quant_params); | |||||
| } | |||||
| for (int index : cNode->outputIndex) { | |||||
| std::vector<schema::QuantParamT> quant_params = {*(meta_graph_->allTensors[index]->quantParams[0])}; | |||||
| primTValue->AddOutputQuantParam(quant_params); | |||||
| } | |||||
| } | |||||
| auto value_node = NewValueNode(primTValue); | |||||
| return value_node; | |||||
| } | |||||
| abstract::AbstractTensorPtr AnfImporterFromMetaGraphT::ConvertTensorToAbstractTensor( | |||||
| const std::unique_ptr<schema::TensorT> &tensor) { | |||||
| MS_ASSERT(nullptr != tensor); | |||||
| std::vector<int> shape(tensor->dims.size()); | |||||
| std::copy(tensor->dims.begin(), tensor->dims.end(), shape.begin()); | |||||
| auto type_id = static_cast<TypeId>(tensor->dataType); | |||||
| auto type_ptr = TypeIdToType(type_id); | |||||
| return std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| } | |||||
| void AnfImporterFromMetaGraphT::ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, | |||||
| const CNodePtr &dst_cnode) { | |||||
| MS_ASSERT(nullptr != meta_graph_); | |||||
| MS_ASSERT(nullptr != src_cnode); | |||||
| MS_ASSERT(nullptr != dst_cnode); | |||||
| std::vector<uint32_t> out_tensor_ids = src_cnode->outputIndex; | |||||
| if (out_tensor_ids.size() == 1) { | |||||
| auto out_tensor_id = out_tensor_ids.front(); | |||||
| MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); | |||||
| auto &tensor = meta_graph_->allTensors.at(out_tensor_id); | |||||
| MS_ASSERT(nullptr != tensor); | |||||
| dst_cnode->set_abstract(ConvertTensorToAbstractTensor(tensor)); | |||||
| AddNode(out_tensor_id, dst_cnode); | |||||
| } else { | |||||
| AbstractBasePtrList abstract_list; | |||||
| for (size_t i = 0; i < out_tensor_ids.size(); i++) { | |||||
| auto out_tensor_id = out_tensor_ids.at(i); | |||||
| MS_ASSERT(meta_graph_->allTensors.size() > out_tensor_id); | |||||
| auto &tensor = meta_graph_->allTensors.at(out_tensor_id); | |||||
| MS_ASSERT(nullptr != tensor); | |||||
| abstract_list.emplace_back(ConvertTensorToAbstractTensor(tensor)); | |||||
| auto tuple_get_item_prim = NewValueNode(GetTupleGetItemPrim()); | |||||
| auto get_item_value = NewValueNode(MakeValue<int>(i)); | |||||
| std::vector<AnfNodePtr> inputs{tuple_get_item_prim, dst_cnode, get_item_value}; | |||||
| CNodePtr get_item_cnode = func_graph_->NewCNode(inputs); | |||||
| AddNode(out_tensor_id, get_item_cnode); | |||||
| } | |||||
| dst_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||||
| } | |||||
| } | |||||
| int AnfImporterFromMetaGraphT::ConverterCNode() { | |||||
| MS_ASSERT(nullptr != meta_graph_); | |||||
| MS_ASSERT(nullptr != func_graph_); | |||||
| for (const auto &cNode : meta_graph_->nodes) { | |||||
| MS_ASSERT(nullptr != cNode); | |||||
| std::vector<AnfNodePtr> op_inputs = {ConvertPrimitive(cNode)}; | |||||
| for (unsigned int j : cNode->inputIndex) { | |||||
| auto node = GetNode(j); | |||||
| if (nullptr == node) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| // todo: CheckInputNodeType, the first node should be op; | |||||
| op_inputs.push_back(node); | |||||
| } | |||||
| auto new_cnode = func_graph_->NewCNode(op_inputs); | |||||
| new_cnode->set_fullname_with_scope(cNode->name); | |||||
| ConvertAbstract(cNode, new_cnode); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int AnfImporterFromMetaGraphT::AddReturnCNode() { | |||||
| MS_EXCEPTION_IF_NULL(meta_graph_); | |||||
| MS_EXCEPTION_IF_NULL(func_graph_); | |||||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||||
| auto make_tuple_prim = NewValueNode(GetMakeTuplePrim()); | |||||
| make_tuple_inputs.emplace_back(make_tuple_prim); | |||||
| for (auto tensor_id : meta_graph_->outputIndex) { | |||||
| auto cNode = GetNode(tensor_id); | |||||
| if (nullptr == cNode) { | |||||
| MS_LOG(ERROR) << "Can't find input node."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| make_tuple_inputs.emplace_back(cNode); | |||||
| } | |||||
| auto make_tuple_cnode = func_graph_->NewCNode(make_tuple_inputs); | |||||
| make_tuple_cnode->set_fullname_with_scope("return tuple"); | |||||
| std::vector<AnfNodePtr> op_inputs; | |||||
| auto value_node = NewValueNode(GetReturnPrim()); | |||||
| op_inputs.emplace_back(value_node); | |||||
| op_inputs.emplace_back(make_tuple_cnode); | |||||
| auto cnode = func_graph_->NewCNode(op_inputs); | |||||
| cnode->set_fullname_with_scope("return"); | |||||
| func_graph_->set_return(cnode); | |||||
| return RET_OK; | |||||
| } | |||||
| FuncGraphPtr AnfImporterFromMetaGraphT::GetResult() { return this->func_graph_; } | |||||
| } // namespace mindspore::lite | |||||
| @@ -18,9 +18,11 @@ | |||||
| #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | #define MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | ||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| #include "tools/anf_importer/anf_importer.h" | |||||
| #include "src/ir/primitive_t_value.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporterFromMetaGraphT : public AnfImporter { | class AnfImporterFromMetaGraphT : public AnfImporter { | ||||
| @@ -33,11 +35,15 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||||
| FuncGraphPtr GetResult() override; | FuncGraphPtr GetResult() override; | ||||
| private: | private: | ||||
| void ConverterConstTensor() override; | |||||
| int ConverterConstTensor() override; | |||||
| int ConverterCNode() override; | int ConverterCNode() override; | ||||
| void AddReturnCNode() override; | |||||
| ValueNodePtr ConvertPrimitive(const std::unique_ptr<schema::CNodeT> &cNode); | |||||
| abstract::AbstractTensorPtr ConvertTensorToAbstractTensor(const std::unique_ptr<schema::TensorT> &tensor); | |||||
| void ConvertAbstract(const std::unique_ptr<schema::CNodeT> &src_cnode, const CNodePtr &dst_cnode); | |||||
| int AddReturnCNode() override; | |||||
| private: | private: | ||||
| schema::MetaGraphT *meta_graph_; | schema::MetaGraphT *meta_graph_; | ||||
| @@ -46,4 +52,3 @@ class AnfImporterFromMetaGraphT : public AnfImporter { | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | #endif // MINDSPORE_LITE_SRC_ANF_IMPORTER_IMPORTER_FROM_META_GRAPHT_H_ | ||||
| @@ -14,7 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/common/anf_importer/import_from_protobuf.h" | |||||
| #include "tools/anf_importer/import_from_protobuf.h" | |||||
| #include <fcntl.h> | #include <fcntl.h> | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| @@ -35,11 +35,11 @@ | |||||
| #include "ir/func_graph.h" | #include "ir/func_graph.h" | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "securec/include/securec.h" | #include "securec/include/securec.h" | ||||
| #include "src/common/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "src/param_value_lite.h" | #include "src/param_value_lite.h" | ||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #include "tools/anf_importer/anf_populater/anf_node_populater_registry.h" | |||||
| using string = std::string; | using string = std::string; | ||||
| using int32 = int32_t; | using int32 = int32_t; | ||||
| @@ -60,24 +60,16 @@ enum ParseForm : int { | |||||
| }; | }; | ||||
| static std::map<std::string, ParseForm> kParseTypeSwitchMap{ | static std::map<std::string, ParseForm> kParseTypeSwitchMap{ | ||||
| {"type", FORM_PARSE_TYPE}, | |||||
| {"scalar", FORM_PARSE_SCALAR}, | |||||
| {"tensor", FORM_PARSE_TENSOR}}; | |||||
| {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}}; | |||||
| static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | static std::unordered_map<int, TypeId> kDefaultValueSwitchMap{ | ||||
| {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, | |||||
| {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||||
| {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, | |||||
| {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||||
| {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, | |||||
| {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||||
| {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, | |||||
| {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||||
| {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, | |||||
| {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||||
| {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, | |||||
| {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||||
| {onnx::TensorProto_DataType_STRING, kObjectTypeString}, | |||||
| {onnx::TensorProto_DataType_BOOL, kNumberTypeBool}, {onnx::TensorProto_DataType_INT8, kNumberTypeInt8}, | |||||
| {onnx::TensorProto_DataType_INT16, kNumberTypeInt16}, {onnx::TensorProto_DataType_INT32, kNumberTypeInt32}, | |||||
| {onnx::TensorProto_DataType_INT64, kNumberTypeInt64}, {onnx::TensorProto_DataType_UINT8, kNumberTypeUInt8}, | |||||
| {onnx::TensorProto_DataType_UINT16, kNumberTypeUInt16}, {onnx::TensorProto_DataType_UINT32, kNumberTypeUInt32}, | |||||
| {onnx::TensorProto_DataType_UINT64, kNumberTypeUInt64}, {onnx::TensorProto_DataType_FLOAT16, kNumberTypeFloat16}, | |||||
| {onnx::TensorProto_DataType_FLOAT, kNumberTypeFloat32}, {onnx::TensorProto_DataType_DOUBLE, kNumberTypeFloat64}, | |||||
| {onnx::TensorProto_DataType_STRING, kObjectTypeString}, | |||||
| }; | }; | ||||
| #if 0 | #if 0 | ||||
| @@ -197,16 +189,15 @@ ParserAttrShape(const std::string &attr_name, const std::unordered_map<string, a | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype( \ | |||||
| const onnx::TensorProto &attr_tensor) { \ | |||||
| if (attr_tensor.type##_data_size() == 1) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } else { \ | |||||
| MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ | |||||
| } \ | |||||
| return {}; \ | |||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| ValuePtr ParseAttrInScalar_##type##_##valuetype(const onnx::TensorProto &attr_tensor) { \ | |||||
| if (attr_tensor.type##_data_size() == 1) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(0)); \ | |||||
| return MakeValue<valuetype>(value); \ | |||||
| } else { \ | |||||
| MS_LOG(ERROR) << "size of scalar tensor doesn't equal 1!"; \ | |||||
| } \ | |||||
| return {}; \ | |||||
| } | } | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | ||||
| @@ -652,21 +643,20 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFunc | |||||
| } | } | ||||
| #else | #else | ||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| void ParseAttrInScalar_##type##_##valuetype( \ | |||||
| const PrimitivePtr &prim, const std::string &attr_name, \ | |||||
| const onnx::TensorProto &attr_tensor) { \ | |||||
| MS_EXCEPTION_IF_NULL(prim); \ | |||||
| std::vector<ValuePtr> attr_value_vec; \ | |||||
| for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ | |||||
| attr_value_vec.push_back(MakeValue<valuetype>(value)); \ | |||||
| } \ | |||||
| if (attr_value_vec.size() == 1) { \ | |||||
| prim->AddAttr(attr_name, attr_value_vec[0]); \ | |||||
| } else { \ | |||||
| prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ | |||||
| } \ | |||||
| #define PARSE_ONNXATTR_IN_SCALAR_FORM(type, valuetype) \ | |||||
| void ParseAttrInScalar_##type##_##valuetype(const PrimitivePtr &prim, const std::string &attr_name, \ | |||||
| const onnx::TensorProto &attr_tensor) { \ | |||||
| MS_EXCEPTION_IF_NULL(prim); \ | |||||
| std::vector<ValuePtr> attr_value_vec; \ | |||||
| for (int i = 0; i < attr_tensor.type##_data_size(); ++i) { \ | |||||
| auto value = static_cast<valuetype>(attr_tensor.type##_data(i)); \ | |||||
| attr_value_vec.push_back(MakeValue<valuetype>(value)); \ | |||||
| } \ | |||||
| if (attr_value_vec.size() == 1) { \ | |||||
| prim->AddAttr(attr_name, attr_value_vec[0]); \ | |||||
| } else { \ | |||||
| prim->AddAttr(attr_name, std::make_shared<ValueList>(attr_value_vec)); \ | |||||
| } \ | |||||
| } | } | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | PARSE_ONNXATTR_IN_SCALAR_FORM(double, double) | ||||
| @@ -677,8 +667,8 @@ PARSE_ONNXATTR_IN_SCALAR_FORM(int32, bool) | |||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | PARSE_ONNXATTR_IN_SCALAR_FORM(int64, int64) | ||||
| PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | PARSE_ONNXATTR_IN_SCALAR_FORM(uint64, uint64) | ||||
| bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( | |||||
| const ParameterPtr &node, const onnx::ValueInfoProto &value_proto) { | |||||
| bool AnfImporterFromProtobuf::BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const onnx::ValueInfoProto &value_proto) { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (!value_proto.has_type() || !value_proto.has_name()) { | if (!value_proto.has_type() || !value_proto.has_name()) { | ||||
| MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | MS_LOG(ERROR) << "onnx ValueInfoProto has no type or name! "; | ||||
| @@ -701,30 +691,24 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( | |||||
| shape.push_back(tensor_shape.dim(i).dim_value()); | shape.push_back(tensor_shape.dim(i).dim_value()); | ||||
| } | } | ||||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == | |||||
| kDefaultValueSwitchMap.end()) { | |||||
| if (kDefaultValueSwitchMap.find(tensor_typeproto.elem_type()) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | MS_LOG(ERROR) << "onnx TypeProto_Tensor elem_type is not support yet!"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| auto type_ptr = | |||||
| TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); | |||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[tensor_typeproto.elem_type()]); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| node->set_abstract(abstract_tensor); | node->set_abstract(abstract_tensor); | ||||
| if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | if (default_para_map_.find(value_proto.name()) != default_para_map_.end()) { | ||||
| tensor::Tensor *tensor_info = new tensor::Tensor( | |||||
| kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| tensor::Tensor *tensor_info = new tensor::Tensor(kDefaultValueSwitchMap[tensor_typeproto.elem_type()], shape); | |||||
| MS_EXCEPTION_IF_NULL(tensor_info); | MS_EXCEPTION_IF_NULL(tensor_info); | ||||
| tensor_info->MallocData(); | tensor_info->MallocData(); | ||||
| const onnx::TensorProto initialize_proto = | |||||
| default_para_map_[value_proto.name()]; | |||||
| const onnx::TensorProto initialize_proto = default_para_map_[value_proto.name()]; | |||||
| std::string initial_data = initialize_proto.raw_data(); | std::string initial_data = initialize_proto.raw_data(); | ||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | ||||
| MS_EXCEPTION_IF_NULL(tensor_data_buf); | MS_EXCEPTION_IF_NULL(tensor_data_buf); | ||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), | |||||
| initial_data.data(), initial_data.size()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), initial_data.data(), initial_data.size()); | |||||
| if (EOK != ret) { | if (EOK != ret) { | ||||
| MS_LOG(ERROR) << "memcpy_s error"; | MS_LOG(ERROR) << "memcpy_s error"; | ||||
| return false; | return false; | ||||
| @@ -740,18 +724,15 @@ bool AnfImporterFromProtobuf::BuildParameterForFuncGraph( | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ImportParametersForGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto) { | |||||
| bool AnfImporterFromProtobuf::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " | |||||
| << importProto.initializer_size(); | |||||
| MS_LOG(INFO) << "Parameters had default paramerer size is: " << importProto.initializer_size(); | |||||
| for (int i = 0; i < importProto.initializer_size(); ++i) { | for (int i = 0; i < importProto.initializer_size(); ++i) { | ||||
| const onnx::TensorProto &initializer_proto = importProto.initializer(i); | const onnx::TensorProto &initializer_proto = importProto.initializer(i); | ||||
| if (!initializer_proto.has_name()) { | if (!initializer_proto.has_name()) { | ||||
| MS_LOG(ERROR) | |||||
| << "initializer vector of onnx GraphProto has no name at index: " | |||||
| << i; | |||||
| MS_LOG(ERROR) << "initializer vector of onnx GraphProto has no name at index: " << i; | |||||
| return false; | return false; | ||||
| } | } | ||||
| default_para_map_[initializer_proto.name()] = initializer_proto; | default_para_map_[initializer_proto.name()] = initializer_proto; | ||||
| @@ -760,8 +741,7 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph( | |||||
| MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | MS_LOG(INFO) << "all parameters size: " << importProto.input_size(); | ||||
| for (int i = 0; i < importProto.input_size(); ++i) { | for (int i = 0; i < importProto.input_size(); ++i) { | ||||
| const onnx::ValueInfoProto &input_proto = importProto.input(i); | const onnx::ValueInfoProto &input_proto = importProto.input(i); | ||||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), | |||||
| input_proto)) { | |||||
| if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), input_proto)) { | |||||
| MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -769,25 +749,20 @@ bool AnfImporterFromProtobuf::ImportParametersForGraph( | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm( | |||||
| const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == | |||||
| kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" | |||||
| << attr_tensor_type; | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain attr in type-form has not support input type:" << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| prim->AddAttr(attr_name, | |||||
| TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| prim->AddAttr(attr_name, TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( | |||||
| const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| switch (attr_tensor_type) { | switch (attr_tensor_type) { | ||||
| @@ -821,16 +796,14 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInScalarForm( | |||||
| break; | break; | ||||
| } | } | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " | |||||
| << attr_tensor_type; | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( | |||||
| const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | const std::string &tensor_buf = attr_tensor.raw_data(); | ||||
| @@ -840,31 +813,26 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( | |||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | for (int i = 0; i < attr_tensor.dims_size(); ++i) { | ||||
| shape.push_back(attr_tensor.dims(i)); | shape.push_back(attr_tensor.dims(i)); | ||||
| } | } | ||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>( | |||||
| kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor_info->MallocData(); | tensor_info->MallocData(); | ||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | ||||
| ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue(tensor_info)); | prim->set_attr(attr_name, MakeValue(tensor_info)); | ||||
| } else { | } else { | ||||
| if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { | if (attr_tensor_type == onnx::TensorProto_DataType_DOUBLE) { | ||||
| size_t data_size = sizeof(double); | size_t data_size = sizeof(double); | ||||
| double attr_value = 0.0; | double attr_value = 0.0; | ||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue<double>(attr_value)); | prim->set_attr(attr_name, MakeValue<double>(attr_value)); | ||||
| } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { | } else if (attr_tensor_type == onnx::TensorProto_DataType_INT64) { | ||||
| size_t data_size = sizeof(int64_t); | size_t data_size = sizeof(int64_t); | ||||
| int32_t attr_value = 0; | int32_t attr_value = 0; | ||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue<int32_t>(attr_value)); | prim->set_attr(attr_name, MakeValue<int32_t>(attr_value)); | ||||
| } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { | } else if (attr_tensor_type == onnx::TensorProto_DataType_BOOL) { | ||||
| size_t data_size = sizeof(bool); | size_t data_size = sizeof(bool); | ||||
| bool attr_value = false; | bool attr_value = false; | ||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| ret = memcpy_s(&attr_value, data_size, tensor_buf.data(), tensor_buf.size()); | |||||
| prim->set_attr(attr_name, MakeValue<bool>(attr_value)); | prim->set_attr(attr_name, MakeValue<bool>(attr_value)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -872,8 +840,7 @@ bool AnfImporterFromProtobuf::ObtainCNodeAttrInTensorForm( | |||||
| return ret == EOK; | return ret == EOK; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::GetAttrValueForCNode( | |||||
| const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| bool AnfImporterFromProtobuf::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | MS_EXCEPTION_IF_NULL(prim); | ||||
| const std::string &attr_name = attr_proto.name(); | const std::string &attr_name = attr_proto.name(); | ||||
| if (!attr_proto.has_ref_attr_name()) { | if (!attr_proto.has_ref_attr_name()) { | ||||
| @@ -897,20 +864,18 @@ bool AnfImporterFromProtobuf::GetAttrValueForCNode( | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( | |||||
| const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| std::vector<int> shape; | std::vector<int> shape; | ||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | for (int i = 0; i < attr_tensor.dims_size(); ++i) { | ||||
| shape.push_back(attr_tensor.dims(i)); | shape.push_back(attr_tensor.dims(i)); | ||||
| } | } | ||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>( | |||||
| kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor::TensorPtr tensor_info = std::make_shared<tensor::Tensor>(kDefaultValueSwitchMap[attr_tensor_type], shape); | |||||
| tensor_info->MallocData(); | tensor_info->MallocData(); | ||||
| const std::string &tensor_buf = attr_tensor.raw_data(); | const std::string &tensor_buf = attr_tensor.raw_data(); | ||||
| auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | auto *tensor_data_buf = reinterpret_cast<uint8_t *>(tensor_info->Data()); | ||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), | |||||
| tensor_buf.size()); | |||||
| auto ret = memcpy_s(tensor_data_buf, tensor_info->Size(), tensor_buf.data(), tensor_buf.size()); | |||||
| if (EOK != ret) { | if (EOK != ret) { | ||||
| MS_LOG(ERROR) << "memcpy_s error"; | MS_LOG(ERROR) << "memcpy_s error"; | ||||
| return false; | return false; | ||||
| @@ -918,15 +883,14 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInTensorForm( | |||||
| auto new_value_node = NewValueNode(MakeValue(tensor_info)); | auto new_value_node = NewValueNode(MakeValue(tensor_info)); | ||||
| MS_EXCEPTION_IF_NULL(new_value_node); | MS_EXCEPTION_IF_NULL(new_value_node); | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type]); | ||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape); | |||||
| new_value_node->set_abstract(abstract_tensor); | new_value_node->set_abstract(abstract_tensor); | ||||
| anfnode_build_map_[value_node_name] = new_value_node; | anfnode_build_map_[value_node_name] = new_value_node; | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( | |||||
| const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| ValuePtr value_ptr = nullptr; | ValuePtr value_ptr = nullptr; | ||||
| switch (attr_tensor_type) { | switch (attr_tensor_type) { | ||||
| @@ -961,8 +925,7 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( | |||||
| break; | break; | ||||
| } | } | ||||
| default: | default: | ||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " | |||||
| << attr_tensor_type; | |||||
| MS_LOG(ERROR) << "Obtain attr in scalar-form has not support input type: " << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto new_value_node = NewValueNode(value_ptr); | auto new_value_node = NewValueNode(value_ptr); | ||||
| @@ -973,28 +936,23 @@ bool AnfImporterFromProtobuf::ObtainValueNodeInScalarForm( | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm( | |||||
| const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::ObtainValueNodeInTypeForm(const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| const int attr_tensor_type = attr_tensor.data_type(); | const int attr_tensor_type = attr_tensor.data_type(); | ||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == | |||||
| kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) | |||||
| << "Obtain ValueNode attr in type-form has not support input type: " | |||||
| << attr_tensor_type; | |||||
| if (kDefaultValueSwitchMap.find(attr_tensor_type) == kDefaultValueSwitchMap.end()) { | |||||
| MS_LOG(ERROR) << "Obtain ValueNode attr in type-form has not support input type: " << attr_tensor_type; | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto new_value_node = | |||||
| NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| abstract::AbstractTypePtr abs_type = | |||||
| std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||||
| auto new_value_node = NewValueNode(TypeIdToType(kDefaultValueSwitchMap[attr_tensor_type])); | |||||
| abstract::AbstractTypePtr abs_type = std::make_shared<abstract::AbstractType>(std::make_shared<TypeType>()); | |||||
| new_value_node->set_abstract(abs_type); | new_value_node->set_abstract(abs_type); | ||||
| anfnode_build_map_[value_node_name] = new_value_node; | anfnode_build_map_[value_node_name] = new_value_node; | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::GetAttrValueForValueNode( | |||||
| const std::string &ref_attr_name, const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| bool AnfImporterFromProtobuf::GetAttrValueForValueNode(const std::string &ref_attr_name, | |||||
| const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor) { | |||||
| switch (kParseTypeSwitchMap[ref_attr_name]) { | switch (kParseTypeSwitchMap[ref_attr_name]) { | ||||
| case FORM_PARSE_SCALAR: { | case FORM_PARSE_SCALAR: { | ||||
| return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); | return ObtainValueNodeInScalarForm(value_node_name, attr_tensor); | ||||
| @@ -1006,14 +964,12 @@ bool AnfImporterFromProtobuf::GetAttrValueForValueNode( | |||||
| return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); | return ObtainValueNodeInTypeForm(value_node_name, attr_tensor); | ||||
| } | } | ||||
| default: | default: | ||||
| MS_LOG(ERROR) | |||||
| << "parse ValueNode value don't support input of ref_attr_name"; | |||||
| MS_LOG(ERROR) << "parse ValueNode value don't support input of ref_attr_name"; | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( | |||||
| const onnx::NodeProto &node_proto) { | |||||
| bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto) { | |||||
| const std::string &value_node_name = node_proto.output(0); | const std::string &value_node_name = node_proto.output(0); | ||||
| const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | const onnx::AttributeProto &attr_proto = node_proto.attribute(0); | ||||
| if (!attr_proto.has_ref_attr_name()) { | if (!attr_proto.has_ref_attr_name()) { | ||||
| @@ -1026,23 +982,21 @@ bool AnfImporterFromProtobuf::BuildValueNodeForFuncGraph( | |||||
| return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); | return GetAttrValueForValueNode(ref_attr_name, value_node_name, attr_tensor); | ||||
| } | } | ||||
| abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode( | |||||
| const onnx::AttributeProto &attr_proto) { | |||||
| abstract::AbstractTensorPtr AnfImporterFromProtobuf::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { | |||||
| std::vector<int> shape_vec; | std::vector<int> shape_vec; | ||||
| const onnx::TensorProto &attr_tensor = attr_proto.t(); | const onnx::TensorProto &attr_tensor = attr_proto.t(); | ||||
| for (int i = 0; i < attr_tensor.dims_size(); ++i) { | for (int i = 0; i < attr_tensor.dims_size(); ++i) { | ||||
| shape_vec.push_back(attr_tensor.dims(i)); | shape_vec.push_back(attr_tensor.dims(i)); | ||||
| } | } | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[attr_tensor.data_type()]); | ||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, shape_vec); | |||||
| MS_EXCEPTION_IF_NULL(abstract_tensor); | MS_EXCEPTION_IF_NULL(abstract_tensor); | ||||
| return abstract_tensor; | return abstract_tensor; | ||||
| } | } | ||||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType) { | |||||
| CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| if (!node_proto.has_op_type()) { | if (!node_proto.has_op_type()) { | ||||
| MS_LOG(ERROR) << "Get CNode op_type failed!"; | MS_LOG(ERROR) << "Get CNode op_type failed!"; | ||||
| @@ -1082,23 +1036,20 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( | |||||
| for (int i = 0; i < node_proto.input_size(); ++i) { | for (int i = 0; i < node_proto.input_size(); ++i) { | ||||
| const std::string &input_name = node_proto.input(i); | const std::string &input_name = node_proto.input(i); | ||||
| if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { | ||||
| MS_LOG(ERROR) << node_name << " input " << i << input_name | |||||
| << "can't find in nodes have parsed"; | |||||
| MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| inputs.push_back(anfnode_build_map_[input_name]); | inputs.push_back(anfnode_build_map_[input_name]); | ||||
| } | } | ||||
| std::string opType = prim->name(); | std::string opType = prim->name(); | ||||
| auto node_parser = | |||||
| AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| auto node_parser = AnfNodePopulaterRegistry::GetInstance()->GetNodePopulater(opType); | |||||
| if (node_parser == nullptr) { | if (node_parser == nullptr) { | ||||
| MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | MS_LOG(ERROR) << "Find op parser failed, opType: " << opType; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto primitiveT = std::make_unique<schema::PrimitiveT>(); | auto primitiveT = std::make_unique<schema::PrimitiveT>(); | ||||
| // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); | // auto * primitiveTValue = new PrimitiveTValue(primitiveT.release()); | ||||
| std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = | |||||
| std::make_shared<PrimitiveTValue>(primitiveT.release()); | |||||
| std::shared_ptr<PrimitiveTValue> primitiveTValuePtr = std::make_shared<PrimitiveTValue>(primitiveT.release()); | |||||
| primitiveTValuePtr->SetQuantType(quantType); | primitiveTValuePtr->SetQuantType(quantType); | ||||
| node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); | node_parser->Populate(prim, primitiveTValuePtr.get(), inputs); | ||||
| MS_ASSERT(primitiveTValuePtr != nullptr); | MS_ASSERT(primitiveTValuePtr != nullptr); | ||||
| @@ -1130,9 +1081,8 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph( | |||||
| return cnode_ptr; | return cnode_ptr; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const CNodePtr &cnode_ptr) { | |||||
| bool AnfImporterFromProtobuf::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, const CNodePtr &cnode_ptr) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| MS_EXCEPTION_IF_NULL(cnode_ptr); | MS_EXCEPTION_IF_NULL(cnode_ptr); | ||||
| std::vector<AnfNodePtr> inputs; | std::vector<AnfNodePtr> inputs; | ||||
| @@ -1147,8 +1097,7 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( | |||||
| elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | elem.push_back(anfnode_build_map_[out_tuple]->abstract()); | ||||
| } | } | ||||
| auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | auto maketuple_ptr = outputFuncGraph->NewCNode(inputs); | ||||
| maketuple_ptr->set_abstract( | |||||
| std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| maketuple_ptr->set_abstract(std::make_shared<abstract::AbstractTuple>(elem)); | |||||
| inputs.clear(); | inputs.clear(); | ||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | inputs.push_back(NewValueNode(prim::kPrimReturn)); | ||||
| inputs.push_back(maketuple_ptr); | inputs.push_back(maketuple_ptr); | ||||
| @@ -1161,14 +1110,11 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( | |||||
| const onnx::TypeProto &output_typeproto = output_node.type(); | const onnx::TypeProto &output_typeproto = output_node.type(); | ||||
| int output_type = output_typeproto.tensor_type().elem_type(); | int output_type = output_typeproto.tensor_type().elem_type(); | ||||
| std::vector<int> output_shape; | std::vector<int> output_shape; | ||||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); | |||||
| ++i) { | |||||
| output_shape.push_back( | |||||
| output_typeproto.tensor_type().shape().dim(i).dim_value()); | |||||
| for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { | |||||
| output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); | |||||
| } | } | ||||
| auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | auto type_ptr = TypeIdToType(kDefaultValueSwitchMap[output_type]); | ||||
| auto abstract_tensor = | |||||
| std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | |||||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, output_shape); | |||||
| inputs.clear(); | inputs.clear(); | ||||
| inputs.push_back(NewValueNode(prim::kPrimReturn)); | inputs.push_back(NewValueNode(prim::kPrimReturn)); | ||||
| @@ -1182,9 +1128,9 @@ bool AnfImporterFromProtobuf::BuildReturnForFuncGraph( | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ImportNodesForGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| bool AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); | ||||
| CNodePtr cnode_ptr = nullptr; | CNodePtr cnode_ptr = nullptr; | ||||
| @@ -1210,9 +1156,8 @@ bool AnfImporterFromProtobuf::ImportNodesForGraph( | |||||
| } | } | ||||
| #endif | #endif | ||||
| bool AnfImporterFromProtobuf::BuildFuncGraph( | |||||
| const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| bool AnfImporterFromProtobuf::BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType) { | |||||
| MS_EXCEPTION_IF_NULL(outputFuncGraph); | MS_EXCEPTION_IF_NULL(outputFuncGraph); | ||||
| GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | GraphDebugInfoPtr debug_info_ptr = outputFuncGraph->debug_info(); | ||||
| MS_EXCEPTION_IF_NULL(debug_info_ptr); | MS_EXCEPTION_IF_NULL(debug_info_ptr); | ||||
| @@ -1228,8 +1173,7 @@ bool AnfImporterFromProtobuf::BuildFuncGraph( | |||||
| return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | return ImportNodesForGraph(outputFuncGraph, importProto, quantType); | ||||
| } | } | ||||
| bool AnfImporterFromProtobuf::ParseModelConfigureInfo( | |||||
| const onnx::ModelProto &model_proto) { | |||||
| bool AnfImporterFromProtobuf::ParseModelConfigureInfo(const onnx::ModelProto &model_proto) { | |||||
| if (!model_proto.has_producer_name()) { | if (!model_proto.has_producer_name()) { | ||||
| MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | MS_LOG(ERROR) << "Parse model producer name from pb file failed!"; | ||||
| return false; | return false; | ||||
| @@ -1267,8 +1211,7 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary( | |||||
| const std::string &model_path) { | |||||
| onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { | |||||
| std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0}); | ||||
| if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { | if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { | ||||
| MS_LOG(ERROR) << "open file failed."; | MS_LOG(ERROR) << "open file failed."; | ||||
| @@ -22,15 +22,15 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include "abstract/abstract_value.h" | |||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "tools/anf_importer/anf_importer.h" | |||||
| #include "abstract/abstract_value.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| class AnfImporterFromProtobuf : public AnfImporter { | class AnfImporterFromProtobuf : public AnfImporter { | ||||
| public: | public: | ||||
| explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, | |||||
| FuncGraphPtr func_graph) | |||||
| explicit AnfImporterFromProtobuf(onnx::ModelProto *onnx_model, FuncGraphPtr func_graph) | |||||
| : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} | : onnx_model_(onnx_model), func_graph_(std::move(func_graph)) {} | ||||
| ~AnfImporterFromProtobuf() override = default; | ~AnfImporterFromProtobuf() override = default; | ||||
| @@ -39,16 +39,14 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| FuncGraphPtr GetResult() override; | FuncGraphPtr GetResult() override; | ||||
| int Import(const schema::QuantType &quantType = | |||||
| schema::QuantType_QUANT_NONE) override; | |||||
| int Import(const schema::QuantType &quantType = schema::QuantType_QUANT_NONE) override; | |||||
| private: | private: | ||||
| void ConverterConstTensor() override{}; | |||||
| int ConverterCNode() override{}; | |||||
| void AddReturnCNode() override{}; | |||||
| int ConverterConstTensor() override{ return RET_ERROR; }; | |||||
| int ConverterCNode() override{ return RET_ERROR; }; | |||||
| int AddReturnCNode() override{ return RET_ERROR; }; | |||||
| bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | bool ParseModelConfigureInfo(const onnx::ModelProto &model_proto); | ||||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| bool BuildFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| #if 0 | #if 0 | ||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | ||||
| @@ -81,43 +79,29 @@ class AnfImporterFromProtobuf : public AnfImporter { | |||||
| std::unordered_map<std::string, abstract::AbstractTensorPtr> | std::unordered_map<std::string, abstract::AbstractTensorPtr> | ||||
| GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | ||||
| #else | #else | ||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto); | |||||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| bool ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto); | |||||
| bool ImportNodesForGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, | |||||
| const onnx::ValueInfoProto &value_proto); | |||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::NodeProto &node_proto, | |||||
| bool BuildParameterForFuncGraph(const ParameterPtr &node, const onnx::ValueInfoProto &value_proto); | |||||
| CNodePtr BuildCNodeForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::NodeProto &node_proto, | |||||
| const schema::QuantType &quantType); | const schema::QuantType &quantType); | ||||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, | |||||
| const onnx::GraphProto &importProto, | |||||
| bool BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGraph, const onnx::GraphProto &importProto, | |||||
| const CNodePtr &cnode_ptr); | const CNodePtr &cnode_ptr); | ||||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, | |||||
| const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, | |||||
| const std::string &attr_name, | |||||
| bool GetAttrValueForCNode(const PrimitivePtr &prim, const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainCNodeAttrInTypeForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, | |||||
| const std::string &attr_name, | |||||
| bool ObtainCNodeAttrInScalarForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, | |||||
| const std::string &attr_name, | |||||
| bool ObtainCNodeAttrInTensorForm(const PrimitivePtr &prim, const std::string &attr_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); | bool BuildValueNodeForFuncGraph(const onnx::NodeProto &node_proto); | ||||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| bool ObtainValueNodeInTensorForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool ObtainValueNodeInScalarForm(const string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForValueNode(const string &ref_attr_name, | |||||
| const std::string &value_node_name, | |||||
| bool ObtainValueNodeInScalarForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| bool GetAttrValueForValueNode(const string &ref_attr_name, const std::string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | const onnx::TensorProto &attr_tensor); | ||||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, | |||||
| const onnx::TensorProto &attr_tensor); | |||||
| abstract::AbstractTensorPtr GetAbstractForCNode( | |||||
| const onnx::AttributeProto &attr_proto); | |||||
| bool ObtainValueNodeInTypeForm(const string &value_node_name, const onnx::TensorProto &attr_tensor); | |||||
| abstract::AbstractTensorPtr GetAbstractForCNode(const onnx::AttributeProto &attr_proto); | |||||
| #endif | #endif | ||||
| @@ -70,7 +70,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc | ||||
| # ${CMAKE_CURRENT_SOURCE_DIR}/../../src/common/anf_exporter/anf_exporter.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc | |||||
| ../optimizer/common/node_pass_extends.cc | ../optimizer/common/node_pass_extends.cc | ||||
| ../optimizer/common/pass_manager_extends.cc | ../optimizer/common/pass_manager_extends.cc | ||||
| @@ -83,6 +83,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||||
| ../optimizer/fusion/constant_folding_fusion.cc | ../optimizer/fusion/constant_folding_fusion.cc | ||||
| ) | ) | ||||
| add_subdirectory(../anf_importer anf_importer) | |||||
| add_subdirectory(../anf_exporter anf_exporter) | |||||
| add_subdirectory(parser/caffe) | add_subdirectory(parser/caffe) | ||||
| add_subdirectory(parser/tflite) | add_subdirectory(parser/tflite) | ||||
| add_subdirectory(parser/onnx) | add_subdirectory(parser/onnx) | ||||
| @@ -100,6 +102,7 @@ target_link_libraries(converter_lite PRIVATE | |||||
| caffe_parser_mid | caffe_parser_mid | ||||
| onnx_parser_mid | onnx_parser_mid | ||||
| anf_importer_mid | anf_importer_mid | ||||
| anf_exporter_mid | |||||
| node_mid | node_mid | ||||
| graph_pass_mid | graph_pass_mid | ||||
| fusion_mid | fusion_mid | ||||
| @@ -28,8 +28,8 @@ | |||||
| #include "parser/caffe/caffe_converter.h" | #include "parser/caffe/caffe_converter.h" | ||||
| #include "parser/tflite/tflite_converter.h" | #include "parser/tflite/tflite_converter.h" | ||||
| #include "parser/onnx/onnx_converter.h" | #include "parser/onnx/onnx_converter.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "src/common/anf_importer/import_from_protobuf.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_importer/import_from_protobuf.h" | |||||
| #include "tools/converter/parser/onnx/onnx.pb.h" | #include "tools/converter/parser/onnx/onnx.pb.h" | ||||
| #include "tools/converter/quantizer/weight_quantizer.h" | #include "tools/converter/quantizer/weight_quantizer.h" | ||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "tools/converter/graphdef_transform.h" | #include "tools/converter/graphdef_transform.h" | ||||
| #include "tools/converter/model_parser.h" | #include "tools/converter/model_parser.h" | ||||
| #include "src/common/anf_importer/anf_importer.h" | |||||
| #include "tools/anf_importer/anf_importer.h" | |||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| #include "tools/converter/anf_transform.h" | #include "tools/converter/anf_transform.h" | ||||
| #include "tools/converter/quantizer/quantizer.h" | #include "tools/converter/quantizer/quantizer.h" | ||||
| @@ -1,5 +1,6 @@ | |||||
| add_library(graph_pass_mid OBJECT | add_library(graph_pass_mid OBJECT | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/eltwise_format_trans_pass.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc | ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc | ||||
| @@ -0,0 +1,200 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| #include <utility> | |||||
| #include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h" | |||||
| #include "tools/common/converter_op_utils.h" | |||||
| #include "tools/common/node_util.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "src/common/common.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #define kMinInputNum 1 | |||||
| #define kOutputNum 1 | |||||
| STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) { | |||||
| if (fmkType == converter::FmkType_TF) { | |||||
| return RET_OK; | |||||
| } | |||||
| MS_ASSERT(graph != nullptr); | |||||
| auto status = DoModelInputFormatTrans(graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoModelInputFormatTrans failed : " << status; | |||||
| return status; | |||||
| } | |||||
| status = DoNodeInoutFormatTrans(graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "DoNodeInoutFormatTrans failed : " << status; | |||||
| return status; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS EltwiseFormatTransPass::DoModelInputFormatTrans(schema::MetaGraphT *graph) { | |||||
| if (fmkType == converter::FmkType_TF || fmkType == converter::FmkType_TFLITE) { | |||||
| return RET_OK; | |||||
| } | |||||
| MS_ASSERT(graph != nullptr); | |||||
| // insert trans node in model input tensor | |||||
| if (graph->nodes.empty()) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto graphInputIdxes = graph->inputIndex; | |||||
| for (size_t i = 0; i < graphInputIdxes.size(); i++) { | |||||
| auto inputIdx = graphInputIdxes.at(i); | |||||
| MS_ASSERT(inputIdx < subGraph->allTensors.size()); | |||||
| auto &tensor = graph->allTensors.at(inputIdx); | |||||
| if (tensor->dims.size() != kNCHWDimNumber) { | |||||
| continue; | |||||
| } | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||||
| auto &node = *iter; | |||||
| for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { | |||||
| if (node->inputIndex.at(inputIndexIdx) == inputIdx) { | |||||
| STATUS status = RET_OK; | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, inputIndexIdx, kNHWC2NCHW, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << (*iter)->name << " failed"; | |||||
| return status; | |||||
| } | |||||
| // set first tensor format to nhwc | |||||
| auto &transNode = *(iter - 1); | |||||
| MS_ASSERT(transNode != nullptr); | |||||
| MS_ASSERT(transNode->inputIndex.size() == 1); | |||||
| MS_ASSERT(subGraph->allTensors.size() > transNode->inputIndex.front()); | |||||
| auto &graphInTensor = graph->allTensors.at(transNode->inputIndex.front()); | |||||
| graphInTensor->format = schema::Format_NHWC; | |||||
| // assume parser not reformat shape | |||||
| auto oldDims = graphInTensor->dims; | |||||
| graphInTensor->dims = {oldDims[NCHW_N], oldDims[NCHW_H], oldDims[NCHW_W], oldDims[NCHW_C]}; | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| // inference needed inputFormat: | |||||
| // conv deconv depth dedepth | |||||
| // fp32 NCHW NCHW NCHW NCHW | |||||
| // uint8 NCHW ? NCHW ? | |||||
| STATUS EltwiseFormatTransPass::DoNodeInoutFormatTrans(schema::MetaGraphT *graph) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| // insert before and after the op cal by nchw/nc4hw4 | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||||
| FormatTransNodeType beforeNodeType, afterNodeType; | |||||
| if (fmkType == converter::FmkType_TFLITE) { // inference by nhwc | |||||
| // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use | |||||
| // nhwc | |||||
| // if (IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only | |||||
| // support nhwc | |||||
| // continue; | |||||
| // } | |||||
| // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { | |||||
| // continue; | |||||
| // } | |||||
| // } else { | |||||
| // if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { | |||||
| continue; | |||||
| // } | |||||
| // } | |||||
| // beforeNodeType = kNCHW2NHWC; | |||||
| // afterNodeType = kNHWC2NCHW; | |||||
| } else if (fmkType == converter::FmkType_CAFFE) { // inference by nchw | |||||
| // if (quantType == QuantType_AwareTrainning) { // awaretrainning op use nhwc | |||||
| // if (!IsContain(GetUint8NhwcOpList(), GetCNodeTType(**iter))) { // uint8NhwcOp only support nhwc | |||||
| // continue; | |||||
| // } | |||||
| // } else { | |||||
| // continue; | |||||
| // } | |||||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { | |||||
| continue; | |||||
| } | |||||
| beforeNodeType = kNCHW2NHWC; | |||||
| afterNodeType = kNHWC2NCHW; | |||||
| } else if (fmkType == converter::FmkType_MS) { | |||||
| if (!IsContain(GetNhwcOpList(), GetCNodeTType(**iter))) { | |||||
| continue; | |||||
| } | |||||
| beforeNodeType = kNCHW2NHWC; | |||||
| afterNodeType = kNHWC2NCHW; | |||||
| } else { | |||||
| MS_LOG(ERROR) << "Unsupported fmk: " << fmkType; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto &node = *iter; | |||||
| auto nodeName = node->name; | |||||
| if (node->inputIndex.size() < kMinInputNum) { | |||||
| MS_LOG(ERROR) << "Op should have " << kMinInputNum << " input tensor at least"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (node->outputIndex.size() != kOutputNum) { | |||||
| MS_LOG(ERROR) << "Op should have " << kOutputNum << " output tensor"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| STATUS status; | |||||
| iter = InsertFormatTransNode(graph, iter, kBefore, 0, beforeNodeType, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode before " << nodeName << "failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| iter = InsertFormatTransNode(graph, iter, kAfter, 0, afterNodeType, &status); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "InsertNhwc2NchwNode after " << nodeName << "failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, | |||||
| InsertPlace place, size_t inoutIdx, FormatTransNodeType nodeType, | |||||
| STATUS *errorCode) { | |||||
| MS_ASSERT((*existNodeIter) != nullptr); | |||||
| auto existNodeName = (*existNodeIter)->name; | |||||
| std::string tileName; | |||||
| if (place == kBefore) { | |||||
| tileName = existNodeName + "_pre"; | |||||
| } else { | |||||
| tileName = existNodeName + "_post"; | |||||
| } | |||||
| auto transNode = std::make_unique<schema::CNodeT>(); | |||||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (nodeType == kNCHW2NHWC) { | |||||
| transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); | |||||
| transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; | |||||
| } else { | |||||
| transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); | |||||
| transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; | |||||
| } | |||||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); | |||||
| } | |||||
| void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||||
| void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,56 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||||
| #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||||
| #include "tools/converter/optimizer.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "tools/converter/converter_flags.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| enum FormatTransNodeType { kNCHW2NHWC, kNHWC2NCHW }; | |||||
| class EltwiseFormatTransPass : public GraphPass { | |||||
| public: | |||||
| EltwiseFormatTransPass() : id(0) {} | |||||
| ~EltwiseFormatTransPass() override = default; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| void SetQuantType(QuantType quantType); | |||||
| void SetFmk(converter::FmkType fmkType); | |||||
| private: | |||||
| STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph); | |||||
| STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph); | |||||
| NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx, | |||||
| FormatTransNodeType nodeType, STATUS *errorCode); | |||||
| private: | |||||
| size_t id; | |||||
| QuantType quantType = QuantType_QUANT_NONE; | |||||
| converter::FmkType fmkType = converter::FmkType_TF; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||||
| @@ -20,7 +20,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/common/anf_importer/import_from_meta_graphT.h" | |||||
| #include "tools/anf_importer/import_from_meta_graphT.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| @@ -12,7 +12,6 @@ add_library(quantizer_mid OBJECT | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../../src/common/anf_exporter/anf_exporter.cc | |||||
| ) | ) | ||||
| if(ENABLE_ASAN) | if(ENABLE_ASAN) | ||||
| @@ -27,7 +27,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "src/ir/tensor.h" | #include "src/ir/tensor.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| #include "tools/converter/quantizer/post_training_quantizer.h" | #include "tools/converter/quantizer/post_training_quantizer.h" | ||||
| #include "tools/converter/quantizer/quantize_util.h" | #include "tools/converter/quantizer/quantize_util.h" | ||||
| #include "src/common/common.h" | #include "src/common/common.h" | ||||
| @@ -283,7 +283,7 @@ void CheckLeastInputSize(const CNodePtr &node, const int size) { | |||||
| } | } | ||||
| } | } | ||||
| AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, | |||||
| ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, | |||||
| const ParamValueLitePtr &weight_tensor) { | const ParamValueLitePtr &weight_tensor) { | ||||
| auto bias_parameter = func_graph->add_parameter(); | auto bias_parameter = func_graph->add_parameter(); | ||||
| MS_ASSERT(bias_parameter != nullptr); | MS_ASSERT(bias_parameter != nullptr); | ||||
| @@ -47,7 +47,7 @@ void CheckIfNodeIsParam(const AnfNodePtr &node); | |||||
| void CheckLeastInputSize(const CNodePtr &node, int size); | void CheckLeastInputSize(const CNodePtr &node, int size); | ||||
| AnfNodePtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, | |||||
| ParameterPtr AddNewBiasNode(float *bias_data, const FuncGraphPtr &func_graph, int kernel_num, | |||||
| const ParamValueLitePtr &weight_tensor); | const ParamValueLitePtr &weight_tensor); | ||||
| schema::PrimitiveType GetCNodeType(const BaseRef &node); | schema::PrimitiveType GetCNodeType(const BaseRef &node); | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| #include "tools/optimizer/common/gllo_utils.h" | #include "tools/optimizer/common/gllo_utils.h" | ||||
| #include "src/kernel_factory.h" | #include "src/kernel_factory.h" | ||||
| #include "src/common/anf_exporter/anf_exporter.h" | |||||
| #include "tools/anf_exporter/anf_exporter.h" | |||||
| #include "src/scheduler.h" | #include "src/scheduler.h" | ||||
| #include "include/context.h" | #include "include/context.h" | ||||
| #include "src/lite_session.h" | #include "src/lite_session.h" | ||||
| @@ -38,7 +38,7 @@ const std::vector<Tensor *> GetCNodeInputTensors(const CNodePtr &CNode) { | |||||
| auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>(); | auto tmp_meta_graph = std::make_unique<schema::MetaGraphT>(); | ||||
| auto tmp_fb_node = std::make_unique<schema::CNodeT>(); | auto tmp_fb_node = std::make_unique<schema::CNodeT>(); | ||||
| lite::AnfExporter anfExporter; | lite::AnfExporter anfExporter; | ||||
| anfExporter.SetOpInputNode(CNode, tmp_meta_graph.get(), tmp_fb_node.get()); | |||||
| anfExporter.SetOpInputNode(CNode, tmp_meta_graph, tmp_fb_node.get()); | |||||
| std::vector<Tensor *> input_tensors; | std::vector<Tensor *> input_tensors; | ||||
| for (auto input_index : tmp_fb_node->inputIndex) { | for (auto input_index : tmp_fb_node->inputIndex) { | ||||
| auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); | auto tensorT = tmp_meta_graph->allTensors.at(input_index).get(); | ||||
| @@ -33,8 +33,8 @@ constexpr size_t kConvWithBiasLen = 4; | |||||
| bool IsConvExtendNode(const BaseRef &n) { | bool IsConvExtendNode(const BaseRef &n) { | ||||
| if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) { | ||||
| auto type = opt::GetCNodeType(n); | auto type = opt::GetCNodeType(n); | ||||
| return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D | |||||
| || type == schema::PrimitiveType_DeConv2D; | |||||
| return type == schema::PrimitiveType_Conv2D || type == schema::PrimitiveType_DepthwiseConv2D || | |||||
| type == schema::PrimitiveType_DeConv2D; | |||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -59,8 +59,8 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { | |||||
| if (type == schema::PrimitiveType_Conv2D) { | if (type == schema::PrimitiveType_Conv2D) { | ||||
| return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; | return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; | ||||
| } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier | |||||
| * primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; | |||||
| return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * | |||||
| primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; | |||||
| } else if (type == schema::PrimitiveType_DeConv2D) { | } else if (type == schema::PrimitiveType_DeConv2D) { | ||||
| return primitive->GetPrimitiveT()->value.AsDeConv2D()->channelOut; | return primitive->GetPrimitiveT()->value.AsDeConv2D()->channelOut; | ||||
| } else { | } else { | ||||
| @@ -83,16 +83,16 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c | |||||
| if (kernel_nums <= 0) { | if (kernel_nums <= 0) { | ||||
| MS_LOG(EXCEPTION) << "kernel num less than 0"; | MS_LOG(EXCEPTION) << "kernel num less than 0"; | ||||
| } | } | ||||
| auto add_bias_data = new(std::nothrow) float[kernel_nums]; | |||||
| auto add_bias_data = new (std::nothrow) float[kernel_nums]; | |||||
| auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); | auto bias_add_weight = bias_node->input(kAddWEIGHTINDEX); | ||||
| CheckIfNodeIsParam(bias_add_weight); | CheckIfNodeIsParam(bias_add_weight); | ||||
| auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param(); | auto add_weight_param = bias_add_weight->cast<ParameterPtr>()->default_param(); | ||||
| auto add_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(add_weight_param); | auto add_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(add_weight_param); | ||||
| auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->tensor_addr()); | auto add_weight_data = reinterpret_cast<float *>(add_weight_tensor->tensor_addr()); | ||||
| auto add_weight_shape = add_weight_tensor->tensor_shape(); | auto add_weight_shape = add_weight_tensor->tensor_shape(); | ||||
| if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] ==1)) { | |||||
| for (size_t i = 0; i < kernel_nums; i++) { | |||||
| add_bias_data[i] = *add_weight_data; | |||||
| if (add_weight_shape.empty() || (add_weight_shape.size() == 1 && add_weight_shape[0] == 1)) { | |||||
| for (size_t i = 0; i < kernel_nums; i++) { | |||||
| add_bias_data[i] = *add_weight_data; | |||||
| } | } | ||||
| } else { | } else { | ||||
| if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { | if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { | ||||
| @@ -115,6 +115,7 @@ void GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, c | |||||
| auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param(); | auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param(); | ||||
| auto conv_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param); | auto conv_weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param); | ||||
| auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); | auto conv_new_bias = AddNewBiasNode(add_bias_data, func_graph, kernel_nums, conv_weight_tensor); | ||||
| conv_new_bias->set_name(conv_node->fullname_with_scope() + "_bias"); | |||||
| conv_node->add_input(conv_new_bias); | conv_node->add_input(conv_new_bias); | ||||
| } | } | ||||
| } | } | ||||
| @@ -159,4 +160,3 @@ const AnfNodePtr ConvBiasaddFusion::Process(const FuncGraphPtr &func_graph, cons | |||||
| return conv_node; | return conv_node; | ||||
| } | } | ||||
| } // namespace mindspore::opt | } // namespace mindspore::opt | ||||
| @@ -44,8 +44,8 @@ int Get_Kenrnel_nums(const CNodePtr &conv_node) { | |||||
| if (type == schema::PrimitiveType_Conv2D) { | if (type == schema::PrimitiveType_Conv2D) { | ||||
| return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; | return primitive->GetPrimitiveT()->value.AsConv2D()->channelOut; | ||||
| } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | } else if (type == schema::PrimitiveType_DepthwiseConv2D) { | ||||
| return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier | |||||
| * primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; | |||||
| return primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelMultiplier * | |||||
| primitive->GetPrimitiveT()->value.AsDepthwiseConv2D()->channelIn; | |||||
| } else { | } else { | ||||
| MS_LOG(ERROR) << "Unsupported opType, " << type; | MS_LOG(ERROR) << "Unsupported opType, " << type; | ||||
| return 0; | return 0; | ||||
| @@ -74,8 +74,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||||
| MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString(); | MS_LOG(ERROR) << "Unsupported conv node, " << conv_node->DebugString(); | ||||
| return node; | return node; | ||||
| } | } | ||||
| auto trans_scale = new(std::nothrow) float[kernel_nums]; | |||||
| auto trans_bias = new(std::nothrow) float[kernel_nums]; | |||||
| auto trans_scale = new (std::nothrow) float[kernel_nums]; | |||||
| auto trans_bias = new (std::nothrow) float[kernel_nums]; | |||||
| GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias); | GenTransParam(transform_node, kernel_nums, trans_scale, trans_bias); | ||||
| GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); | GenNewConvTensor(func_graph, conv_node, kernel_nums, trans_scale, trans_bias); | ||||
| delete[] trans_bias; | delete[] trans_bias; | ||||
| @@ -93,8 +93,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co | |||||
| return pre_node; | return pre_node; | ||||
| } | } | ||||
| const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, | |||||
| float *trans_scale, float *trans_bias) const { | |||||
| const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale, | |||||
| float *trans_bias) const { | |||||
| if (trans_scale == nullptr) { | if (trans_scale == nullptr) { | ||||
| MS_LOG(EXCEPTION) << "new transScale failed"; | MS_LOG(EXCEPTION) << "new transScale failed"; | ||||
| } | } | ||||
| @@ -112,8 +112,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in | |||||
| } | } | ||||
| const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, | const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, | ||||
| int kernel_num, const float *trans_scale, const float *trans_bias) | |||||
| const { | |||||
| int kernel_num, const float *trans_scale, | |||||
| const float *trans_bias) const { | |||||
| MS_ASSERT(conv_node != nullptr); | MS_ASSERT(conv_node != nullptr); | ||||
| AnfNodePtr conv_weight_node = nullptr; | AnfNodePtr conv_weight_node = nullptr; | ||||
| AnfNodePtr conv_bias_node = nullptr; | AnfNodePtr conv_bias_node = nullptr; | ||||
| @@ -152,18 +152,19 @@ const { | |||||
| bias_data = reinterpret_cast<float *>(bias_tensor->tensor_addr()); | bias_data = reinterpret_cast<float *>(bias_tensor->tensor_addr()); | ||||
| bias_flag = true; | bias_flag = true; | ||||
| } else { | } else { | ||||
| bias_data = new(std::nothrow) float[kernel_num]; | |||||
| bias_data = new (std::nothrow) float[kernel_num]; | |||||
| } | } | ||||
| CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias); | CalNewBiasTensor(bias_data, kernel_num, bias_flag, trans_scale, trans_bias); | ||||
| if (!bias_flag) { | if (!bias_flag) { | ||||
| auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor); | auto bias_node = AddNewBiasNode(bias_data, func_graph, kernel_num, weight_tensor); | ||||
| bias_node->set_name(conv_node->fullname_with_scope() + "_bias"); | |||||
| conv_node->add_input(bias_node); | conv_node->add_input(bias_node); | ||||
| } | } | ||||
| } | } | ||||
| const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, | const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size, | ||||
| const float *trans_scale) const { | const float *trans_scale) const { | ||||
| MS_ASSERT(weight_data != nullptr); | MS_ASSERT(weight_data != nullptr); | ||||
| auto tmp_weight_data = new(std::nothrow) float[kernel_num * kernel_size]; | |||||
| auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size]; | |||||
| MS_ASSERT(new_weight_data != nullptr); | MS_ASSERT(new_weight_data != nullptr); | ||||
| auto data_size = kernel_num * kernel_size * sizeof(float); | auto data_size = kernel_num * kernel_size * sizeof(float); | ||||
| if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) { | if (0 != memset_s(tmp_weight_data, data_size, 0, data_size)) { | ||||
| @@ -189,7 +190,7 @@ const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_nu | |||||
| const float *trans_scale, const float *trans_bias) const { | const float *trans_scale, const float *trans_bias) const { | ||||
| MS_ASSERT(bias_data != nullptr); | MS_ASSERT(bias_data != nullptr); | ||||
| if (bias_flag) { | if (bias_flag) { | ||||
| auto tmp_bias_data = new(std::nothrow) float[kernel_num]; | |||||
| auto tmp_bias_data = new (std::nothrow) float[kernel_num]; | |||||
| if (EOK != memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) { | if (EOK != memset_s(tmp_bias_data, kernel_num * sizeof(float), 0, kernel_num * sizeof(float))) { | ||||
| MS_LOG(EXCEPTION) << "memset bias data failed"; | MS_LOG(EXCEPTION) << "memset bias data failed"; | ||||
| } | } | ||||