| @@ -29,12 +29,10 @@ using mindspore::schema::PrimitiveType_ExpandDims; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int ExpandDimsCPUKernel::Init() { | int ExpandDimsCPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int ret = ReSize(); | |||||
| return ret; | |||||
| return ReSize(); | |||||
| } | } | ||||
| int ExpandDimsCPUKernel::ReSize() { | int ExpandDimsCPUKernel::ReSize() { | ||||
| @@ -35,18 +35,19 @@ constexpr int kOutputNum = 1; | |||||
| } // namespace | } // namespace | ||||
| int FillCPUKernel::Init() { | int FillCPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| return ReSize(); | |||||
| } | |||||
| int FillCPUKernel::ReSize() { | |||||
| data_size_ = out_tensors_.front()->ElementsNum(); | data_size_ = out_tensors_.front()->ElementsNum(); | ||||
| thread_sz_count_ = MSMIN(thread_count_, data_size_); | thread_sz_count_ = MSMIN(thread_count_, data_size_); | ||||
| thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); | thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| int FillCPUKernel::ReSize() { return RET_OK; } | |||||
| int FillCPUKernel::DoFill(int task_id) { | int FillCPUKernel::DoFill(int task_id) { | ||||
| int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); | int size = MSMIN(thread_sz_stride_, data_size_ - task_id * thread_sz_stride_); | ||||
| if (size <= 0) { | if (size <= 0) { | ||||
| @@ -32,7 +32,10 @@ namespace mindspore::kernel { | |||||
| int GatherCPUKernel::Init() { | int GatherCPUKernel::Init() { | ||||
| axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_; | axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_; | ||||
| batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_; | batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_; | ||||
| return RET_OK; | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | } | ||||
| int GatherCPUKernel::ReSize() { return RET_OK; } | int GatherCPUKernel::ReSize() { return RET_OK; } | ||||
| @@ -38,10 +38,17 @@ GatherNdCPUKernel::~GatherNdCPUKernel() { | |||||
| } | } | ||||
| int GatherNdCPUKernel::Init() { | int GatherNdCPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| return ReSize(); | |||||
| } | |||||
| int GatherNdCPUKernel::ReSize() { | |||||
| if (in_offset_ != nullptr) { | |||||
| free(in_offset_); | |||||
| in_offset_ = nullptr; | |||||
| } | |||||
| auto indices_tensor = in_tensors_.at(1); | auto indices_tensor = in_tensors_.at(1); | ||||
| auto indices_shape = indices_tensor->shape(); | auto indices_shape = indices_tensor->shape(); | ||||
| int indices_rank = indices_shape.size(); | int indices_rank = indices_shape.size(); | ||||
| @@ -59,16 +66,9 @@ int GatherNdCPUKernel::Init() { | |||||
| thread_sz_count_ = MSMIN(thread_count_, count_); | thread_sz_count_ = MSMIN(thread_count_, count_); | ||||
| thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); | thread_sz_stride_ = UP_DIV(count_, thread_sz_count_); | ||||
| int ret = ReSize(); | |||||
| return ret; | |||||
| } | |||||
| int GatherNdCPUKernel::ReSize() { | |||||
| auto in_shape = in_tensors_.front()->shape(); | auto in_shape = in_tensors_.front()->shape(); | ||||
| int in_rank = in_shape.size(); | int in_rank = in_shape.size(); | ||||
| auto indices_tensor = in_tensors_.at(1); | |||||
| auto indices_shape = indices_tensor->shape(); | |||||
| int indices_rank = indices_shape.size(); | |||||
| int idx_lastshape = indices_shape[indices_rank - 1]; | int idx_lastshape = indices_shape[indices_rank - 1]; | ||||
| auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data()); | auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data()); | ||||
| area_ = 1; | area_ = 1; | ||||
| @@ -35,40 +35,49 @@ constexpr size_t kOutputNum = 1; | |||||
| } // namespace | } // namespace | ||||
| int OneHotCPUKernel::Init() { | int OneHotCPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| // indices depth on_value off_value | // indices depth on_value off_value | ||||
| if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { | if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { | ||||
| MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size() | MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size() | ||||
| << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); | << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (context_ == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot context nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| thread_num_ = context_->thread_num_; | |||||
| auto param = reinterpret_cast<OneHotParameter *>(op_parameter_); | |||||
| if (param == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot op_parameter_ nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| axis_ = param->axis_; | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int OneHotCPUKernel::ReSize() { | |||||
| auto indices = in_tensors_.at(0); | auto indices = in_tensors_.at(0); | ||||
| if (indices == nullptr) { | if (indices == nullptr) { | ||||
| MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr"; | MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr"; | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| } | } | ||||
| auto indices_shape = indices->shape(); | auto indices_shape = indices->shape(); | ||||
| const int indices_rank = static_cast<int>(indices_shape.size()); | |||||
| if (axis_ < 0) { | |||||
| axis_ += indices_rank + 1; | |||||
| } | |||||
| outer_size_ = 1; | outer_size_ = 1; | ||||
| for (size_t i = 0; i < static_cast<size_t>(axis_); i++) { | for (size_t i = 0; i < static_cast<size_t>(axis_); i++) { | ||||
| outer_size_ *= indices_shape[i]; | outer_size_ *= indices_shape[i]; | ||||
| } | } | ||||
| inner_size_ = indices->ElementsNum() / outer_size_; | inner_size_ = indices->ElementsNum() / outer_size_; | ||||
| if (context_ == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot context nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| thread_num_ = context_->thread_num_; | |||||
| const int indices_rank = static_cast<int>(in_tensors_.at(0)->shape().size()); | |||||
| if (axis_ < 0) { | |||||
| axis_ += indices_rank + 1; | |||||
| } | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -26,12 +26,12 @@ class OneHotCPUKernel : public LiteKernel { | |||||
| OneHotCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | OneHotCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, | ||||
| const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx, | ||||
| const lite::Primitive *primitive) | const lite::Primitive *primitive) | ||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive), context_(ctx) {} | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~OneHotCPUKernel() override = default; | ~OneHotCPUKernel() override = default; | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override { return 0; }; | |||||
| int ReSize() override; | |||||
| int Run() override; | int Run() override; | ||||
| int OneHotImpl(int task_id); | int OneHotImpl(int task_id); | ||||
| @@ -39,7 +39,6 @@ class OneHotCPUKernel : public LiteKernel { | |||||
| int GetParams(); | int GetParams(); | ||||
| private: | private: | ||||
| const lite::Context *context_; | |||||
| int thread_num_; | int thread_num_; | ||||
| int axis_; | int axis_; | ||||
| int outer_size_; | int outer_size_; | ||||
| @@ -36,16 +36,19 @@ constexpr int kOutputNum = 1; | |||||
| } // namespace | } // namespace | ||||
| int PadCPUKernel::Init() { | int PadCPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| return RET_OK; | |||||
| } | |||||
| if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { | if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { | ||||
| MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << in_tensors_.size() | MS_LOG(ERROR) << "Pad input size should be " << kInputNum << ", got " << in_tensors_.size() | ||||
| << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); | << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| } | |||||
| return ReSize(); | |||||
| } | |||||
| int PadCPUKernel::ReSize() { | |||||
| auto input = in_tensors_.at(0); | auto input = in_tensors_.at(0); | ||||
| auto output = out_tensors_.at(0); | auto output = out_tensors_.at(0); | ||||
| if (input == nullptr || output == nullptr) { | if (input == nullptr || output == nullptr) { | ||||
| @@ -35,7 +35,7 @@ class PadCPUKernel : public LiteKernel { | |||||
| ~PadCPUKernel() {} | ~PadCPUKernel() {} | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override { return 0; }; | |||||
| int ReSize() override; | |||||
| int Run() override; | int Run() override; | ||||
| int RunImpl(int task_id); | int RunImpl(int task_id); | ||||
| @@ -44,10 +44,7 @@ int ReduceCPUKernel::Init() { | |||||
| if (ret != RET_OK) { | if (ret != RET_OK) { | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = MallocTmpBuffer(); | |||||
| if (ret != RET_OK) { | |||||
| return ret; | |||||
| } | |||||
| switch (mode_) { | switch (mode_) { | ||||
| case static_cast<int>(ReduceMode_ReduceSum): { | case static_cast<int>(ReduceMode_ReduceSum): { | ||||
| reducer_ = ReduceSum; | reducer_ = ReduceSum; | ||||
| @@ -77,12 +74,15 @@ int ReduceCPUKernel::Init() { | |||||
| MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; | MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| return ReSize(); | return ReSize(); | ||||
| } | } | ||||
| int ReduceCPUKernel::ReSize() { return MallocTmpBuffer(); } | |||||
| int ReduceCPUKernel::CallReduceUnit(int task_id) { | int ReduceCPUKernel::CallReduceUnit(int task_id) { | ||||
| auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id, | auto ret = reducer_(outer_size_, inner_size_, axis_size_, src_data_, tmp_shape_.data(), dst_data_, task_id, | ||||
| context_->thread_num_); | context_->thread_num_); | ||||
| @@ -149,6 +149,14 @@ int ReduceCPUKernel::Run() { | |||||
| } | } | ||||
| int ReduceCPUKernel::MallocTmpBuffer() { | int ReduceCPUKernel::MallocTmpBuffer() { | ||||
| for (auto buffer : data_buffers_) { | |||||
| if (buffer != nullptr) { | |||||
| free(buffer); | |||||
| buffer = nullptr; | |||||
| } | |||||
| } | |||||
| data_buffers_.clear(); | |||||
| auto input_shape = in_tensors_.at(0)->shape(); | auto input_shape = in_tensors_.at(0)->shape(); | ||||
| for (auto i = 0; i < num_axes_ - 1; i++) { | for (auto i = 0; i < num_axes_ - 1; i++) { | ||||
| int axis = axes_[i]; | int axis = axes_[i]; | ||||
| @@ -48,15 +48,15 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel { | |||||
| } | } | ||||
| int Init() override; | int Init() override; | ||||
| int ReSize() override { return 0; }; | |||||
| int ReSize() override; | |||||
| int Run() override; | int Run() override; | ||||
| int CallReduceUnit(int task_id); | int CallReduceUnit(int task_id); | ||||
| private: | private: | ||||
| Reducer reducer_; | |||||
| Reducer reducer_ = nullptr; | |||||
| std::vector<float *> data_buffers_; | std::vector<float *> data_buffers_; | ||||
| const float *src_data_; | |||||
| float *dst_data_; | |||||
| const float *src_data_ = nullptr; | |||||
| float *dst_data_ = nullptr; | |||||
| private: | private: | ||||
| int MallocTmpBuffer(); | int MallocTmpBuffer(); | ||||
| @@ -38,6 +38,10 @@ int ReverseCPUKernel::Stride(int index) { | |||||
| } | } | ||||
| int ReverseCPUKernel::ReSize() { | int ReverseCPUKernel::ReSize() { | ||||
| data_size_ = in_tensors_.at(0)->ElementsNum(); | |||||
| thread_sz_count_ = MSMIN(thread_count_, data_size_); | |||||
| thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); | |||||
| auto *param = reinterpret_cast<ReverseParameter *>(op_parameter_); | auto *param = reinterpret_cast<ReverseParameter *>(op_parameter_); | ||||
| auto input_shape = in_tensors_[0]->shape(); | auto input_shape = in_tensors_[0]->shape(); | ||||
| if (param->num_axis_ > input_shape.size()) { | if (param->num_axis_ > input_shape.size()) { | ||||
| @@ -89,13 +93,9 @@ int ReverseCPUKernel::ReSize() { | |||||
| } | } | ||||
| int ReverseCPUKernel::Init() { | int ReverseCPUKernel::Init() { | ||||
| if (context_->infer_shape_interrupt_ && !context_->running_) { | |||||
| set_need_reinit(); | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| data_size_ = in_tensors_.at(0)->ElementsNum(); | |||||
| thread_sz_count_ = MSMIN(thread_count_, data_size_); | |||||
| thread_sz_stride_ = UP_DIV(data_size_, thread_sz_count_); | |||||
| int ret = ReSize(); | int ret = ReSize(); | ||||
| return ret; | return ret; | ||||
| } | } | ||||