Browse Source

[MS][LITE] solve static check

[MS][LITE] add
pull/15689/head
cjh9368 4 years ago
parent
commit
c91b7d15fa
11 changed files with 32 additions and 4 deletions
  1. +11
    -0
      mindspore/lite/src/common/prim_util.cc
  2. +8
    -1
      mindspore/lite/src/common/tensor_util.cc
  3. +1
    -1
      mindspore/lite/src/common/tensor_util.h
  4. +1
    -1
      mindspore/lite/src/cxx_api/tensor/tensor_impl.h
  5. +1
    -0
      mindspore/lite/src/kernel_interface.h
  6. +1
    -0
      mindspore/lite/src/kernel_interface_registry.h
  7. +4
    -0
      mindspore/lite/src/ops/schema_register.h
  8. +1
    -0
      mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h
  9. +2
    -0
      mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h
  10. +1
    -1
      mindspore/lite/tools/converter/graphdef_transform.h
  11. +1
    -0
      mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.h

+ 11
- 0
mindspore/lite/src/common/prim_util.cc View File

@@ -24,6 +24,7 @@
namespace mindspore {
namespace lite {
int GetPrimitiveType(const void *primitive) {
MS_ASSERT(primitive != nullptr);
if (primitive == nullptr) {
return -1;
}
@@ -51,6 +52,7 @@ const char *PrimitiveCurVersionTypeName(int type) {
int GenPrimVersionKey(int primitive_type, int schema_version) { return primitive_type * 1000 + schema_version; }

bool IsPartialNode(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_PartialFusion;
@@ -65,9 +67,11 @@ bool IsPartialNode(const void *primitive) {
}

int GetPartialGraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion() != nullptr);
index = static_cast<const schema::Primitive *>(primitive)->value_as_PartialFusion()->sub_graph_index();
}
#ifdef ENABLE_V0
@@ -79,6 +83,7 @@ int GetPartialGraphIndex(const void *primitive) {
}

bool IsWhileNode(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
return reinterpret_cast<const schema::Primitive *>(primitive)->value_type() == schema::PrimitiveType_While;
@@ -92,13 +97,16 @@ bool IsWhileNode(const void *primitive) {
}

int GetWhileBodySubgraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->body_subgraph_index();
}
#ifdef ENABLE_V0
if (schema_version == SCHEMA_V0) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->bodySubgraphIndex();
}
#endif
@@ -106,13 +114,16 @@ int GetWhileBodySubgraphIndex(const void *primitive) {
}

int GetWhileCondSubgraphIndex(const void *primitive) {
MS_ASSERT(primitive != nullptr);
int index = -1;
int schema_version = VersionManager::GetInstance()->GetSchemaVersion();
if (schema_version == SCHEMA_CUR) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::Primitive *>(primitive)->value_as_While()->cond_subgraph_index();
}
#ifdef ENABLE_V0
if (schema_version == SCHEMA_V0) {
MS_ASSERT(static_cast<const schema::Primitive *>(primitive)->value_as_While() != nullptr);
index = reinterpret_cast<const schema::v0::Primitive *>(primitive)->value_as_While()->condSubgraphIndex();
}
#endif


+ 8
- 1
mindspore/lite/src/common/tensor_util.cc View File

@@ -128,7 +128,7 @@ int TensorList2TensorListC(TensorList *src, TensorListC *dst) {
return NNACL_OK;
}

void TensorListC2TensorList(TensorListC *src, TensorList *dst) {
int TensorListC2TensorList(TensorListC *src, TensorList *dst) {
dst->set_data_type(static_cast<TypeId>(src->data_type_));
dst->set_format(static_cast<schema::Format>(src->format_));
dst->set_shape(std::vector<int>(1, src->element_num_));
@@ -136,11 +136,17 @@ void TensorListC2TensorList(TensorListC *src, TensorList *dst) {

// Set Tensors
for (size_t i = 0; i < src->element_num_; i++) {
if (dst->GetTensor(i) == nullptr) {
MS_LOG(ERROR) << "Tensor i is null ptr";
return RET_NULL_PTR;
}

TensorC2Tensor(&src->tensors_[i], dst->GetTensor(i));
}

dst->set_element_shape(std::vector<int>(src->element_shape_, src->element_shape_ + src->element_shape_size_));
dst->set_max_elements_num(src->max_elements_num_);
return RET_OK;
}

int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
@@ -189,6 +195,7 @@ int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite
memset(tensor_list_c, 0, sizeof(TensorListC));
ret = TensorList2TensorListC(tensor_list, tensor_list_c);
if (ret != RET_OK) {
free(tensor_list_c);
return NNACL_ERR;
}
in_tensor_c->push_back(reinterpret_cast<TensorC *>(tensor_list_c));


+ 1
- 1
mindspore/lite/src/common/tensor_util.h View File

@@ -31,7 +31,7 @@ void FreeTensorListC(TensorListC *tensorListC);
void Tensor2TensorC(Tensor *src, TensorC *dst);
void TensorC2Tensor(TensorC *src, Tensor *dst);
int TensorList2TensorListC(TensorList *src, TensorListC *dst);
void TensorListC2TensorList(TensorListC *src, TensorList *dst);
int TensorListC2TensorList(TensorListC *src, TensorList *dst);
int GenerateMergeSwitchOutTensorC(const std::vector<lite::Tensor *> &inputs, std::vector<lite::Tensor *> *outputs,
std::vector<TensorC *> *out_tensor_c);
int GenerateInTensorC(const OpParameter *const parameter, const std::vector<lite::Tensor *> &inputs,


+ 1
- 1
mindspore/lite/src/cxx_api/tensor/tensor_impl.h View File

@@ -140,7 +140,7 @@ class MSTensor::Impl {

virtual bool IsDevice() const { return false; }

tensor::MSTensor *lite_tensor() { return lite_tensor_; }
tensor::MSTensor *lite_tensor() const { return lite_tensor_; }

Status set_lite_tensor(tensor::MSTensor *tensor) {
if (tensor == nullptr) {


+ 1
- 0
mindspore/lite/src/kernel_interface.h View File

@@ -48,6 +48,7 @@ class RegisterKernelInterface {
public:
static RegisterKernelInterface *Instance();
int Reg(const std::string &vendor, const int op_type, KernelInterfaceCreator creator);
virtual ~RegisterKernelInterface() = default;

private:
RegisterKernelInterface() = default;


+ 1
- 0
mindspore/lite/src/kernel_interface_registry.h View File

@@ -31,6 +31,7 @@ class KernelInterfaceRegistry {
}

int Reg(const std::string &vendor, const int &op_type, kernel::KernelInterfaceCreator creator);
virtual ~KernelInterfaceRegistry() = default;

private:
KernelInterfaceRegistry() = default;


+ 4
- 0
mindspore/lite/src/ops/schema_register.h View File

@@ -37,6 +37,8 @@ class SchemaRegisterImpl {

GetSchemaDef GetPrimTypeGenFunc() const { return prim_type_gen_; }

virtual ~SchemaRegisterImpl() = default;

private:
std::vector<GetSchemaDef> op_def_funcs_;
GetSchemaDef prim_type_gen_;
@@ -45,11 +47,13 @@ class SchemaRegisterImpl {
class SchemaOpRegister {
public:
explicit SchemaOpRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->OpPush(func); }
virtual ~SchemaOpRegister() = default;
};

class PrimitiveTypeRegister {
public:
explicit PrimitiveTypeRegister(GetSchemaDef func) { SchemaRegisterImpl::Instance()->SetPrimTypeGenFunc(func); }
virtual ~PrimitiveTypeRegister() = default;
};
} // namespace mindspore::lite::ops



+ 1
- 0
mindspore/lite/src/runtime/agent/npu/optimizer/npu_insert_transform_pass.h View File

@@ -31,6 +31,7 @@ class NPUInsertTransformPass : public NPUBasePass {
name_ = "NPUInsertTransformPass";
}

virtual ~NPUInsertTransformPass() = default;
int Run() override;

private:


+ 2
- 0
mindspore/lite/src/runtime/agent/npu/optimizer/npu_transform_pass.h View File

@@ -36,6 +36,8 @@ class NPUTransformPass : public NPUBasePass {
name_ = "NPUTransformPass";
}

virtual ~NPUTransformPass() = default;

private:
int InsertPreNodes(kernel::LiteKernel *kernel, std::vector<kernel::LiteKernel *> *trans_kernels);



+ 1
- 1
mindspore/lite/tools/converter/graphdef_transform.h View File

@@ -37,7 +37,7 @@ class GraphDefTransform {
virtual ~GraphDefTransform();
virtual int Transform(const converter::Flags &ctx);
void SetGraphDef(schema::MetaGraphT *dst_def);
inline schema::MetaGraphT *GetOutput() { return graph_defT_; }
inline schema::MetaGraphT *GetOutput() const { return graph_defT_; }

protected:
std::vector<schema::CNodeT *> GetGraphNodes();


+ 1
- 0
mindspore/lite/tools/optimizer/graph/primitive_adjust_pass.h View File

@@ -57,6 +57,7 @@ class RegistryPrimitiveAdjust {
RegistryPrimitiveAdjust(const std::string &key, PrimitiveAdjustCreator creator) {
PrimitiveAdjustRegistry::GetInstance()->InsertPrimitiveAdjustMap(key, creator);
}
virtual ~RegistryPrimitiveAdjust() = default;
};

#define REGIST_PRIMITIVE_ADJUST(type, primitive_adjust_func) \


Loading…
Cancel
Save