Browse Source

[MS][LITE] solve static check

[MS][LITE] add
pull/15689/head
cjh9368 4 years ago
parent
commit
c91b7d15fa
11 changed files with 32 additions and 4 deletions
  1. +11
    -0
      mindspore/lite/src/common/prim_util.cc
  2. +8
    -1
      mindspore/lite/src/common/tensor_util.cc
  3. +1
    -1
      mindspore/lite/src/common/tensor_util.h
  4. +1
    -1
      mindspore/lite/src/cxx_api/tensor/tensor_impl.h
  5. +1
    -0
      mindspore/lite/src/kernel_interface.h
  6. +1
    -0
      mindspore/lite/src/kernel_interface_registry.h
  7. +4
    -0
      mindspore/lite/src/ops/schema_register.h
  8. +1
    -0
      mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h
  9. +2
    -0
      mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h
  10. +1
    -1
      mindspore/lite/tools/converter/graphdef_transform.h
  11. +1
    -0
      mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.h

+ 11
- 0
mindspore/lite/src/common/prim_util.cc View File

@@ -24,6 +24,7 @@
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
int GetPrimitiveType(const void *primitive) { int GetPrimitiveType(const void *primitive) {
MS_ASSERT(primitive != nullptr);
if (primitive == nullptr) { if (primitive == nullptr) {
return -1; return -1;
} }
@@ -51,6 +52,7 @@ const char *PrimitiveCurVersionTypeName(int type) {
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; } int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }


bool IsPartialNode(const void *primitive) { bool IsPartialNode(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) { if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion; return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion;
@@ -65,9 +67,11 @@ bool IsPartialNode(const void *primitive) {
} }


int GetPartialGraphIndex(const void *primitive) { int GetPartialGraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1; int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) { if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion() != nullptr);
index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index(); index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index();
} }
#ifdef ENABLE_V0 #ifdef ENABLE_V0
@@ -79,6 +83,7 @@ int GetPartialGraphIndex(const void *primitive) {
} }


bool IsWhileNode(const void *primitive) { bool IsWhileNode(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) { if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While; return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While;
@@ -92,13 +97,16 @@ bool IsWhileNode(const void *primitive) {
} }


int GetWhileBodySubgraphIndex(const void *primitive) { int GetWhileBodySubgraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1; int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) { if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index(); index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index();
} }
#ifdef ENABLE_V0 #ifdef ENABLE_V0
if (schema_version == SCHEMA_V0) { if (schema_version == SCHEMA_V0) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex(); index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex();
} }
#endif #endif
@@ -106,13 +114,16 @@ int GetWhileBodySubgraphIndex(const void *primitive) {
} }


int GetWhileCondSubgraphIndex(const void *primitive) { int GetWhileCondSubgraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1; int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion(); int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) { if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index(); index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index();
} }
#ifdef ENABLE_V0 #ifdef ENABLE_V0
if (schema_version == SCHEMA_V0) { if (schema_version == SCHEMA_V0) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex(); index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex();
} }
#endif #endif


+ 8
- 1
mindspore/lite/src/common/tensor_util.cc View File

@@ -128,7 +128,7 @@ int TensorList2TensorListC(TensorList *src, TensorListC *dst) {
return NNACL_OK; return NNACL_OK;
} }


void TensorListC2TensorList(TensorListC *src, TensorList *dst) {
int TensorListC2TensorList(TensorListC *src, TensorList *dst) {
dst->set_data_type(static_cast<TypeId>(src->data_type_)); dst->set_data_type(static_cast<TypeId>(src->data_type_));
dst->set_format(static_cast<schema::Format>(src->format_)); dst->set_format(static_cast<schema::Format>(src->format_));
dst->set_shape(std::vector<int>(1, src->element_num_)); dst->set_shape(std::vector<int>(1, src->element_num_));
@@ -136,11 +136,17 @@ void TensorListC2TensorList(TensorListC *src, TensorList *dst) {


// Set Tensors // Set Tensors
for (size_t i = 0; i < src->element_num_; i++) { for (size_t i = 0; i < src->element_num_; i++) {
if (dst->GetTensor(i) == nullptr) {
MS_LOG(ERROR) << "Tensor i is null ptr";
return RET_NULL_PTR;
}

TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i)); TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i));
} }


dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_)); dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_));
dst->set_max_elements_num(src->max_elements_num_); dst->set_max_elements_num(src->max_elements_num_);
return RET_OK;
} }


int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
@@ -189,6 +195,7 @@ int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite
memset(tensor_list_c, 0, sizeof(TensorListC)); memset(tensor_list_c, 0, sizeof(TensorListC));
ret = TensorList2TensorListC(tensor_list, tensor_list_c); ret = TensorList2TensorListC(tensor_list, tensor_list_c);
if (ret != RET_OK) { if (ret != RET_OK) {
free(tensor_list_c);
return NNACL_ERR; return NNACL_ERR;
} }
in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c)); in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c));


