From: @yeyunpeng2020 Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -102,8 +102,13 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te | |||||
| const InnerContext *ctx, const kernel::KernelKey &key) { | const InnerContext *ctx, const kernel::KernelKey &key) { | ||||
| MS_ASSERT(nullptr != primitive); | MS_ASSERT(nullptr != primitive); | ||||
| MS_ASSERT(nullptr != ctx); | MS_ASSERT(nullptr != ctx); | ||||
| auto parameter = | |||||
| PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); | |||||
| auto func_pointer = PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type())); | |||||
| if (func_pointer == nullptr) { | |||||
| MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | |||||
| return nullptr; | |||||
| } | |||||
| auto parameter = func_pointer(primitive); | |||||
| if (parameter == nullptr) { | if (parameter == nullptr) { | ||||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | ||||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <queue> | #include <queue> | ||||
| #include "src/tensor.h" | #include "src/tensor.h" | ||||
| #include "src/common/utils.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| @@ -196,7 +197,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||||
| if (outer_in_kernels.empty()) { | if (outer_in_kernels.empty()) { | ||||
| for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { | for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { | ||||
| if (!in_kernel_in_tensor->IsConst()) { | if (!in_kernel_in_tensor->IsConst()) { | ||||
| input_tensors.push_back(in_kernel_in_tensor); | |||||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||||
| input_tensors.push_back(in_kernel_in_tensor); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| continue; | continue; | ||||
| @@ -211,7 +214,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||||
| auto outer_in_kernel_out_tensors_iter = | auto outer_in_kernel_out_tensors_iter = | ||||
| std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); | std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); | ||||
| if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { | if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { | ||||
| input_tensors.emplace_back(in_kernel_in_tensor); | |||||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||||
| input_tensors.emplace_back(in_kernel_in_tensor); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -226,7 +231,11 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||||
| auto &outer_out_kernels = output_kernel->out_kernels(); | auto &outer_out_kernels = output_kernel->out_kernels(); | ||||
| auto &out_kernel_out_tensors = output_kernel->out_tensors(); | auto &out_kernel_out_tensors = output_kernel->out_tensors(); | ||||
| if (outer_out_kernels.empty()) { | if (outer_out_kernels.empty()) { | ||||
| output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end()); | |||||
| for (auto out_kernel_out_tensor : out_kernel_out_tensors) { | |||||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||||
| output_tensors.push_back(out_kernel_out_tensor); | |||||
| } | |||||
| } | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (auto outer_out_kernel : outer_out_kernels) { | for (auto outer_out_kernel : outer_out_kernels) { | ||||
| @@ -239,7 +248,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||||
| auto outer_out_kernel_in_tensors_iter = | auto outer_out_kernel_in_tensors_iter = | ||||
| std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); | std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); | ||||
| if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { | if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { | ||||
| output_tensors.emplace_back(out_kernel_out_tensor); | |||||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||||
| output_tensors.emplace_back(out_kernel_out_tensor); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -50,7 +50,14 @@ bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA | |||||
| node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | ||||
| #else | #else | ||||
| auto primitive = const_cast<schema::Primitive *>(src_prim); | auto primitive = const_cast<schema::Primitive *>(src_prim); | ||||
| node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); | |||||
| auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); | |||||
| if (func_pointer == nullptr) { | |||||
| MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType(primitive->value_type()); | |||||
| delete node; | |||||
| return false; | |||||
| } | |||||
| node->primitive_ = func_pointer(primitive); | |||||
| #endif | #endif | ||||
| if (node->primitive_ == nullptr) { | if (node->primitive_ == nullptr) { | ||||
| MS_LOG(ERROR) << "unpack primitive == nullptr!"; | MS_LOG(ERROR) << "unpack primitive == nullptr!"; | ||||
| @@ -28,10 +28,10 @@ class OpsRegistry { | |||||
| return ®istry; | return ®istry; | ||||
| } | } | ||||
| void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { | |||||
| void InsertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { | |||||
| primitive_creators[type] = creator; | primitive_creators[type] = creator; | ||||
| } | } | ||||
| PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) { | |||||
| PrimitiveCCreator GetPrimitiveCreator(schema::PrimitiveType type) { | |||||
| if (primitive_creators.find(type) != primitive_creators.end()) { | if (primitive_creators.find(type) != primitive_creators.end()) { | ||||
| return primitive_creators[type]; | return primitive_creators[type]; | ||||
| } else { | } else { | ||||
| @@ -47,7 +47,7 @@ class OpsRegistry { | |||||
| class Registry { | class Registry { | ||||
| public: | public: | ||||
| Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) { | Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) { | ||||
| OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator); | |||||
| OpsRegistry::GetInstance()->InsertPrimitiveCMap(primitive_type, creator); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_OP_POPULATE_REGISTER_H | #define LITE_MINDSPORE_LITE_C_OPS_OP_POPULATE_REGISTER_H | ||||
| #include <map> | #include <map> | ||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| @@ -29,9 +30,9 @@ class PopulateRegistry { | |||||
| return ®istry; | return ®istry; | ||||
| } | } | ||||
| void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } | |||||
| void InsertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } | |||||
| ParameterCreator getParameterCreator(schema::PrimitiveType type) { | |||||
| ParameterCreator GetParameterCreator(schema::PrimitiveType type) { | |||||
| if (parameter_creators.find(type) != parameter_creators.end()) { | if (parameter_creators.find(type) != parameter_creators.end()) { | ||||
| return parameter_creators[type]; | return parameter_creators[type]; | ||||
| } else { | } else { | ||||
| @@ -47,7 +48,7 @@ class PopulateRegistry { | |||||
| class Registry { | class Registry { | ||||
| public: | public: | ||||
| Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) { | Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) { | ||||
| PopulateRegistry::GetInstance()->insertParameterMap(primitive_type, creator); | |||||
| PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator); | |||||
| } | } | ||||
| ~Registry() = default; | ~Registry() = default; | ||||
| }; | }; | ||||
| @@ -244,8 +244,14 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||||
| auto primitive = lite_primitive.get(); | auto primitive = lite_primitive.get(); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||
| MS_ASSERT(primitive->Type() != nullptr); | MS_ASSERT(primitive->Type() != nullptr); | ||||
| auto parameter = | |||||
| lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); | |||||
| auto func_pointer = | |||||
| lite::PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type())); | |||||
| if (func_pointer == nullptr) { | |||||
| MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: " | |||||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | |||||
| return nullptr; | |||||
| } | |||||
| auto parameter = func_pointer(primitive); | |||||
| if (parameter == nullptr) { | if (parameter == nullptr) { | ||||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | ||||