diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 4a966dfe84..59983c1974 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -863,25 +863,35 @@ OpParameter *PopulateTopKParameter(const mindspore::lite::PrimitiveC *primitive) } OpParameter *PopulateNhwc2NchwParameter(const mindspore::lite::PrimitiveC *primitive) { - OpParameter *parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + TransposeParameter *parameter = reinterpret_cast(malloc(sizeof(TransposeParameter))); if (parameter == nullptr) { MS_LOG(ERROR) << "malloc OpParameter failed."; return nullptr; } memset(parameter, 0, sizeof(OpParameter)); - parameter->type_ = primitive->Type(); - return parameter; + parameter->op_parameter_.type_ = primitive->Type(); + parameter->num_axes_ = 4; + parameter->perm_[0] = 0; + parameter->perm_[1] = 3; + parameter->perm_[2] = 1; + parameter->perm_[3] = 2; + return reinterpret_cast(parameter); } OpParameter *PopulateNchw2NhwcParameter(const mindspore::lite::PrimitiveC *primitive) { - OpParameter *parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + TransposeParameter *parameter = reinterpret_cast(malloc(sizeof(TransposeParameter))); if (parameter == nullptr) { MS_LOG(ERROR) << "malloc OpParameter failed."; return nullptr; } memset(parameter, 0, sizeof(OpParameter)); - parameter->type_ = primitive->Type(); - return parameter; + parameter->op_parameter_.type_ = primitive->Type(); + parameter->num_axes_ = 4; + parameter->perm_[0] = 0; + parameter->perm_[1] = 2; + parameter->perm_[2] = 3; + parameter->perm_[3] = 1; + return reinterpret_cast(parameter); } OpParameter *PopulateTransposeParameter(const mindspore::lite::PrimitiveC *primitive) {