+ 1
- 1
mindspore/lite/src/common/tensor_util.h View File

@@ -31,7 +31,7 @@ void FreeTensorListC(TensorListC *tensorListC);
void Tensor2TensorC(Tensor *src, TensorC *dst); void Tensor2TensorC(Tensor *src, TensorC *dst);
void TensorC2Tensor(TensorC *src, Tensor *dst); void TensorC2Tensor(TensorC *src, Tensor *dst);
int TensorList2TensorListC(TensorList *src, TensorListC *dst); int TensorList2TensorListC(TensorList *src, TensorListC *dst);
void TensorListC2TensorList(TensorListC *src, TensorList *dst);
int TensorListC2TensorList(TensorListC *src, TensorList *dst);
int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs, int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
std::vector<TensorC *> *out_tensor_c); std::vector<TensorC *> *out_tensor_c);
int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs, int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs,


+ 1
- 1
mindspore/lite/src/cxx_api/tensor/tensor_impl.h View File

@@ -140,7 +140,7 @@ class MSTensor::Impl {


virtual bool IsDevice() const { return false; } virtual bool IsDevice() const { return false; }


tensor::MSTensor *lite_tensor() { return lite_tensor_; }
tensor::MSTensor *lite_tensor() const { return lite_tensor_; }


Status set_lite_tensor(tensor::MSTensor *tensor) { Status set_lite_tensor(tensor::MSTensor *tensor) {
if (tensor == nullptr) { if (tensor == nullptr) {


+ 1
- 0
mindspore/lite/src/kernel_interface.h View File

@@ -48,6 +48,7 @@ class RegisterKernelInterface {
public: public:
static RegisterKernelInterface *Instance(); static RegisterKernelInterface *Instance();
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator); int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);
virtual ~RegisterKernelInterface() = default;


private: private:
RegisterKernelInterface() = default; RegisterKernelInterface() = default;


+ 1
- 0
mindspore/lite/src/kernel_interface_registry.h View File

@@ -31,6 +31,7 @@ class KernelInterfaceRegistry {
} }


int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator); int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry() = default;


private: private:
KernelInterfaceRegistry() = default; KernelInterfaceRegistry() = default;


+ 4
- 0
mindspore/lite/src/ops/schema_register.h View File

@@ -37,6 +37,8 @@ class SchemaRegisterImpl {


GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; } GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; }


virtual ~SchemaRegisterImpl() = default;

private: private:
std::vector<GetSchemaDef> op_def_funcs_; std::vector<GetSchemaDef> op_def_funcs_;
GetSchemaDef prim_type_gen_; GetSchemaDef prim_type_gen_;
@@ -45,11 +47,13 @@ class SchemaRegisterImpl {
class SchemaOpRegister { class SchemaOpRegister {
public: public:
explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); } explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); }
virtual ~SchemaOpRegister() = default;
}; };


class PrimitiveTypeRegister { class PrimitiveTypeRegister {
public: public:
explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); } explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); }
virtual ~PrimitiveTypeRegister() = default;
}; };
} // namespace mindspore::lite::ops } // namespace mindspore::lite::ops




+ 1
- 0
mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h View File

@@ -31,6 +31,7 @@ class NPUInsertTransformPass : public NPUBasePass {
name_ = "NPUInsertTransformPass"; name_ = "NPUInsertTransformPass";
} }


virtual ~NPUInsertTransformPass() = default;
int Run() override; int Run() override;


private: private:


+ 2
- 0
mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h View File

@@ -36,6 +36,8 @@ class NPUTransformPass : public NPUBasePass {
name_ = "NPUTransformPass"; name_ = "NPUTransformPass";
} }


virtual ~NPUTransformPass() = default;

private: private:
int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels); int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels);




+ 1
- 1
mindspore/lite/tools/converter/graphdef_transform.h View File

@@ -37,7 +37,7 @@ class GraphDefTransform {
virtual ~GraphDefTransform(); virtual ~GraphDefTransform();
virtual int Transform(const converter::Flags &ctx); virtual int Transform(const converter::Flags &ctx);
void SetGraphDef(schema::MetaGraphT *dst_def); void SetGraphDef(schema::MetaGraphT *dst_def);
inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
inline schema::MetaGraphT *GetOutput() const { return graph_defT_; }


protected: protected:
std::vector<schema::CNodeT *> GetGraphNodes(); std::vector<schema::CNodeT *> GetGraphNodes();


+ 1
- 0
mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.h View File

@@ -57,6 +57,7 @@ class RegistryPrimitiveAdjust {
RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) { RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) {
PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator); PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator);
} }
virtual ~RegistryPrimitiveAdjust() = default;
}; };


#define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \ #define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \


Loading…
Cancel
Save