From: @yeyunpeng2020 Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangtags/v1.1.0
| @@ -102,8 +102,13 @@ kernel::LiteKernel *KernelRegistry::GetKernel(const std::vector<Tensor *> &in_te | |||
| const InnerContext *ctx, const kernel::KernelKey &key) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != ctx); | |||
| auto parameter = | |||
| PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); | |||
| auto func_pointer = PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type())); | |||
| if (func_pointer == nullptr) { | |||
| MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: " | |||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | |||
| return nullptr; | |||
| } | |||
| auto parameter = func_pointer(primitive); | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | |||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | |||
| @@ -18,6 +18,7 @@ | |||
| #include <algorithm> | |||
| #include <queue> | |||
| #include "src/tensor.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore::kernel { | |||
| using mindspore::lite::RET_ERROR; | |||
| @@ -196,7 +197,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||
| if (outer_in_kernels.empty()) { | |||
| for (auto &in_kernel_in_tensor : in_kernel_in_tensors) { | |||
| if (!in_kernel_in_tensor->IsConst()) { | |||
| input_tensors.push_back(in_kernel_in_tensor); | |||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||
| input_tensors.push_back(in_kernel_in_tensor); | |||
| } | |||
| } | |||
| } | |||
| continue; | |||
| @@ -211,7 +214,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphInputTensors(const std::vect | |||
| auto outer_in_kernel_out_tensors_iter = | |||
| std::find(outer_in_kernel_out_tensors.begin(), outer_in_kernel_out_tensors.end(), in_kernel_in_tensor); | |||
| if (outer_in_kernel_out_tensors_iter != outer_in_kernel_out_tensors.end()) { | |||
| input_tensors.emplace_back(in_kernel_in_tensor); | |||
| if (!lite::IsContain(input_tensors, in_kernel_in_tensor)) { | |||
| input_tensors.emplace_back(in_kernel_in_tensor); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -226,7 +231,11 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||
| auto &outer_out_kernels = output_kernel->out_kernels(); | |||
| auto &out_kernel_out_tensors = output_kernel->out_tensors(); | |||
| if (outer_out_kernels.empty()) { | |||
| output_tensors.insert(output_tensors.end(), out_kernel_out_tensors.begin(), out_kernel_out_tensors.end()); | |||
| for (auto out_kernel_out_tensor : out_kernel_out_tensors) { | |||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||
| output_tensors.push_back(out_kernel_out_tensor); | |||
| } | |||
| } | |||
| continue; | |||
| } | |||
| for (auto outer_out_kernel : outer_out_kernels) { | |||
| @@ -239,7 +248,9 @@ std::vector<lite::Tensor *> LiteKernelUtil::SubgraphOutputTensors(const std::vec | |||
| auto outer_out_kernel_in_tensors_iter = | |||
| std::find(outer_out_kernel_in_tensors.begin(), outer_out_kernel_in_tensors.end(), out_kernel_out_tensor); | |||
| if (outer_out_kernel_in_tensors_iter != outer_out_kernel_in_tensors.end()) { | |||
| output_tensors.emplace_back(out_kernel_out_tensor); | |||
| if (!lite::IsContain(output_tensors, out_kernel_out_tensor)) { | |||
| output_tensors.emplace_back(out_kernel_out_tensor); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -50,7 +50,14 @@ bool ConvertNodes(const T &meta_graph, Model *model, int schema_version = SCHEMA | |||
| node->primitive_ = PrimitiveC::Create(const_cast<schema::Primitive *>(src_prim)); | |||
| #else | |||
| auto primitive = const_cast<schema::Primitive *>(src_prim); | |||
| node->primitive_ = OpsRegistry::GetInstance()->getPrimitiveCreator(primitive->value_type())(primitive); | |||
| auto func_pointer = OpsRegistry::GetInstance()->GetPrimitiveCreator(primitive->value_type()); | |||
| if (func_pointer == nullptr) { | |||
| MS_LOG(ERROR) << "PrimitiveCreator function pointer is nullptr, type: " | |||
| << schema::EnumNamePrimitiveType(primitive->value_type()); | |||
| delete node; | |||
| return false; | |||
| } | |||
| node->primitive_ = func_pointer(primitive); | |||
| #endif | |||
| if (node->primitive_ == nullptr) { | |||
| MS_LOG(ERROR) << "unpack primitive == nullptr!"; | |||
| @@ -28,10 +28,10 @@ class OpsRegistry { | |||
| return ®istry; | |||
| } | |||
| void insertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { | |||
| void InsertPrimitiveCMap(schema::PrimitiveType type, PrimitiveCCreator creator) { | |||
| primitive_creators[type] = creator; | |||
| } | |||
| PrimitiveCCreator getPrimitiveCreator(schema::PrimitiveType type) { | |||
| PrimitiveCCreator GetPrimitiveCreator(schema::PrimitiveType type) { | |||
| if (primitive_creators.find(type) != primitive_creators.end()) { | |||
| return primitive_creators[type]; | |||
| } else { | |||
| @@ -47,7 +47,7 @@ class OpsRegistry { | |||
| class Registry { | |||
| public: | |||
| Registry(schema::PrimitiveType primitive_type, PrimitiveCCreator creator) { | |||
| OpsRegistry::GetInstance()->insertPrimitiveCMap(primitive_type, creator); | |||
| OpsRegistry::GetInstance()->InsertPrimitiveCMap(primitive_type, creator); | |||
| } | |||
| }; | |||
| @@ -18,6 +18,7 @@ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_OP_POPULATE_REGISTER_H | |||
| #include <map> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| @@ -29,9 +30,9 @@ class PopulateRegistry { | |||
| return ®istry; | |||
| } | |||
| void insertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } | |||
| void InsertParameterMap(schema::PrimitiveType type, ParameterCreator creator) { parameter_creators[type] = creator; } | |||
| ParameterCreator getParameterCreator(schema::PrimitiveType type) { | |||
| ParameterCreator GetParameterCreator(schema::PrimitiveType type) { | |||
| if (parameter_creators.find(type) != parameter_creators.end()) { | |||
| return parameter_creators[type]; | |||
| } else { | |||
| @@ -47,7 +48,7 @@ class PopulateRegistry { | |||
| class Registry { | |||
| public: | |||
| Registry(schema::PrimitiveType primitive_type, ParameterCreator creator) { | |||
| PopulateRegistry::GetInstance()->insertParameterMap(primitive_type, creator); | |||
| PopulateRegistry::GetInstance()->InsertParameterMap(primitive_type, creator); | |||
| } | |||
| ~Registry() = default; | |||
| }; | |||
| @@ -244,8 +244,14 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| auto primitive = lite_primitive.get(); | |||
| MS_ASSERT(primitive != nullptr); | |||
| MS_ASSERT(primitive->Type() != nullptr); | |||
| auto parameter = | |||
| lite::PopulateRegistry::GetInstance()->getParameterCreator(schema::PrimitiveType(primitive->Type()))(primitive); | |||
| auto func_pointer = | |||
| lite::PopulateRegistry::GetInstance()->GetParameterCreator(schema::PrimitiveType(primitive->Type())); | |||
| if (func_pointer == nullptr) { | |||
| MS_LOG(ERROR) << "ParameterCreator function pointer is nullptr, type: " | |||
| << schema::EnumNamePrimitiveType((schema::PrimitiveType)primitive->Type()); | |||
| return nullptr; | |||
| } | |||
| auto parameter = func_pointer(primitive); | |||
| if (parameter == nullptr) { | |||
| MS_LOG(ERROR) << "PopulateParameter return nullptr, type: " | |||