|
|
|
@@ -62,6 +62,7 @@ class LiteKernel { |
|
|
|
const lite::Primitive *primitive) |
|
|
|
: opParameter(parameter), inputs_(inputs), outputs_(outputs), primitive_(primitive), |
|
|
|
context_(ctx) { |
|
|
|
opParameter->thread_num_ = ctx->thread_num_; |
|
|
|
this->in_kernel_.clear(); |
|
|
|
this->out_kernel_.clear(); |
|
|
|
} |
|
|
|
@@ -69,12 +70,13 @@ class LiteKernel { |
|
|
|
virtual ~LiteKernel() { delete opParameter; } |
|
|
|
|
|
|
|
virtual int Prepare() { |
|
|
|
if (primitive_ != nullptr && !primitive_->GetInferFlag()) { |
|
|
|
if (!InferShapeDone()) { |
|
|
|
(const_cast<lite::Primitive *>(primitive_))->InferShape(inputs_, outputs_); |
|
|
|
if (need_reinit) { |
|
|
|
Init(); |
|
|
|
} |
|
|
|
} |
|
|
|
if (need_reinit) { |
|
|
|
Init(); |
|
|
|
} |
|
|
|
|
|
|
|
auto &outputs = this->GetOutputs(); |
|
|
|
for (auto *output : outputs) { |
|
|
|
MS_ASSERT(output != nullptr); |
|
|
|
@@ -126,6 +128,13 @@ class LiteKernel { |
|
|
|
} |
|
|
|
|
|
|
|
protected: |
|
|
|
bool InferShapeDone() { |
|
|
|
if (primitive_ != nullptr && !primitive_->GetInferFlag()) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
KernelKey desc; |
|
|
|
std::string name; |
|
|
|
OpParameter *opParameter = nullptr; |
|
|
|
|