| @@ -30,10 +30,6 @@ using mindspore::schema::PrimitiveType_Transpose; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int TransposeCPUKernel::Init() { | int TransposeCPUKernel::Init() { | ||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||||
| num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); | |||||
| thread_h_num_ = MSMIN(thread_num_, num_unit_); | |||||
| thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); | |||||
| if (!InferShapeDone()) { | if (!InferShapeDone()) { | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -41,9 +37,13 @@ int TransposeCPUKernel::Init() { | |||||
| } | } | ||||
| int TransposeCPUKernel::ReSize() { | int TransposeCPUKernel::ReSize() { | ||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_); | |||||
| num_unit_ = static_cast<int>(in_tensors_[kInputIndex]->shape().at(param->perm_[kNHWC_H])); | |||||
| thread_h_num_ = MSMIN(thread_num_, num_unit_); | |||||
| thread_h_stride_ = UP_DIV(num_unit_, thread_h_num_); | |||||
| auto &inTensor = in_tensors_.front(); | auto &inTensor = in_tensors_.front(); | ||||
| auto &outTensor = out_tensors_.front(); | auto &outTensor = out_tensors_.front(); | ||||
| auto param = reinterpret_cast<TransposeParameter *>(op_parameter_); | |||||
| auto in_shape = inTensor->shape(); | auto in_shape = inTensor->shape(); | ||||
| auto out_shape = outTensor->shape(); | auto out_shape = outTensor->shape(); | ||||
| param->strides_[param->num_axes_ - 1] = 1; | param->strides_[param->num_axes_ - 1] = 1; | ||||