From e6c7ddb9eade51fba0d80b55726f868b06ca4c1b Mon Sep 17 00:00:00 2001 From: wangyanling Date: Sat, 17 Apr 2021 10:09:35 +0800 Subject: [PATCH] optimize cpu adam op --- .../lite/src/runtime/kernel/arm/base/tile_base.cc | 14 ++++++-------- .../lite/src/runtime/kernel/arm/base/tile_base.h | 5 +++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc index ba382314c5..02c026e783 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc @@ -43,7 +43,7 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) { } int TileCPUKernel::ReSize() { - auto tile_parameter_ = reinterpret_cast(op_parameter_); + tile_parameter_ = reinterpret_cast(op_parameter_); MS_ASSERT(tile_parameter_); if (in_tensors_.size() == kDoubleInputsSize) { if (in_tensors_[1]->ElementsNum() > static_cast(in_tensors_[0]->shape().size())) { @@ -90,8 +90,6 @@ int SimpleTile(void *cdata, int task_id) { void TileCPUKernel::FillOneDimTileParam() { // check if tile exact one dim - auto tile_parameter_ = reinterpret_cast(op_parameter_); - MS_ASSERT(tile_parameter_); int large_one_multiple_count = 0; int multiple; int mul_index; @@ -114,19 +112,19 @@ void TileCPUKernel::FillOneDimTileParam() { } int TileCPUKernel::SimpleTileImpl(int task_id) { - auto param = reinterpret_cast(op_parameter_); - MS_ASSERT(param); - size_t unit = UP_DIV(param->fast_outer_size_, static_cast(context_->thread_num_)); + size_t unit = UP_DIV(tile_parameter_->fast_outer_size_, static_cast(context_->thread_num_)); if (unit == 0 && task_id > 0) { return RET_OK; } size_t begin = unit * static_cast(task_id); - size_t end = MSMIN(begin + unit, param->fast_outer_size_); - TileSimple(input_addr_, output_addr_, begin, end, param); + size_t end = MSMIN(begin + unit, tile_parameter_->fast_outer_size_); + TileSimple(input_addr_, output_addr_, begin, end, tile_parameter_); return RET_OK; } int TileCPUKernel::RunSimpleTile() { + auto data_type = in_tensors_.at(0)->data_type(); + tile_parameter_->data_size_ = lite::DataTypeSize(data_type); auto ret = ParallelLaunch(static_cast(this->context_)->thread_pool_, SimpleTile, this, context_->thread_num_); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/base/tile_base.h b/mindspore/lite/src/runtime/kernel/arm/base/tile_base.h index 86dc677142..8e6020a6fb 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/tile_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/tile_base.h @@ -38,8 +38,9 @@ class TileCPUKernel : public LiteKernel { void ComputeStrides(const int *shape, int *strides, int ndim); void FillOneDimTileParam(); bool one_dim_tile_; - uint8_t *input_addr_; - uint8_t *output_addr_; + uint8_t *input_addr_ = nullptr; + uint8_t *output_addr_ = nullptr; + TileParameter *tile_parameter_ = nullptr; }; } // namespace mindspore::kernel