| @@ -24,6 +24,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| int GetPrimitiveType(const void *primitive) { | int GetPrimitiveType(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | |||||
| if (primitive == nullptr) { | if (primitive == nullptr) { | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| @@ -51,6 +52,7 @@ const char *PrimitiveCurVersionTypeName(int type) { | |||||
| int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; } | int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; } | ||||
| bool IsPartialNode(const void *primitive) { | bool IsPartialNode(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | |||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | ||||
| if (schema_version == SCHEMA_CUR) { | if (schema_version == SCHEMA_CUR) { | ||||
| return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion; | return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion; | ||||
| @@ -65,9 +67,11 @@ bool IsPartialNode(const void *primitive) { | |||||
| } | } | ||||
| int GetPartialGraphIndex(const void *primitive) { | int GetPartialGraphIndex(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | |||||
| int index = -1; | int index = -1; | ||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | ||||
| if (schema_version == SCHEMA_CUR) { | if (schema_version == SCHEMA_CUR) { | ||||
| MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion() != nullptr); | |||||
| index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index(); | index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index(); | ||||
| } | } | ||||
| #ifdef ENABLE_V0 | #ifdef ENABLE_V0 | ||||
| @@ -79,6 +83,7 @@ int GetPartialGraphIndex(const void *primitive) { | |||||
| } | } | ||||
| bool IsWhileNode(const void *primitive) { | bool IsWhileNode(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | |||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | ||||
| if (schema_version == SCHEMA_CUR) { | if (schema_version == SCHEMA_CUR) { | ||||
| return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While; | return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While; | ||||
| @@ -92,13 +97,16 @@ bool IsWhileNode(const void *primitive) { | |||||
| } | } | ||||
| int GetWhileBodySubgraphIndex(const void *primitive) { | int GetWhileBodySubgraphIndex(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | |||||
| int index = -1; | int index = -1; | ||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | ||||
| if (schema_version == SCHEMA_CUR) { | if (schema_version == SCHEMA_CUR) { | ||||
| MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); | |||||
| index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index(); | index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index(); | ||||
| } | } | ||||
| #ifdef ENABLE_V0 | #ifdef ENABLE_V0 | ||||
| if (schema_version == SCHEMA_V0) { | if (schema_version == SCHEMA_V0) { | ||||
| MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); | |||||
| index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex(); | index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -106,13 +114,16 @@ int GetWhileBodySubgraphIndex(const void *primitive) { | |||||
| } | } | ||||
| int GetWhileCondSubgraphIndex(const void *primitive) { | int GetWhileCondSubgraphIndex(const void *primitive) { | ||||
| MS_ASSERT(primitive != nullptr); | |||||
| int index = -1; | int index = -1; | ||||
| int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); | ||||
| if (schema_version == SCHEMA_CUR) { | if (schema_version == SCHEMA_CUR) { | ||||
| MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); | |||||
| index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index(); | index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index(); | ||||
| } | } | ||||
| #ifdef ENABLE_V0 | #ifdef ENABLE_V0 | ||||
| if (schema_version == SCHEMA_V0) { | if (schema_version == SCHEMA_V0) { | ||||
| MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr); | |||||
| index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex(); | index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex(); | ||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -128,7 +128,7 @@ int TensorList2TensorListC(TensorList *src, TensorListC *dst) { | |||||
| return NNACL_OK; | return NNACL_OK; | ||||
| } | } | ||||
| void TensorListC2TensorList(TensorListC *src, TensorList *dst) { | |||||
| int TensorListC2TensorList(TensorListC *src, TensorList *dst) { | |||||
| dst->set_data_type(static_cast<TypeId>(src->data_type_)); | dst->set_data_type(static_cast<TypeId>(src->data_type_)); | ||||
| dst->set_format(static_cast<schema::Format>(src->format_)); | dst->set_format(static_cast<schema::Format>(src->format_)); | ||||
| dst->set_shape(std::vector<int>(1, src->element_num_)); | dst->set_shape(std::vector<int>(1, src->element_num_)); | ||||
| @@ -136,11 +136,17 @@ void TensorListC2TensorList(TensorListC *src, TensorList *dst) { | |||||
| // Set Tensors | // Set Tensors | ||||
| for (size_t i = 0; i < src->element_num_; i++) { | for (size_t i = 0; i < src->element_num_; i++) { | ||||
| if (dst->GetTensor(i) == nullptr) { | |||||
| MS_LOG(ERROR) << "Tensor i is null ptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i)); | TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i)); | ||||
| } | } | ||||
| dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_)); | dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_)); | ||||
| dst->set_max_elements_num(src->max_elements_num_); | dst->set_max_elements_num(src->max_elements_num_); | ||||
| return RET_OK; | |||||
| } | } | ||||
| int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | ||||
| @@ -189,6 +195,7 @@ int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite | |||||
| memset(tensor_list_c, 0, sizeof(TensorListC)); | memset(tensor_list_c, 0, sizeof(TensorListC)); | ||||
| ret = TensorList2TensorListC(tensor_list, tensor_list_c); | ret = TensorList2TensorListC(tensor_list, tensor_list_c); | ||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| free(tensor_list_c); | |||||
| return NNACL_ERR; | return NNACL_ERR; | ||||
| } | } | ||||
| in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c)); | in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c)); | ||||
| @@ -31,7 +31,7 @@ void FreeTensorListC(TensorListC *tensorListC); | |||||
| void Tensor2TensorC(Tensor *src, TensorC *dst); | void Tensor2TensorC(Tensor *src, TensorC *dst); | ||||
| void TensorC2Tensor(TensorC *src, Tensor *dst); | void TensorC2Tensor(TensorC *src, Tensor *dst); | ||||
| int TensorList2TensorListC(TensorList *src, TensorListC *dst); | int TensorList2TensorListC(TensorList *src, TensorListC *dst); | ||||
| void TensorListC2TensorList(TensorListC *src, TensorList *dst); | |||||
| int TensorListC2TensorList(TensorListC *src, TensorList *dst); | |||||
| int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, | ||||
| std::vector<TensorC *> *out_tensor_c); | std::vector<TensorC *> *out_tensor_c); | ||||
| int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs, | int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| @@ -140,7 +140,7 @@ class MSTensor::Impl { | |||||
| virtual bool IsDevice() const { return false; } | virtual bool IsDevice() const { return false; } | ||||
| tensor::MSTensor *lite_tensor() { return lite_tensor_; } | |||||
| tensor::MSTensor *lite_tensor() const { return lite_tensor_; } | |||||
| Status set_lite_tensor(tensor::MSTensor *tensor) { | Status set_lite_tensor(tensor::MSTensor *tensor) { | ||||
| if (tensor == nullptr) { | if (tensor == nullptr) { | ||||
| @@ -48,6 +48,7 @@ class RegisterKernelInterface { | |||||
| public: | public: | ||||
| static RegisterKernelInterface *Instance(); | static RegisterKernelInterface *Instance(); | ||||
| int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); | int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); | ||||
| virtual ~RegisterKernelInterface() = default; | |||||
| private: | private: | ||||
| RegisterKernelInterface() = default; | RegisterKernelInterface() = default; | ||||
| @@ -31,6 +31,7 @@ class KernelInterfaceRegistry { | |||||
| } | } | ||||
| int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); | int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); | ||||
| virtual ~KernelInterfaceRegistry() = default; | |||||
| private: | private: | ||||
| KernelInterfaceRegistry() = default; | KernelInterfaceRegistry() = default; | ||||
| @@ -37,6 +37,8 @@ class SchemaRegisterImpl { | |||||
| GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; } | GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; } | ||||
| virtual ~SchemaRegisterImpl() = default; | |||||
| private: | private: | ||||
| std::vector<GetSchemaDef> op_def_funcs_; | std::vector<GetSchemaDef> op_def_funcs_; | ||||
| GetSchemaDef prim_type_gen_; | GetSchemaDef prim_type_gen_; | ||||
| @@ -45,11 +47,13 @@ class SchemaRegisterImpl { | |||||
| class SchemaOpRegister { | class SchemaOpRegister { | ||||
| public: | public: | ||||
| explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); } | explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); } | ||||
| virtual ~SchemaOpRegister() = default; | |||||
| }; | }; | ||||
| class PrimitiveTypeRegister { | class PrimitiveTypeRegister { | ||||
| public: | public: | ||||
| explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); } | explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); } | ||||
| virtual ~PrimitiveTypeRegister() = default; | |||||
| }; | }; | ||||
| } // namespace mindspore::lite::ops | } // namespace mindspore::lite::ops | ||||
| @@ -31,6 +31,7 @@ class NPUInsertTransformPass : public NPUBasePass { | |||||
| name_ = "NPUInsertTransformPass"; | name_ = "NPUInsertTransformPass"; | ||||
| } | } | ||||
| virtual ~NPUInsertTransformPass() = default; | |||||
| int Run() override; | int Run() override; | ||||
| private: | private: | ||||
| @@ -36,6 +36,8 @@ class NPUTransformPass : public NPUBasePass { | |||||
| name_ = "NPUTransformPass"; | name_ = "NPUTransformPass"; | ||||
| } | } | ||||
| virtual ~NPUTransformPass() = default; | |||||
| private: | private: | ||||
| int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels); | int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels); | ||||
| @@ -37,7 +37,7 @@ class GraphDefTransform { | |||||
| virtual ~GraphDefTransform(); | virtual ~GraphDefTransform(); | ||||
| virtual int Transform(const converter::Flags &ctx); | virtual int Transform(const converter::Flags &ctx); | ||||
| void SetGraphDef(schema::MetaGraphT *dst_def); | void SetGraphDef(schema::MetaGraphT *dst_def); | ||||
| inline schema::MetaGraphT *GetOutput() { return graph_defT_; } | |||||
| inline schema::MetaGraphT *GetOutput() const { return graph_defT_; } | |||||
| protected: | protected: | ||||
| std::vector<schema::CNodeT *> GetGraphNodes(); | std::vector<schema::CNodeT *> GetGraphNodes(); | ||||
| @@ -57,6 +57,7 @@ class RegistryPrimitiveAdjust { | |||||
| RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) { | RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) { | ||||
| PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator); | PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator); | ||||
| } | } | ||||
| virtual ~RegistryPrimitiveAdjust() = default; | |||||
| }; | }; | ||||
| #define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \ | #define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \ | ||||