From: @xu_anyue Reviewed-by: Signed-off-by:tags/v1.1.0
| @@ -536,7 +536,7 @@ build_lite() | |||||
| -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | ||||
| -DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \ | -DPLATFORM_ARM64=on -DENABLE_NEON=on -DENABLE_FP16="off" \ | ||||
| -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | ||||
| -DSUPPORT_GPU=${LITE_ENABLE_GPU} -DSUPPORT_NPU=${LITE_ENABLE_NPU} \ | |||||
| -DSUPPORT_GPU=${LITE_ENABLE_GPU} -DSUPPORT_NPU=${LITE_ENABLE_NPU} -DENABLE_V0=on \ | |||||
| -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | ||||
| -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ | -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ | ||||
| -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | ||||
| @@ -548,7 +548,7 @@ build_lite() | |||||
| -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ | -DANDROID_STL=${ANDROID_STL} -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ | ||||
| -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | -DPLATFORM_ARM32=on -DENABLE_NEON=on -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | ||||
| -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | ||||
| -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \ | |||||
| -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} -DENABLE_V0=on \ | |||||
| -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | ||||
| -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ | -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp -DMS_VERSION_MAJOR=${VERSION_MAJOR} \ | ||||
| -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} -DENABLE_VERBOSE=${ENABLE_VERBOSE} \ | ||||
| @@ -557,7 +557,7 @@ build_lite() | |||||
| cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | cmake -DPLATFORM_ARM64=off -DSUPPORT_TRAIN=${SUPPORT_TRAIN} \ | ||||
| -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | -DENABLE_TOOLS=${ENABLE_TOOLS} -DENABLE_CONVERTER=${ENABLE_CONVERTER} -DBUILD_TESTCASES=${RUN_TESTCASES} \ | ||||
| -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \ | -DCMAKE_BUILD_TYPE=${BUILD_TYPE} -DSUPPORT_GPU=${ENABLE_GPU} -DSUPPORT_NPU=${ENABLE_NPU} \ | ||||
| -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} \ | |||||
| -DBUILD_MINDDATA=${COMPILE_MINDDATA_LITE} -DENABLE_V0=on \ | |||||
| -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ | -DOFFLINE_COMPILE=${OPENCL_OFFLINE_COMPILE} -DCMAKE_INSTALL_PREFIX=${BASEPATH}/output/tmp \ | ||||
| -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | -DMS_VERSION_MAJOR=${VERSION_MAJOR} -DMS_VERSION_MINOR=${VERSION_MINOR} -DMS_VERSION_REVISION=${VERSION_REVISION} \ | ||||
| -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite" | -DENABLE_VERBOSE=${ENABLE_VERBOSE} -DX86_64_SIMD=${X86_64_SIMD} "${BASEPATH}/mindspore/lite" | ||||
| @@ -19,9 +19,12 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include "schema/model_generated.h" | |||||
| #include "include/ms_tensor.h" | #include "include/ms_tensor.h" | ||||
| namespace mindspore::schema { | |||||
| struct Tensor; | |||||
| } // namespace mindspore::schema | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| /// \brief Allocator defined a memory pool for malloc memory and free memory dynamically. | /// \brief Allocator defined a memory pool for malloc memory and free memory dynamically. | ||||
| /// | /// | ||||
| @@ -35,7 +38,7 @@ using TensorPtrVector = std::vector<mindspore::schema::Tensor *>; | |||||
| using DeviceContextVector = std::vector<DeviceContext>; | using DeviceContextVector = std::vector<DeviceContext>; | ||||
| using Uint32Vector = std::vector<uint32_t>; | using Uint32Vector = std::vector<uint32_t>; | ||||
| using String = std::string; | using String = std::string; | ||||
| using NodeType = schema::NodeType; | |||||
| using NodeType = int; /**< 0 : NodeType_ValueNode, 1 : NodeType_Parameter, 2 : NodeType_CNode. */ | |||||
| using AllocatorPtr = std::shared_ptr<Allocator>; | using AllocatorPtr = std::shared_ptr<Allocator>; | ||||
| /// \brief Set data of MSTensor from string vector. | /// \brief Set data of MSTensor from string vector. | ||||
| @@ -53,13 +53,10 @@ struct MS_API Model { | |||||
| static Model *Import(const char *model_buf, size_t size); | static Model *Import(const char *model_buf, size_t size); | ||||
| /// \brief Free meta graph temporary buffer | /// \brief Free meta graph temporary buffer | ||||
| virtual void Free(); | |||||
| /// \brief Free all temporay buffer.EG: nodes in the model. | |||||
| void Destroy(); | |||||
| virtual void Free() = 0; | |||||
| /// \brief Model destruct, free all memory | /// \brief Model destruct, free all memory | ||||
| virtual ~Model(); | |||||
| virtual ~Model() = default; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -1,4 +1,7 @@ | |||||
| add_compile_definitions(USE_ANDROID_LOG) | add_compile_definitions(USE_ANDROID_LOG) | ||||
| if (ENABLE_V0) | |||||
| add_definitions(-DENABLE_V0) | |||||
| endif() | |||||
| set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) | set(LITE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/..) | ||||
| include_directories(${LITE_DIR}/nnacl/) | include_directories(${LITE_DIR}/nnacl/) | ||||
| include_directories(${LITE_DIR}/nnacl/optimize) | include_directories(${LITE_DIR}/nnacl/optimize) | ||||
| @@ -29,13 +32,12 @@ set(LITE_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/tensorlist.cc | ${CMAKE_CURRENT_SOURCE_DIR}/tensorlist.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc | ${CMAKE_CURRENT_SOURCE_DIR}/executor.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc | ${CMAKE_CURRENT_SOURCE_DIR}/inner_context.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_common.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_model.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | ${CMAKE_CURRENT_SOURCE_DIR}/kernel_registry.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_kernel.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | ${CMAKE_CURRENT_SOURCE_DIR}/sub_graph_kernel.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ${CMAKE_CURRENT_SOURCE_DIR}/lite_session.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/model.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ${CMAKE_CURRENT_SOURCE_DIR}/errorcode.cc | ||||
| ) | ) | ||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_ | |||||
| #define MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_ | |||||
| #include <string> | |||||
| #include "src/lite_model.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class VersionManager { | |||||
| public: | |||||
| static VersionManager *GetInstance() { | |||||
| static VersionManager instance; | |||||
| return &instance; | |||||
| } | |||||
| virtual ~VersionManager() = default; | |||||
| void SetSchemaVersion(const int schema_version) { schema_version_ = schema_version; } | |||||
| int GetSchemaVersion() const { return schema_version_; } | |||||
| private: | |||||
| VersionManager() = default; | |||||
| private: | |||||
| int schema_version_ = SCHEMA_VERSION::SCHEMA_CUR; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_COMMON_VERSION_MANAGER_H_ | |||||
| @@ -13,15 +13,115 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "src/model_common.h" | |||||
| #include "src/lite_model.h" | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <unordered_map> | |||||
| #include "src/ops/while.h" | #include "src/ops/while.h" | ||||
| #ifdef ENABLE_V0 | |||||
| #include "src/ops/compat/compat_register.h" | |||||
| #endif | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { | |||||
| if (model == nullptr) { | |||||
| MS_LOG(ERROR) << "model is null."; | |||||
| #ifdef ENABLE_V0 | |||||
| int LiteModel::ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim, | |||||
| std::vector<schema::Tensor *> *dst_tensor) { | |||||
| if (node == nullptr || dst_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "node or tensor_vec is nullptr."; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| int primitive_type = prim->value_type(); | |||||
| auto creator = CompatRegistry::GetInstance()->GetTransferAttrFunc(SCHEMA_VERSION::SCHEMA_V0, primitive_type); | |||||
| if (creator == nullptr) { | |||||
| MS_LOG(DEBUG) << "the node don't need to convert attr to tensor."; | |||||
| return RET_OK; | |||||
| } | |||||
| int status = creator(reinterpret_cast<const void *>(prim), node, dst_tensor, &this->attr_tensor_bufs_); | |||||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||||
| MS_LOG(ERROR) << "translate attr to tensor failed."; | |||||
| return status; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int LiteModel::ConvertAttrToTensors(const void *meta_graph) { | |||||
| MS_ASSERT(meta_graph != nullptr); | |||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | |||||
| if (schema_version != SCHEMA_VERSION::SCHEMA_V0) { | |||||
| MS_LOG(DEBUG) << "no need to convert attr to tensor."; | |||||
| return RET_OK; | |||||
| } | |||||
| auto meta_graph_v0 = reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph); | |||||
| std::unordered_map<int, std::set<int>> subgraph_node_indexes; | |||||
| for (size_t subgraph_index = 0; subgraph_index < this->sub_graphs_.size(); ++subgraph_index) { | |||||
| for (size_t node_index = 0; node_index < this->sub_graphs_[subgraph_index]->node_indices_.size(); ++node_index) { | |||||
| subgraph_node_indexes[subgraph_index].insert(this->sub_graphs_[subgraph_index]->node_indices_[node_index]); | |||||
| } | |||||
| } | |||||
| int cur_all_tensors_size = this->all_tensors_.size(); | |||||
| for (size_t index = 0; index < this->all_nodes_.size(); ++index) { | |||||
| std::vector<schema::Tensor *> dst_tensors; | |||||
| auto prim = meta_graph_v0->nodes()->GetAs<schema::v0::CNode>(index)->primitive(); | |||||
| int status = ConvertAttrs(this->all_nodes_[index], prim, &dst_tensors); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "fail to convert attr to tensor."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (dst_tensors.empty()) { | |||||
| continue; | |||||
| } | |||||
| std::vector<int> subgraphs_with_node; | |||||
| for (size_t subgraph_index = 0; subgraph_index < this->sub_graphs_.size(); ++subgraph_index) { | |||||
| if (subgraph_node_indexes[subgraph_index].find(index) == subgraph_node_indexes[subgraph_index].end()) { | |||||
| continue; | |||||
| } | |||||
| subgraphs_with_node.push_back(subgraph_index); | |||||
| } | |||||
| for (auto tensor : dst_tensors) { | |||||
| for (auto subgraph_index : subgraphs_with_node) { | |||||
| this->sub_graphs_[subgraph_index]->tensor_indices_.push_back(cur_all_tensors_size); | |||||
| } | |||||
| this->all_nodes_[index]->input_indices_.push_back(cur_all_tensors_size++); | |||||
| this->all_tensors_.push_back(tensor); | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| #endif | |||||
| void LiteModel::Free() { | |||||
| if (this->buf != nullptr) { | |||||
| free(this->buf); | |||||
| this->buf = nullptr; | |||||
| } | |||||
| for (auto &tensor_buf : attr_tensor_bufs_) { | |||||
| free(tensor_buf); | |||||
| } | |||||
| attr_tensor_bufs_.resize(0); | |||||
| } | |||||
| LiteModel::~LiteModel() { | |||||
| Free(); | |||||
| auto nodes_size = this->all_nodes_.size(); | |||||
| for (size_t i = 0; i < nodes_size; ++i) { | |||||
| auto node = this->all_nodes_[i]; | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(node->primitive_ != nullptr); | |||||
| delete node->primitive_; | |||||
| node->primitive_ = nullptr; | |||||
| delete node; | |||||
| } | |||||
| this->all_nodes_.clear(); | |||||
| auto sub_graph_size = this->sub_graphs_.size(); | |||||
| for (size_t i = 0; i < sub_graph_size; ++i) { | |||||
| auto sub_graph = this->sub_graphs_[i]; | |||||
| delete sub_graph; | |||||
| } | |||||
| } | |||||
| int LiteModel::ConvertSubGraph(const schema::SubGraph &sub_graph) { | |||||
| if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr || | if (sub_graph.name() == nullptr || sub_graph.inputIndices() == nullptr || sub_graph.outputIndices() == nullptr || | ||||
| sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) { | sub_graph.nodeIndices() == nullptr || sub_graph.tensorIndices() == nullptr) { | ||||
| MS_LOG(ERROR) << "sub_graph is invalid."; | MS_LOG(ERROR) << "sub_graph is invalid."; | ||||
| @@ -51,28 +151,31 @@ int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model) { | |||||
| for (uint32_t i = 0; i < tensor_count; ++i) { | for (uint32_t i = 0; i < tensor_count; ++i) { | ||||
| subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i)); | subgraph->tensor_indices_.push_back(sub_graph.tensorIndices()->Get(i)); | ||||
| } | } | ||||
| model->sub_graphs_.push_back(subgraph); | |||||
| this->sub_graphs_.push_back(subgraph); | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int VersionVerify(flatbuffers::Verifier *verify) { | |||||
| int LiteModel::VersionVerify(flatbuffers::Verifier *verify) const { | |||||
| if (verify == nullptr) { | if (verify == nullptr) { | ||||
| MS_LOG(ERROR) << "verify is null."; | MS_LOG(ERROR) << "verify is null."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (schema::VerifyMetaGraphBuffer(*verify)) { | if (schema::VerifyMetaGraphBuffer(*verify)) { | ||||
| return SCHEMA_VERSION::SCHEMA_CUR; | return SCHEMA_VERSION::SCHEMA_CUR; | ||||
| } else if (schema::v0::VerifyMetaGraphBuffer(*verify)) { | |||||
| } | |||||
| #ifdef ENABLE_V0 | |||||
| if (schema::v0::VerifyMetaGraphBuffer(*verify)) { | |||||
| return SCHEMA_VERSION::SCHEMA_V0; | return SCHEMA_VERSION::SCHEMA_V0; | ||||
| } | } | ||||
| #endif | |||||
| return SCHEMA_VERSION::SCHEMA_INVALID; | return SCHEMA_VERSION::SCHEMA_INVALID; | ||||
| } | } | ||||
| int NodeVerify(const Model &model) { | |||||
| auto tensor_size = model.all_tensors_.size(); | |||||
| uint32_t subGraph_size = model.sub_graphs_.size(); | |||||
| int LiteModel::NodeVerify() const { | |||||
| auto tensor_size = this->all_tensors_.size(); | |||||
| uint32_t subGraph_size = this->sub_graphs_.size(); | |||||
| for (auto &node : model.all_nodes_) { | |||||
| for (auto &node : this->all_nodes_) { | |||||
| if (node == nullptr || node->primitive_ == nullptr) { | if (node == nullptr || node->primitive_ == nullptr) { | ||||
| MS_LOG(ERROR) << "node or its primitive_ is null."; | MS_LOG(ERROR) << "node or its primitive_ is null."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -105,11 +208,11 @@ int NodeVerify(const Model &model) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int SubGraphVerify(const Model &model) { | |||||
| auto tensor_size = model.all_tensors_.size(); | |||||
| auto node_size = model.all_nodes_.size(); | |||||
| int LiteModel::SubGraphVerify() const { | |||||
| auto tensor_size = this->all_tensors_.size(); | |||||
| auto node_size = this->all_nodes_.size(); | |||||
| for (auto &graph : model.sub_graphs_) { | |||||
| for (auto &graph : this->sub_graphs_) { | |||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| MS_LOG(ERROR) << "graph is null."; | MS_LOG(ERROR) << "graph is null."; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -138,49 +241,78 @@ int SubGraphVerify(const Model &model) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| bool ModelVerify(const Model &model) { return NodeVerify(model) == RET_OK && SubGraphVerify(model) == RET_OK; } | |||||
| bool LiteModel::ModelVerify() const { return NodeVerify() == RET_OK && SubGraphVerify() == RET_OK; } | |||||
| const void *GetMetaGraphByVerison(const char *buf, const int &schema_version) { | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "buf is null."; | |||||
| return nullptr; | |||||
| } | |||||
| const void *LiteModel::GetMetaGraphByVerison() { | |||||
| MS_ASSERT(this->buf != nullptr); | |||||
| auto schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | |||||
| if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { | if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { | ||||
| return reinterpret_cast<const void *>(schema::GetMetaGraph(buf)); | |||||
| } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { | |||||
| return reinterpret_cast<const void *>(schema::GetMetaGraph(this->buf)); | |||||
| } | |||||
| #ifdef ENABLE_V0 | |||||
| if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { | |||||
| return reinterpret_cast<const void *>(schema::v0::GetMetaGraph(buf)); | return reinterpret_cast<const void *>(schema::v0::GetMetaGraph(buf)); | ||||
| } | } | ||||
| #endif | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version) { | |||||
| if (meta_graph == nullptr || model == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph or model is null."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int LiteModel::GenerateModelByVersion(const void *meta_graph) { | |||||
| MS_ASSERT(meta_graph != nullptr); | |||||
| auto schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | |||||
| int status = RET_ERROR; | int status = RET_ERROR; | ||||
| if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { | if (schema_version == SCHEMA_VERSION::SCHEMA_CUR) { | ||||
| status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph), | |||||
| model, schema_version); | |||||
| } else if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { | |||||
| status = GenerateModel<schema::MetaGraph, schema::CNode>(*reinterpret_cast<const schema::MetaGraph *>(meta_graph)); | |||||
| } | |||||
| #ifdef ENABLE_V0 | |||||
| if (schema_version == SCHEMA_VERSION::SCHEMA_V0) { | |||||
| status = GenerateModel<schema::v0::MetaGraph, schema::v0::CNode>( | status = GenerateModel<schema::v0::MetaGraph, schema::v0::CNode>( | ||||
| *reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph), model, schema_version); | |||||
| *reinterpret_cast<const schema::v0::MetaGraph *>(meta_graph)); | |||||
| } | } | ||||
| #endif | |||||
| return status; | return status; | ||||
| } | } | ||||
| Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { | |||||
| if (model_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||||
| return nullptr; | |||||
| int LiteModel::ConstructModel() { | |||||
| if (this->buf == nullptr || this->buf_size_ <= 0) { | |||||
| MS_LOG(ERROR) << "cannot construct model."; | |||||
| return RET_NULL_PTR; | |||||
| } | } | ||||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||||
| flatbuffers::Verifier verify((const uint8_t *)this->buf, this->buf_size_); | |||||
| int schema_version = VersionVerify(&verify); | int schema_version = VersionVerify(&verify); | ||||
| if (schema_version == SCHEMA_INVALID) { | if (schema_version == SCHEMA_INVALID) { | ||||
| MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | MS_LOG(ERROR) << "The buffer is invalid and fail to create graph."; | ||||
| return RET_ERROR; | |||||
| } | |||||
| VersionManager::GetInstance()->SetSchemaVersion(schema_version); | |||||
| const void *meta_graph = GetMetaGraphByVerison(); | |||||
| if (meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph is nullptr!"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| int status = GenerateModelByVersion(meta_graph); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "fail to generate model"; | |||||
| return status; | |||||
| } | |||||
| if (this->version_ != Version()) { | |||||
| MS_LOG(WARNING) << "model version is " << this->version_ << ", inference version is " << Version() << " not equal"; | |||||
| } | |||||
| if (this->sub_graphs_.empty()) { | |||||
| return RET_ERROR; | |||||
| } | |||||
| return ModelVerify() ? RET_OK : RET_ERROR; | |||||
| } | |||||
| Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { | |||||
| if (model_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "The model buf is nullptr"; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto *model = new (std::nothrow) Model(); | |||||
| auto *model = new (std::nothrow) LiteModel(); | |||||
| if (model == nullptr) { | if (model == nullptr) { | ||||
| MS_LOG(ERROR) << "new model fail!"; | MS_LOG(ERROR) << "new model fail!"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -201,28 +333,15 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) { | |||||
| } | } | ||||
| memcpy(model->buf, model_buf, size); | memcpy(model->buf, model_buf, size); | ||||
| } | } | ||||
| const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); | |||||
| if (meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph is nullptr!"; | |||||
| delete (model); | |||||
| return nullptr; | |||||
| } | |||||
| int status = GenerateModelByVersion(meta_graph, model, schema_version); | |||||
| model->buf_size_ = size; | |||||
| auto status = model->ConstructModel(); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| delete (model); | |||||
| MS_LOG(ERROR) << "fail to generate model"; | |||||
| MS_LOG(ERROR) << "construct model failed."; | |||||
| delete model; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| if (model->version_ != Version()) { | |||||
| MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal"; | |||||
| } | |||||
| if (model->sub_graphs_.empty()) { | |||||
| delete (model); | |||||
| return nullptr; | |||||
| } | |||||
| return ModelVerify(*model) ? model : nullptr; | |||||
| return model; | |||||
| } | } | ||||
| Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -0,0 +1,223 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_LITE_MODEL_H_ | |||||
| #define MINDSPORE_LITE_SRC_LITE_MODEL_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "include/model.h" | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "include/version.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/common/common.h" | |||||
| #include "src/common/version_manager.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| #ifdef ENABLE_V0 | |||||
| #include "schema/model_v0_generated.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class LiteModel : public Model { | |||||
| public: | |||||
| int ConstructModel(); | |||||
| bool ModelVerify() const; | |||||
| void Free() override; | |||||
| ~LiteModel() override; | |||||
| private: | |||||
| #ifdef ENABLE_V0 | |||||
| int ConvertAttrs(Model::Node *node, const schema::v0::Primitive *prim, std::vector<schema::Tensor *> *dst_tensor); | |||||
| int ConvertAttrToTensors(const void *meta_graph); | |||||
| #endif | |||||
| template <typename T = schema::MetaGraph, typename U = schema::CNode> | |||||
| bool ConvertNodes(const T &meta_graph) { | |||||
| if (meta_graph.nodes() == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { | |||||
| auto *node = new (std::nothrow) Model::Node(); | |||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "new node fail!"; | |||||
| return false; | |||||
| } | |||||
| auto c_node = meta_graph.nodes()->template GetAs<U>(i); | |||||
| auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive()); | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | |||||
| #else | |||||
| auto primitive = const_cast<schema::Primitive *>(src_prim); | |||||
| auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); | |||||
| if (func_pointer == nullptr) { | |||||
| MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType(primitive->value_type()); | |||||
| delete node; | |||||
| return false; | |||||
| } | |||||
| node->primitive_ = func_pointer(primitive); | |||||
| #endif | |||||
| if (node->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "unpack primitive == nullptr!"; | |||||
| delete node; | |||||
| return false; | |||||
| } | |||||
| node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType())); | |||||
| node->name_ = c_node->name()->c_str(); | |||||
| node->node_type_ = static_cast<NodeType>(c_node->nodeType()); | |||||
| auto count = c_node->inputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | |||||
| node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); | |||||
| } | |||||
| if (c_node->outputIndex() != nullptr) { | |||||
| count = c_node->outputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | |||||
| node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); | |||||
| } | |||||
| } | |||||
| this->all_nodes_.push_back(node); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T = schema::MetaGraph> | |||||
| bool ConvertTensors(const T &meta_graph) { | |||||
| if (meta_graph.allTensors() == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; | |||||
| return false; | |||||
| } | |||||
| auto tensor_count = meta_graph.allTensors()->size(); | |||||
| for (uint32_t i = 0; i < tensor_count; ++i) { | |||||
| auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << i << "the tensor in metagraph is nullptr"; | |||||
| return false; | |||||
| } | |||||
| this->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor)); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T = schema::MetaGraph> | |||||
| int MetaGraphMappingSubGraph(const T &meta_graph) { | |||||
| if (meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || meta_graph.nodes() == nullptr || | |||||
| meta_graph.allTensors() == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph is invalid, please check your model file."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *subgraph = new (std::nothrow) Model::SubGraph(); | |||||
| if (subgraph == nullptr) { | |||||
| MS_LOG(ERROR) << "new subGraph fail!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (meta_graph.name() != nullptr) { | |||||
| subgraph->name_ = meta_graph.name()->c_str(); | |||||
| } | |||||
| auto in_count = meta_graph.inputIndex()->size(); | |||||
| for (uint32_t i = 0; i < in_count; ++i) { | |||||
| subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i))); | |||||
| } | |||||
| auto out_count = meta_graph.outputIndex()->size(); | |||||
| for (uint32_t i = 0; i < out_count; ++i) { | |||||
| subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i))); | |||||
| } | |||||
| auto node_count = meta_graph.nodes()->size(); | |||||
| for (uint32_t i = 0; i < node_count; ++i) { | |||||
| subgraph->node_indices_.push_back(i); | |||||
| } | |||||
| auto tensor_count = meta_graph.allTensors()->size(); | |||||
| for (uint32_t i = 0; i < tensor_count; ++i) { | |||||
| subgraph->tensor_indices_.push_back(i); | |||||
| } | |||||
| this->sub_graphs_.push_back(subgraph); | |||||
| return RET_OK; | |||||
| } | |||||
| template <typename T = schema::MetaGraph, typename U = schema::CNode> | |||||
| int GenerateModel(const T &meta_graph) { | |||||
| if (meta_graph.name() != nullptr) { | |||||
| this->name_ = meta_graph.name()->c_str(); | |||||
| } | |||||
| if (meta_graph.version() != nullptr) { | |||||
| this->version_ = meta_graph.version()->c_str(); | |||||
| } | |||||
| if (!ConvertNodes<T, U>(meta_graph)) { | |||||
| MS_LOG(ERROR) << "convert node failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!ConvertTensors<T>(meta_graph)) { | |||||
| MS_LOG(ERROR) << "convert tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (meta_graph.subGraph() == nullptr) { | |||||
| int ret = MetaGraphMappingSubGraph<T>(meta_graph); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "converter old version model wrong."; | |||||
| return ret; | |||||
| } | |||||
| } else { | |||||
| auto sub_graphs = meta_graph.subGraph(); | |||||
| auto sub_graph_size = sub_graphs->size(); | |||||
| for (size_t i = 0; i < sub_graph_size; i++) { | |||||
| auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i); | |||||
| int ret = ConvertSubGraph(*sub_graph); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "converter subgraph wrong."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| #ifdef ENABLE_V0 | |||||
| if (ConvertAttrToTensors(&meta_graph) != RET_OK) { | |||||
| MS_LOG(ERROR) << "fail to convert attr to tensor."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| #endif | |||||
| return RET_OK; | |||||
| } | |||||
| int VersionVerify(flatbuffers::Verifier *verify) const; | |||||
| const void *GetMetaGraphByVerison(); | |||||
| int GenerateModelByVersion(const void *meta_graph); | |||||
| int ConvertSubGraph(const schema::SubGraph &sub_graph); | |||||
| int NodeVerify() const; | |||||
| int SubGraphVerify() const; | |||||
| public: | |||||
| size_t buf_size_ = 0; | |||||
| protected: | |||||
| std::vector<char *> attr_tensor_bufs_; | |||||
| }; | |||||
| Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_LITE_SRC_LITE_MODEL_H_ | |||||
| @@ -26,7 +26,7 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/model_common.h" | |||||
| #include "src/lite_model.h" | |||||
| #include "src/runtime/kernel/arm/base/dequant.h" | #include "src/runtime/kernel/arm/base/dequant.h" | ||||
| #if SUPPORT_NPU | #if SUPPORT_NPU | ||||
| #include "src/runtime/agent/npu/npu_manager.h" | #include "src/runtime/agent/npu/npu_manager.h" | ||||
| @@ -363,7 +363,7 @@ int LiteSession::CompileGraph(Model *model) { | |||||
| is_running_.store(false); | is_running_.store(false); | ||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| if (!ModelVerify(*model)) { | |||||
| if (!reinterpret_cast<LiteModel *>(model)->ModelVerify()) { | |||||
| MS_LOG(ERROR) << "wrong model input, please check"; | MS_LOG(ERROR) << "wrong model input, please check"; | ||||
| is_running_.store(false); | is_running_.store(false); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| @@ -1,52 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "include/model.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/model_common.h" | |||||
| namespace mindspore::lite { | |||||
| Model *Model::Import(const char *model_buf, size_t size) { return ImportFromBuffer(model_buf, size, false); } | |||||
| void Model::Free() { | |||||
| if (this->buf != nullptr) { | |||||
| free(this->buf); | |||||
| this->buf = nullptr; | |||||
| } | |||||
| } | |||||
| void Model::Destroy() { | |||||
| Free(); | |||||
| auto nodes_size = this->all_nodes_.size(); | |||||
| for (size_t i = 0; i < nodes_size; ++i) { | |||||
| auto node = this->all_nodes_[i]; | |||||
| MS_ASSERT(node != nullptr); | |||||
| MS_ASSERT(node->primitive_ != nullptr); | |||||
| delete node->primitive_; | |||||
| node->primitive_ = nullptr; | |||||
| delete node; | |||||
| } | |||||
| this->all_nodes_.clear(); | |||||
| auto sub_graph_size = this->sub_graphs_.size(); | |||||
| for (size_t i = 0; i < sub_graph_size; ++i) { | |||||
| auto sub_graph = this->sub_graphs_[i]; | |||||
| delete sub_graph; | |||||
| } | |||||
| } | |||||
| Model::~Model() { Destroy(); } | |||||
| } // namespace mindspore::lite | |||||
| @@ -1,192 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_MODEL_COMMON_H_ | |||||
| #define MINDSPORE_LITE_SRC_MODEL_COMMON_H_ | |||||
| #include <string> | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "include/model.h" | |||||
| #include "include/version.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "schema/model_v0_generated.h" | |||||
| #include "src/common/common.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore::lite { | |||||
| int ConvertSubGraph(const schema::SubGraph &sub_graph, Model *model); | |||||
| template <typename T = schema::MetaGraph, typename U = schema::CNode> | |||||
| bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA_CUR) { | |||||
| if (model == nullptr || meta_graph.nodes() == nullptr) { | |||||
| MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file."; | |||||
| return false; | |||||
| } | |||||
| for (size_t i = 0; i < meta_graph.nodes()->size(); ++i) { | |||||
| auto *node = new (std::nothrow) Model::Node(); | |||||
| if (node == nullptr) { | |||||
| MS_LOG(ERROR) << "new node fail!"; | |||||
| return false; | |||||
| } | |||||
| auto c_node = meta_graph.nodes()->template GetAs<U>(i); | |||||
| auto src_prim = reinterpret_cast<const schema::Primitive *>(c_node->primitive()); | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | |||||
| #else | |||||
| auto primitive = const_cast<schema::Primitive *>(src_prim); | |||||
| auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); | |||||
| if (func_pointer == nullptr) { | |||||
| MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType(primitive->value_type()); | |||||
| delete node; | |||||
| return false; | |||||
| } | |||||
| node->primitive_ = func_pointer(primitive); | |||||
| #endif | |||||
| if (node->primitive_ == nullptr) { | |||||
| MS_LOG(ERROR) << "unpack primitive == nullptr!"; | |||||
| delete node; | |||||
| return false; | |||||
| } | |||||
| node->primitive_->set_quant_type(static_cast<schema::QuantType>(c_node->quantType())); | |||||
| node->name_ = c_node->name()->c_str(); | |||||
| node->node_type_ = static_cast<NodeType>(c_node->nodeType()); | |||||
| auto count = c_node->inputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | |||||
| node->input_indices_.push_back(size_t(c_node->inputIndex()->template GetAs<uint32_t>(j))); | |||||
| } | |||||
| if (c_node->outputIndex() != nullptr) { | |||||
| count = c_node->outputIndex()->size(); | |||||
| for (uint32_t j = 0; j < count; ++j) { | |||||
| node->output_indices_.push_back(size_t(c_node->outputIndex()->template GetAs<uint32_t>(j))); | |||||
| } | |||||
| } | |||||
| model->all_nodes_.push_back(node); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T = schema::MetaGraph> | |||||
| bool ConvertTensors(const T &meta_graph, Model *model) { | |||||
| if (model == nullptr || meta_graph.allTensors() == nullptr) { | |||||
| MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file."; | |||||
| return false; | |||||
| } | |||||
| auto tensor_count = meta_graph.allTensors()->size(); | |||||
| for (uint32_t i = 0; i < tensor_count; ++i) { | |||||
| auto *tensor = meta_graph.allTensors()->template GetAs<schema::Tensor>(i); | |||||
| if (tensor == nullptr) { | |||||
| MS_LOG(ERROR) << i << "th tensor in model is nullptr"; | |||||
| return false; | |||||
| } | |||||
| model->all_tensors_.push_back(const_cast<mindspore::schema::Tensor *>(tensor)); | |||||
| } | |||||
| return true; | |||||
| } | |||||
| template <typename T = schema::MetaGraph> | |||||
| int MetaGraphMappingSubGraph(const T &meta_graph, Model *model) { | |||||
| if (model == nullptr || meta_graph.inputIndex() == nullptr || meta_graph.outputIndex() == nullptr || | |||||
| meta_graph.nodes() == nullptr || meta_graph.allTensors() == nullptr) { | |||||
| MS_LOG(ERROR) << "model or meta_graph is invalid, please check your model file."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto *subgraph = new (std::nothrow) Model::SubGraph(); | |||||
| if (subgraph == nullptr) { | |||||
| MS_LOG(ERROR) << "new subGraph fail!"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (meta_graph.name() != nullptr) { | |||||
| subgraph->name_ = meta_graph.name()->c_str(); | |||||
| } | |||||
| auto in_count = meta_graph.inputIndex()->size(); | |||||
| for (uint32_t i = 0; i < in_count; ++i) { | |||||
| subgraph->input_indices_.push_back(size_t(meta_graph.inputIndex()->template GetAs<uint32_t>(i))); | |||||
| } | |||||
| auto out_count = meta_graph.outputIndex()->size(); | |||||
| for (uint32_t i = 0; i < out_count; ++i) { | |||||
| subgraph->output_indices_.push_back(size_t(meta_graph.outputIndex()->template GetAs<uint32_t>(i))); | |||||
| } | |||||
| auto node_count = meta_graph.nodes()->size(); | |||||
| for (uint32_t i = 0; i < node_count; ++i) { | |||||
| subgraph->node_indices_.push_back(i); | |||||
| } | |||||
| auto tensor_count = meta_graph.allTensors()->size(); | |||||
| for (uint32_t i = 0; i < tensor_count; ++i) { | |||||
| subgraph->tensor_indices_.push_back(i); | |||||
| } | |||||
| model->sub_graphs_.push_back(subgraph); | |||||
| return RET_OK; | |||||
| } | |||||
| template <typename T = schema::MetaGraph, typename U = schema::CNode> | |||||
| int GenerateModel(const T &meta_graph, Model *model, int schema_version = 0) { | |||||
| if (model == nullptr) { | |||||
| MS_LOG(ERROR) << "model is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (meta_graph.name() != nullptr) { | |||||
| model->name_ = meta_graph.name()->c_str(); | |||||
| } | |||||
| if (meta_graph.version() != nullptr) { | |||||
| model->version_ = meta_graph.version()->c_str(); | |||||
| } | |||||
| if (!ConvertNodes<T, U>(meta_graph, model, schema_version)) { | |||||
| MS_LOG(ERROR) << "convert node failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (!ConvertTensors<T>(meta_graph, model)) { | |||||
| MS_LOG(ERROR) << "convert tensor failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (meta_graph.subGraph() == nullptr) { | |||||
| int ret = MetaGraphMappingSubGraph<T>(meta_graph, model); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "converter old version model wrong."; | |||||
| return ret; | |||||
| } | |||||
| } else { | |||||
| auto sub_graphs = meta_graph.subGraph(); | |||||
| auto sub_graph_size = sub_graphs->size(); | |||||
| for (size_t i = 0; i < sub_graph_size; i++) { | |||||
| auto sub_graph = sub_graphs->template GetAs<schema::SubGraph>(i); | |||||
| int ret = ConvertSubGraph(*sub_graph, model); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "converter subgraph wrong."; | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| int VersionVerify(flatbuffers::Verifier *verify); | |||||
| int NodeVerify(const Model &model); | |||||
| int SubGraphVerify(const Model &model); | |||||
| bool ModelVerify(const Model &model); | |||||
| const void *GetMetaGraphByVerison(const char *buf, const int &schema_version); | |||||
| int GenerateModelByVersion(const void *meta_graph, Model *model, const int &schema_version); | |||||
| Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf); | |||||
| } // namespace mindspore::lite | |||||
| #endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_ | |||||
| @@ -4,6 +4,10 @@ file(GLOB OPS_SRC | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/*.cc | ${CMAKE_CURRENT_SOURCE_DIR}/*.cc | ||||
| ${CMAKE_CURRENT_SOURCE_DIR}/populate/*.cc | ${CMAKE_CURRENT_SOURCE_DIR}/populate/*.cc | ||||
| ) | ) | ||||
| if (ENABLE_V0) | |||||
| file(GLOB_RECURSE COMPAT_SRC ${CMAKE_CURRENT_SOURCE_DIR}/compat/*.cc) | |||||
| set(OPS_SRC ${OPS_SRC} ${COMPAT_SRC}) | |||||
| endif () | |||||
| add_library(cpu_ops_mid OBJECT ${OPS_SRC}) | add_library(cpu_ops_mid OBJECT ${OPS_SRC}) | ||||
| add_dependencies(cpu_ops_mid fbs_src) | add_dependencies(cpu_ops_mid fbs_src) | ||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/compat/attr_transfer_common.h" | |||||
| #include <vector> | |||||
| #include "src/common/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id, | |||||
| std::vector<char *> *tensor_bufs) { | |||||
| if (data == nullptr || tensor_bufs == nullptr) { | |||||
| MS_LOG(ERROR) << "the parameter of this function is nullptr."; | |||||
| return nullptr; | |||||
| } | |||||
| auto dst_tensor = | |||||
| (is_array ? new (std::nothrow) Tensor(type_id, {data_size}, schema::Format_NHWC, Tensor::Category::CONST_TENSOR) | |||||
| : new (std::nothrow) Tensor(type_id, {}, schema::Format_NHWC, Tensor::Category::CONST_SCALAR)); | |||||
| auto dst_data = dst_tensor->MutableData(); | |||||
| if (dst_data == nullptr) { | |||||
| MS_LOG(ERROR) << "Data from tensor is nullptr"; | |||||
| return nullptr; | |||||
| } | |||||
| std::vector<uint8_t> uint8_data; | |||||
| uint8_data.resize(dst_tensor->Size()); | |||||
| memcpy(uint8_data.data(), data, dst_tensor->Size()); | |||||
| auto shape = dst_tensor->shape(); | |||||
| flatbuffers::FlatBufferBuilder fbb(1024); | |||||
| auto tensor_offset = schema::CreateTensorDirect(fbb, schema::NodeType_ValueNode, type_id, &shape, schema::Format_NHWC, | |||||
| 0, 0, &uint8_data); | |||||
| fbb.Finish(tensor_offset); | |||||
| delete dst_tensor; | |||||
| auto buf = fbb.GetBufferPointer(); | |||||
| if (buf == nullptr) { | |||||
| MS_LOG(ERROR) << "GetBufferPointer return nullptr"; | |||||
| fbb.Clear(); | |||||
| return nullptr; | |||||
| } | |||||
| auto tensor_buf = reinterpret_cast<char *>(malloc(fbb.GetSize())); | |||||
| if (tensor_buf == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc primitive_buf_ failed"; | |||||
| fbb.Clear(); | |||||
| return nullptr; | |||||
| } | |||||
| memcpy(tensor_buf, buf, fbb.GetSize()); | |||||
| auto tensor = flatbuffers::GetRoot<schema::Tensor>(tensor_buf); | |||||
| tensor_bufs->push_back(tensor_buf); | |||||
| fbb.Clear(); | |||||
| return const_cast<schema::Tensor *>(tensor); | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_ | |||||
| #include <vector> | |||||
| #include "ir/dtype/type_id.h" | |||||
| #include "src/tensor.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/model_v0_generated.h" | |||||
| #include "src/common/common.h" | |||||
| #include "src/ops/compat/compat_register.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| schema::Tensor *AttrToTensor(void *data, int data_size, bool is_array, TypeId type_id, | |||||
| std::vector<char *> *tensor_bufs); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_OP_ATTR_TRANSFER_COMMON_H_ | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_ | |||||
| #include <unordered_map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "include/model.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| // compatibility, transfer attr to input tensor. | |||||
| typedef int (*TransferAttrFunc)(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *tensor, | |||||
| std::vector<char *> *tensor_bufs); | |||||
| class CompatRegistry { | |||||
| public: | |||||
| static CompatRegistry *GetInstance() { | |||||
| static CompatRegistry registry; | |||||
| return ®istry; | |||||
| } | |||||
| void InsertTransferAttrFuncMap(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) { | |||||
| int key = primitive_type * 10 + schema_version; | |||||
| transfer_attr_funcs_[key] = transfer_attr_func; | |||||
| } | |||||
| TransferAttrFunc GetTransferAttrFunc(int schema_version, int primitive_type) { | |||||
| int key = primitive_type * 10 + schema_version; | |||||
| if (transfer_attr_funcs_.find(key) != transfer_attr_funcs_.end()) { | |||||
| return transfer_attr_funcs_[key]; | |||||
| } else { | |||||
| MS_LOG(DEBUG) << "Unsupported transformer type in Create : " << key; | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| protected: | |||||
| std::unordered_map<int, TransferAttrFunc> transfer_attr_funcs_; | |||||
| }; | |||||
| class Register { | |||||
| public: | |||||
| Register(int schema_version, int primitive_type, TransferAttrFunc transfer_attr_func) { | |||||
| CompatRegistry::GetInstance()->InsertTransferAttrFuncMap(schema_version, primitive_type, transfer_attr_func); | |||||
| } | |||||
| virtual ~Register() = default; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_OP_COMPAT_REGISTER_H_ | |||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/compat/attr_transfer_common.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int TransferBroadcastToAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors, | |||||
| std::vector<char *> *tensor_bufs) { | |||||
| if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) { | |||||
| MS_LOG(ERROR) << "the parameter of this function is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (node->input_indices_.size() != 1) { | |||||
| MS_LOG(DEBUG) << "broadcast_to don't need to convert attr to tensor."; | |||||
| return RET_OK; | |||||
| } | |||||
| dst_tensors->clear(); | |||||
| tensor_bufs->clear(); | |||||
| auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive); | |||||
| auto dst_shape_attr = prim->value_as_BroadcastTo()->dst_shape(); | |||||
| std::vector<int> dst_shape = std::vector<int>(dst_shape_attr->begin(), dst_shape_attr->end()); | |||||
| auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs); | |||||
| if (dst_shape_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| dst_tensors->push_back(dst_shape_tensor); | |||||
| return RET_OK; | |||||
| } | |||||
| Register BroadcastToTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_BroadcastTo, | |||||
| TransferBroadcastToAttr); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,47 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/compat/attr_transfer_common.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int TransferReshapeAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors, | |||||
| std::vector<char *> *tensor_bufs) { | |||||
| if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) { | |||||
| MS_LOG(ERROR) << "the parameter of this function is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (node->input_indices_.size() != 1) { | |||||
| MS_LOG(DEBUG) << "reshape need to convert attr to tensor."; | |||||
| return RET_OK; | |||||
| } | |||||
| dst_tensors->clear(); | |||||
| tensor_bufs->clear(); | |||||
| auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive); | |||||
| auto dst_shape_attr = prim->value_as_Reshape()->shape(); | |||||
| std::vector<int> dst_shape = std::vector<int>(dst_shape_attr->begin(), dst_shape_attr->end()); | |||||
| auto dst_shape_tensor = AttrToTensor(dst_shape.data(), dst_shape.size(), true, kNumberTypeInt32, tensor_bufs); | |||||
| if (dst_shape_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "attr tensor is nullptr, transform is failed."; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| dst_tensors->push_back(dst_shape_tensor); | |||||
| return RET_OK; | |||||
| } | |||||
| Register ReshapeTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_Reshape, TransferReshapeAttr); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,67 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/compat/attr_transfer_common.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int TransferStridedSliceAttr(const void *primitive, Model::Node *node, std::vector<schema::Tensor *> *dst_tensors, | |||||
| std::vector<char *> *tensor_bufs) { | |||||
| if (primitive == nullptr || node == nullptr || dst_tensors == nullptr || tensor_bufs == nullptr) { | |||||
| MS_LOG(ERROR) << "the parameter of this function is nullptr."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| dst_tensors->clear(); | |||||
| tensor_bufs->clear(); | |||||
| auto prim = reinterpret_cast<const schema::v0::Primitive *>(primitive); | |||||
| int inputs_size = node->input_indices_.size(); | |||||
| switch (inputs_size) { | |||||
| case 1: { | |||||
| auto begins_attr = prim->value_as_StridedSlice()->begin(); | |||||
| std::vector<int> dst_begins = std::vector<int>(begins_attr->begin(), begins_attr->end()); | |||||
| auto dst_begins_tensor = AttrToTensor(dst_begins.data(), dst_begins.size(), true, kNumberTypeInt32, tensor_bufs); | |||||
| dst_tensors->push_back(dst_begins_tensor); | |||||
| } | |||||
| case 2: { | |||||
| auto ends_attr = prim->value_as_StridedSlice()->end(); | |||||
| std::vector<int> dst_ends = std::vector<int>(ends_attr->begin(), ends_attr->end()); | |||||
| auto dst_ends_tensor = AttrToTensor(dst_ends.data(), dst_ends.size(), true, kNumberTypeInt32, tensor_bufs); | |||||
| dst_tensors->push_back(dst_ends_tensor); | |||||
| } | |||||
| case 3: { | |||||
| auto strides_attr = prim->value_as_StridedSlice()->stride(); | |||||
| std::vector<int> dst_strides = std::vector<int>(strides_attr->begin(), strides_attr->end()); | |||||
| auto dst_strides_tensor = | |||||
| AttrToTensor(dst_strides.data(), dst_strides.size(), true, kNumberTypeInt32, tensor_bufs); | |||||
| dst_tensors->push_back(dst_strides_tensor); | |||||
| break; | |||||
| } | |||||
| default: { | |||||
| MS_LOG(DEBUG) << "stride_slice don't need to convert attr to tensor."; | |||||
| return RET_OK; | |||||
| } | |||||
| } | |||||
| if (std::any_of(dst_tensors->begin(), dst_tensors->end(), [](schema::Tensor *tensor) { return tensor == nullptr; })) { | |||||
| MS_LOG(ERROR) << "convert attr to tensor failed."; | |||||
| return RET_ERROR; | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| Register StridedSliceTransferRegistry(SCHEMA_VERSION::SCHEMA_V0, schema::v0::PrimitiveType_StridedSlice, | |||||
| TransferStridedSliceAttr); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -18,7 +18,6 @@ | |||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| #include "include/errorcode.h" | #include "include/errorcode.h" | ||||
| #include "src/common/graph_util.h" | #include "src/common/graph_util.h" | ||||
| #include "src/model_common.h" | |||||
| namespace mindspore::lite { | namespace mindspore::lite { | ||||
| @@ -27,12 +26,6 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||||
| MS_LOG(ERROR) << "The model buf is nullptr"; | MS_LOG(ERROR) << "The model buf is nullptr"; | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| flatbuffers::Verifier verify((const uint8_t *)model_buf, size); | |||||
| int schema_version = VersionVerify(&verify); | |||||
| if (schema_version == -1) { | |||||
| MS_LOG(ERROR) << "The model buffer is invalid, cannot get schema version"; | |||||
| return nullptr; | |||||
| } | |||||
| TrainModel *model = new (std::nothrow) TrainModel(); | TrainModel *model = new (std::nothrow) TrainModel(); | ||||
| if (model == nullptr) { | if (model == nullptr) { | ||||
| MS_LOG(ERROR) << "new model fail!"; | MS_LOG(ERROR) << "new model fail!"; | ||||
| @@ -46,19 +39,10 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) { | |||||
| } | } | ||||
| memcpy(model->buf, model_buf, size); | memcpy(model->buf, model_buf, size); | ||||
| model->buf_size_ = size; | model->buf_size_ = size; | ||||
| const void *meta_graph = GetMetaGraphByVerison(model->buf, schema_version); | |||||
| if (meta_graph == nullptr) { | |||||
| MS_LOG(ERROR) << "meta_graph is nullptr!"; | |||||
| free(model->buf); | |||||
| delete (model); | |||||
| return nullptr; | |||||
| } | |||||
| int status = GenerateModelByVersion(meta_graph, model, schema_version); | |||||
| auto status = model->ConstructModel(); | |||||
| if (status != RET_OK) { | if (status != RET_OK) { | ||||
| free(model->buf); | |||||
| delete (model); | |||||
| MS_LOG(ERROR) << "fail to generate model"; | |||||
| MS_LOG(ERROR) << "construct model failed."; | |||||
| delete model; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return model; | return model; | ||||
| @@ -91,6 +75,4 @@ char *TrainModel::ExportBuf(char *buffer, size_t *len) const { | |||||
| *len = buf_size_; | *len = buf_size_; | ||||
| return buffer; | return buffer; | ||||
| } | } | ||||
| TrainModel::~TrainModel() { Model::Free(); } | |||||
| } // namespace mindspore::lite | } // namespace mindspore::lite | ||||
| @@ -16,13 +16,13 @@ | |||||
| #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | #ifndef MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | ||||
| #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | #define MINDSPORE_LITE_SRC_TRAIN_TRAIN_MODEL_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "include/model.h" | |||||
| #include "src/lite_model.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| /// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model | /// \brief TrainModel Defines a class that allows to import and export a mindsport trainable model | ||||
| struct TrainModel : public lite::Model { | |||||
| struct TrainModel : public lite::LiteModel { | |||||
| /// \brief Static method to create a TrainModel object | /// \brief Static method to create a TrainModel object | ||||
| /// | /// | ||||
| /// \param[in] model_buf A buffer that was read from a MS model file | /// \param[in] model_buf A buffer that was read from a MS model file | ||||
| @@ -35,7 +35,7 @@ struct TrainModel : public lite::Model { | |||||
| void Free() override; | void Free() override; | ||||
| /// \brief Class destructor, free all memory | /// \brief Class destructor, free all memory | ||||
| virtual ~TrainModel(); | |||||
| virtual ~TrainModel() = default; | |||||
| /// \brief Export Model into a buffer | /// \brief Export Model into a buffer | ||||
| /// | /// | ||||
| @@ -44,8 +44,6 @@ struct TrainModel : public lite::Model { | |||||
| /// | /// | ||||
| /// \return Pointer to buffer with exported model | /// \return Pointer to buffer with exported model | ||||
| char *ExportBuf(char *buf, size_t *len) const; | char *ExportBuf(char *buf, size_t *len) const; | ||||
| size_t buf_size_; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -105,7 +105,8 @@ if (PLATFORM_ARM32 OR PLATFORM_ARM64) | |||||
| endif() | endif() | ||||
| endif() | endif() | ||||
| ### runtime framework | ### runtime framework | ||||
| file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc ${LITE_DIR}/src/ops/populate/*.cc) | |||||
| add_definitions(-DENABLE_V0) | |||||
| file(GLOB_RECURSE OPS_SRC ${LITE_DIR}/src/ops/*.cc) | |||||
| set(TEST_LITE_SRC | set(TEST_LITE_SRC | ||||
| ${TEST_LITE_SRC} | ${TEST_LITE_SRC} | ||||
| ${CCSRC_SRC} | ${CCSRC_SRC} | ||||
| @@ -123,8 +124,7 @@ set(TEST_LITE_SRC | |||||
| ${LITE_DIR}/src/lite_kernel.cc | ${LITE_DIR}/src/lite_kernel.cc | ||||
| ${LITE_DIR}/src/lite_session.cc | ${LITE_DIR}/src/lite_session.cc | ||||
| ${LITE_DIR}/src/sub_graph_kernel.cc | ${LITE_DIR}/src/sub_graph_kernel.cc | ||||
| ${LITE_DIR}/src/model.cc | |||||
| ${LITE_DIR}/src/model_common.cc | |||||
| ${LITE_DIR}/src/lite_model.cc | |||||
| ${LITE_DIR}/src/scheduler.cc | ${LITE_DIR}/src/scheduler.cc | ||||
| ${LITE_DIR}/src/common/graph_util.cc | ${LITE_DIR}/src/common/graph_util.cc | ||||
| ${LITE_DIR}/src/common/file_utils.cc | ${LITE_DIR}/src/common/file_utils.cc | ||||
| @@ -9,7 +9,7 @@ set(CCSRC_SRC | |||||
| include(${TOP_DIR}/cmake/external_libs/glog.cmake) | include(${TOP_DIR}/cmake/external_libs/glog.cmake) | ||||
| file(GLOB_RECURSE OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc | |||||
| file(GLOB OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/*.cc | |||||
| ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc) | ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ops/populate/*.cc) | ||||
| file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | ||||
| @@ -88,8 +88,7 @@ set(LITE_SRC | |||||
| ${SRC_DIR}/sub_graph_kernel.cc | ${SRC_DIR}/sub_graph_kernel.cc | ||||
| ${SRC_DIR}/lite_session.cc | ${SRC_DIR}/lite_session.cc | ||||
| ${SRC_DIR}/executor.cc | ${SRC_DIR}/executor.cc | ||||
| ${SRC_DIR}/model.cc | |||||
| ${SRC_DIR}/model_common.cc | |||||
| ${SRC_DIR}/lite_model.cc | |||||
| ${SRC_DIR}/errorcode.cc | ${SRC_DIR}/errorcode.cc | ||||
| ) | ) | ||||
| if (SUPPORT_TRAIN) | if (SUPPORT_TRAIN) | ||||
| @@ -1581,6 +1581,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||||
| flatbuffers::FlatBufferBuilder builder(1024); | flatbuffers::FlatBufferBuilder builder(1024); | ||||
| auto offset = schema::MetaGraph::Pack(builder, meta_graph); | auto offset = schema::MetaGraph::Pack(builder, meta_graph); | ||||
| builder.Finish(offset); | builder.Finish(offset); | ||||
| schema::FinishMetaGraphBuffer(builder, offset); | |||||
| size_t size = builder.GetSize(); | size_t size = builder.GetSize(); | ||||
| auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); | auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); | ||||
| if (content == nullptr) { | if (content == nullptr) { | ||||
| @@ -1662,6 +1663,7 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr func_graph) { | |||||
| flatbuffers::FlatBufferBuilder int8_builder(1024); | flatbuffers::FlatBufferBuilder int8_builder(1024); | ||||
| auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); | auto int8_offset = schema::MetaGraph::Pack(int8_builder, int8_meta_graph); | ||||
| int8_builder.Finish(int8_offset); | int8_builder.Finish(int8_offset); | ||||
| schema::FinishMetaGraphBuffer(int8_builder, int8_offset); | |||||
| size = int8_builder.GetSize(); | size = int8_builder.GetSize(); | ||||
| auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer()); | auto *int8_content = reinterpret_cast<const char *>(int8_builder.GetBufferPointer()); | ||||
| if (int8_content == nullptr) { | if (int8_content == nullptr) { | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include "tools/common/flag_parser.h" | #include "tools/common/flag_parser.h" | ||||
| #include "src/common/file_utils.h" | #include "src/common/file_utils.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "schema/model_generated.h" | |||||
| #include "include/lite_session.h" | #include "include/lite_session.h" | ||||
| #include "tools/lib_cropper/cropper_flags.h" | #include "tools/lib_cropper/cropper_flags.h" | ||||