Browse Source

!8639 add dims parameter for tile ops

From: @liuwenhao4
Reviewed-by: @HilbertDavid,@zhanghaibo5,@zhang_xue_tong
Signed-off-by:
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
b4773004fb
4 changed files with 15 additions and 11 deletions
  1. +1
    -0
      mindspore/lite/nnacl/fp32/tile.h
  2. +6
    -3
      mindspore/lite/src/ops/populate/tile_populate.cc
  3. +7
    -8
      mindspore/lite/src/ops/tile.cc
  4. +1
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc

+ 1
- 0
mindspore/lite/nnacl/fp32/tile.h View File

@@ -24,6 +24,7 @@ typedef struct TileParameter {
int in_dim_; int in_dim_;
int in_shape_[5]; int in_shape_[5];
int out_shape_[5]; int out_shape_[5];
int dims_[5];
int multiples_[5]; int multiples_[5];
int in_strides_[5]; int in_strides_[5];
int out_strides_[5]; int out_strides_[5];


+ 6
- 3
mindspore/lite/src/ops/populate/tile_populate.cc View File

@@ -31,10 +31,13 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive)
memset(tile_param, 0, sizeof(TileParameter)); memset(tile_param, 0, sizeof(TileParameter));
tile_param->op_parameter_.type_ = primitive->Type(); tile_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::Tile *>(const_cast<mindspore::lite::PrimitiveC *>(primitive)); auto param = reinterpret_cast<mindspore::lite::Tile *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
auto dims = param->GetDims();
auto multiples = param->GetMultiples(); auto multiples = param->GetMultiples();
tile_param->in_dim_ = multiples.size();
for (int i = 0; i < tile_param->in_dim_; ++i) {
tile_param->multiples_[i] = multiples[i];
for (size_t i = 0; i < kDimension_4d; ++i) {
tile_param->multiples_[i] = 1;
}
for (size_t i = 0; i < dims.size(); ++i) {
tile_param->multiples_[dims[i]] = multiples[i];
} }
return reinterpret_cast<OpParameter *>(tile_param); return reinterpret_cast<OpParameter *>(tile_param);
} }


+ 7
- 8
mindspore/lite/src/ops/tile.cc View File

@@ -140,18 +140,17 @@ int Tile::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output


std::vector<int> out_shape; std::vector<int> out_shape;
std::vector<int> multiples = GetMultiples(); std::vector<int> multiples = GetMultiples();
std::vector<int> dims = GetDims();
const size_t in_dims = input->shape().size(); const size_t in_dims = input->shape().size();
const size_t delta_dims = in_dims - multiples.size();


size_t i = 0;
for (; i < delta_dims; ++i) {
int tmp = input->shape()[i];
out_shape.push_back(tmp);
MS_ASSERT(multiples.size() == dims.size());
for (size_t i = 0; i < in_dims; ++i) {
out_shape.push_back(input->shape()[i]);
} }
for (; i < in_dims; ++i) {
int tmp = input->shape()[i] * (multiples[i - delta_dims]);
out_shape.push_back(tmp);
for (size_t i = 0; i < dims.size(); ++i) {
out_shape[dims[i]] = input->shape()[dims[i]] * (multiples[i]);
} }

output->set_shape(out_shape); output->set_shape(out_shape);
return RET_OK; return RET_OK;
} }


+ 1
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc View File

@@ -41,6 +41,7 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) {


int TileCPUKernel::ReSize() { int TileCPUKernel::ReSize() {
auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_); auto tile_parameter_ = reinterpret_cast<TileParameter *>(op_parameter_);
tile_parameter_->in_dim_ = in_tensors_[0]->shape().size();
for (int i = 0; i < tile_parameter_->in_dim_; ++i) { for (int i = 0; i < tile_parameter_->in_dim_; ++i) {
tile_parameter_->in_shape_[i] = in_tensors_[0]->shape()[i]; tile_parameter_->in_shape_[i] = in_tensors_[0]->shape()[i];
tile_parameter_->out_shape_[i] = out_tensors_[0]->shape()[i]; tile_parameter_->out_shape_[i] = out_tensors_[0]->shape()[i];


Loading…
Cancel
Save