|
|
|
@@ -43,7 +43,7 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) { |
|
|
|
} |
|
|
|
|
|
|
|
int TileCPUKernel::ReSize() { |
|
|
|
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_); |
|
|
|
tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_); |
|
|
|
MS_ASSERT(tile_parameter_); |
|
|
|
if (in_tensors_.size() == kDoubleInputsSize) { |
|
|
|
if (in_tensors_[1]->ElementsNum() > static_cast<int>(in_tensors_[0]->shape().size())) { |
|
|
|
@@ -90,8 +90,6 @@ int SimpleTile(void *cdata, int task_id) { |
|
|
|
|
|
|
|
void TileCPUKernel::FillOneDimTileParam() { |
|
|
|
// check if tile exact one dim |
|
|
|
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_); |
|
|
|
MS_ASSERT(tile_parameter_); |
|
|
|
int large_one_multiple_count = 0; |
|
|
|
int multiple; |
|
|
|
int mul_index; |
|
|
|
@@ -114,19 +112,19 @@ void TileCPUKernel::FillOneDimTileParam() { |
|
|
|
} |
|
|
|
|
|
|
|
int TileCPUKernel::SimpleTileImpl(int task_id) { |
|
|
|
auto param = reinterpret_cast<TileParameter *>(op_parameter_); |
|
|
|
MS_ASSERT(param); |
|
|
|
size_t unit = UP_DIV(param->fast_outer_size_, static_cast<size_t>(context_->thread_num_)); |
|
|
|
size_t unit = UP_DIV(tile_parameter_->fast_outer_size_, static_cast<size_t>(context_->thread_num_)); |
|
|
|
if (unit == 0 && task_id > 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
size_t begin = unit * static_cast<size_t>(task_id); |
|
|
|
size_t end = MSMIN(begin + unit, param->fast_outer_size_); |
|
|
|
TileSimple(input_addr_, output_addr_, begin, end, param); |
|
|
|
size_t end = MSMIN(begin + unit, tile_parameter_->fast_outer_size_); |
|
|
|
TileSimple(input_addr_, output_addr_, begin, end, tile_parameter_); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
int TileCPUKernel::RunSimpleTile() { |
|
|
|
auto data_type = in_tensors_.at(0)->data_type(); |
|
|
|
tile_parameter_->data_size_ = lite::DataTypeSize(data_type); |
|
|
|
auto ret = ParallelLaunch(static_cast<const lite::InnerContext *>(this->context_)->thread_pool_, SimpleTile, this, |
|
|
|
context_->thread_num_); |
|
|
|
if (ret != RET_OK) { |
|
|
|
|