|
|
|
@@ -104,14 +104,13 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) { |
|
|
|
return parameter; |
|
|
|
} |
|
|
|
kernel::LiteKernel *GetLiteKernel(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs, OpParameter *parameter, |
|
|
|
mindspore::lite::PrimitiveC *primitive) { |
|
|
|
lite::Context *context, mindspore::lite::PrimitiveC *primitive) { |
|
|
|
MS_ASSERT(nullptr != lite_primitive); |
|
|
|
auto data_type = inputs.front()->data_type(); |
|
|
|
kernel::KernelKey desc{kernel::KERNEL_ARCH::kCPU, data_type, (schema::PrimitiveType)primitive->Type()}; |
|
|
|
lite::Context context; |
|
|
|
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc); |
|
|
|
if (creator != nullptr) { |
|
|
|
auto lite_kernel = creator(inputs, outputs, parameter, &context, desc, primitive); |
|
|
|
auto lite_kernel = creator(inputs, outputs, parameter, context, desc, primitive); |
|
|
|
return lite_kernel; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
@@ -235,7 +234,8 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An |
|
|
|
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type())); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, lite_primitive.get()); |
|
|
|
lite::Context context; |
|
|
|
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get()); |
|
|
|
if (lite_kernel == nullptr) { |
|
|
|
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; |
|
|
|
FreeTensors(&input_tensors, &output_tensors); |
|
|
|
|