diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc index 702046c65f..c885786d76 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc @@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_Transpose; namespace mindspore::kernel { int TransposeCPUKernel::Init() { - TransposeParameter *param = reinterpret_cast(this->op_parameter_); - num_unit_ = static_cast(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); - thread_h_num_ = MSMIN(thread_num_, num_unit_); - thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); if (!InferShapeDone()) { return RET_OK; } @@ -41,9 +37,13 @@ int TransposeCPUKernel::Init() { } int TransposeCPUKernel::ReSize() { + TransposeParameter *param = reinterpret_cast(op_parameter_); + num_unit_ = static_cast(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); + thread_h_num_ = MSMIN(thread_num_, num_unit_); + thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); + auto &inTensor = in_tensors_.front(); auto &outTensor = out_tensors_.front(); - auto param = reinterpret_cast(op_parameter_); auto in_shape = inTensor->shape(); auto out_shape = outTensor->shape(); param->strides_[param->num_axes_ - 1] = 1;