Browse Source

add parameter for nchw2nhwc and nhwc2nchw

tags/v1.1.0
hangq 5 years ago
parent
commit
d0213d2da2
1 changed files with 16 additions and 6 deletions
  1. +16
    -6
      mindspore/lite/src/populate_parameter.cc

+ 16
- 6
mindspore/lite/src/populate_parameter.cc View File

@@ -863,25 +863,35 @@ OpParameter *PopulateTopKParameter(const mindspore::lite::PrimitiveC *primitive)
}

OpParameter *PopulateNhwc2NchwParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
TransposeParameter *parameter = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter)));
if (parameter == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(parameter, 0, sizeof(OpParameter));
parameter->type_ = primitive->Type();
return parameter;
parameter->op_parameter_.type_ = primitive->Type();
parameter->num_axes_ = 4;
parameter->perm_[0] = 0;
parameter->perm_[1] = 3;
parameter->perm_[2] = 1;
parameter->perm_[3] = 2;
return reinterpret_cast<OpParameter *>(parameter);
}

OpParameter *PopulateNchw2NhwcParameter(const mindspore::lite::PrimitiveC *primitive) {
OpParameter *parameter = reinterpret_cast<OpParameter *>(malloc(sizeof(OpParameter)));
TransposeParameter *parameter = reinterpret_cast<TransposeParameter *>(malloc(sizeof(TransposeParameter)));
if (parameter == nullptr) {
MS_LOG(ERROR) << "malloc OpParameter failed.";
return nullptr;
}
memset(parameter, 0, sizeof(OpParameter));
parameter->type_ = primitive->Type();
return parameter;
parameter->op_parameter_.type_ = primitive->Type();
parameter->num_axes_ = 4;
parameter->perm_[0] = 0;
parameter->perm_[1] = 2;
parameter->perm_[2] = 3;
parameter->perm_[3] = 1;
return reinterpret_cast<OpParameter *>(parameter);
}

OpParameter *PopulateTransposeParameter(const mindspore::lite::PrimitiveC *primitive) {


Loading…
Cancel
Save