diff --git a/mindspore/lite/nnacl/fp32/tile.h b/mindspore/lite/nnacl/fp32/tile.h index 967b7ff524..3c2de3fe0d 100644 --- a/mindspore/lite/nnacl/fp32/tile.h +++ b/mindspore/lite/nnacl/fp32/tile.h @@ -24,6 +24,7 @@ typedef struct TileParameter { int in_dim_; int in_shape_[5]; int out_shape_[5]; + int dims_[5]; int multiples_[5]; int in_strides_[5]; int out_strides_[5]; diff --git a/mindspore/lite/src/ops/populate/tile_populate.cc b/mindspore/lite/src/ops/populate/tile_populate.cc index d745331483..8d8bbdd2b5 100644 --- a/mindspore/lite/src/ops/populate/tile_populate.cc +++ b/mindspore/lite/src/ops/populate/tile_populate.cc @@ -31,10 +31,13 @@ OpParameter *PopulateTileParameter(const mindspore::lite::PrimitiveC *primitive) memset(tile_param, 0, sizeof(TileParameter)); tile_param->op_parameter_.type_ = primitive->Type(); auto param = reinterpret_cast(const_cast(primitive)); + auto dims = param->GetDims(); auto multiples = param->GetMultiples(); - tile_param->in_dim_ = multiples.size(); - for (int i = 0; i < tile_param->in_dim_; ++i) { - tile_param->multiples_[i] = multiples[i]; + for (size_t i = 0; i < kDimension_4d; ++i) { + tile_param->multiples_[i] = 1; + } + for (size_t i = 0; i < dims.size(); ++i) { + tile_param->multiples_[dims[i]] = multiples[i]; } return reinterpret_cast(tile_param); } diff --git a/mindspore/lite/src/ops/tile.cc b/mindspore/lite/src/ops/tile.cc index 49e907c9ab..8b18d709fe 100644 --- a/mindspore/lite/src/ops/tile.cc +++ b/mindspore/lite/src/ops/tile.cc @@ -140,18 +140,17 @@ int Tile::InferShape(std::vector inputs_, std::vector output std::vector out_shape; std::vector multiples = GetMultiples(); + std::vector dims = GetDims(); const size_t in_dims = input->shape().size(); - const size_t delta_dims = in_dims - multiples.size(); - size_t i = 0; - for (; i < delta_dims; ++i) { - int tmp = input->shape()[i]; - out_shape.push_back(tmp); + MS_ASSERT(multiples.size() == dims.size()); + for (size_t i = 0; i < in_dims; ++i) { + out_shape.push_back(input->shape()[i]); } - for (; i < in_dims; ++i) { - int tmp = input->shape()[i] * (multiples[i - delta_dims]); - out_shape.push_back(tmp); + for (size_t i = 0; i < dims.size(); ++i) { + out_shape[dims[i]] = input->shape()[dims[i]] * (multiples[i]); } + output->set_shape(out_shape); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc index 76aaff4422..3b212ea887 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tile_fp32.cc @@ -41,6 +41,7 @@ void TileCPUKernel::ComputeStrides(const int *shape, int *strides, int ndim) { int TileCPUKernel::ReSize() { auto tile_parameter_ = reinterpret_cast(op_parameter_); + tile_parameter_->in_dim_ = in_tensors_[0]->shape().size(); for (int i = 0; i < tile_parameter_->in_dim_; ++i) { tile_parameter_->in_shape_[i] = in_tensors_[0]->shape()[i]; tile_parameter_->out_shape_[i] = out_tensors_[0]->shape()[i];