Browse Source

optimize cpu adam op

pull/15173/head
wangyanling 4 years ago
parent
commit
e6c7ddb9ea
2 changed files with 9 additions and 10 deletions
  1. +6
    -8
      mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc
  2. +3
    -2
      mindspore/lite/src/runtime/kernel/arm/base/tile_base.h

+ 6
- 8
mindspore/lite/src/runtime/kernel/arm/base/tile_base.cc View File

@@ -43,7 +43,7 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {
}

int TileCPUKernel::ReSize() {
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
MS_ASSERT(tile_parameter_);
if (in_tensors_.size() == kDoubleInputsSize) {
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) {
@@ -90,8 +90,6 @@ int SimpleTile(void *cdata, int task_id) {

void TileCPUKernel::FillOneDimTileParam() {
// check if tile exact one dim
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
MS_ASSERT(tile_parameter_);
int large_one_multiple_count = 0;
int multiple;
int mul_index;
@@ -114,19 +112,19 @@ void TileCPUKernel::FillOneDimTileParam() {
}

int TileCPUKernel::SimpleTileImpl(int task_id) {
auto param = reinterpret_cast<TileParameter *>(op_parameter_);
MS_ASSERT(param);
size_t unit = UP_DIV(param->fast_outer_size_, static_cast<size_t>(context_->thread_num_));
size_t unit = UP_DIV(tile_parameter_->fast_outer_size_, static_cast<size_t>(context_->thread_num_));
if (unit == 0 && task_id > 0) {
return RET_OK;
}
size_t begin = unit * static_cast<size_t>(task_id);
size_t end = MSMIN(begin + unit, param->fast_outer_size_);
TileSimple(input_addr_, output_addr_, begin, end, param);
size_t end = MSMIN(begin + unit, tile_parameter_->fast_outer_size_);
TileSimple(input_addr_, output_addr_, begin, end, tile_parameter_);
return RET_OK;
}

int TileCPUKernel::RunSimpleTile() {
auto data_type = in_tensors_.at(0)->data_type();
tile_parameter_->data_size_ = lite::DataTypeSize(data_type);
auto ret = ParallelLaunch(static_cast<const lite::InnerContext *>(this->context_)->thread_pool_, SimpleTile, this,
context_->thread_num_);
if (ret != RET_OK) {


+ 3
- 2
mindspore/lite/src/runtime/kernel/arm/base/tile_base.h View File

@@ -38,8 +38,9 @@ class TileCPUKernel : public LiteKernel {
void ComputeStrides(const int *shape, int *strides, int ndim);
void FillOneDimTileParam();
bool one_dim_tile_;
uint8_t *input_addr_;
uint8_t *output_addr_;
uint8_t *input_addr_ = nullptr;
uint8_t *output_addr_ = nullptr;
TileParameter *tile_parameter_ = nullptr;
};
} // namespace mindspore::kernel



Loading…
Cancel
Save