Browse Source

transpose put shape related to ReSize

tags/v1.0.0
zhaozhenlong 5 years ago
parent
commit
b2d43bbe1b
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc

+ 5
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/transpose.cc View File

@@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_Transpose;

namespace mindspore::kernel {
int TransposeCPUKernel::Init() {
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H]));
thread_h_num_ = MSMIN(thread_num_, num_unit_);
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);
if (!InferShapeDone()) {
return RET_OK;
}
@@ -41,9 +37,13 @@ int TransposeCPUKernel::Init() {
}

int TransposeCPUKernel::ReSize() {
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_);
num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H]));
thread_h_num_ = MSMIN(thread_num_, num_unit_);
thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_);

auto &inTensor = in_tensors_.front();
auto &outTensor = out_tensors_.front();
auto param = reinterpret_cast<TransposeParameter *>(op_parameter_);
auto in_shape = inTensor->shape();
auto out_shape = outTensor->shape();
param->strides_[param->num_axes_ - 1] = 1;


Loading…
Cancel
Save