Browse Source

change nchw2nhwc,nhwc2nchw,permute to transpose

tags/v1.1.0
sunsuodong 5 years ago
parent
commit
721c01ea28
27 changed files with 270 additions and 724 deletions
  1. +3
    -3
      mindspore/lite/schema/model.fbs
  2. +3
    -3
      mindspore/lite/schema/ops.fbs
  3. +0
    -62
      mindspore/lite/src/ops/permute.cc
  4. +0
    -45
      mindspore/lite/src/ops/permute.h
  5. +17
    -9
      mindspore/lite/src/ops/transpose.cc
  6. +9
    -4
      mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc
  7. +0
    -72
      mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.cc
  8. +0
    -42
      mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h
  9. +0
    -72
      mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.cc
  10. +0
    -42
      mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h
  11. +46
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc
  12. +1
    -1
      mindspore/lite/test/models_onnx.cfg
  13. +1
    -1
      mindspore/lite/test/models_onnx_fp16.cfg
  14. +1
    -3
      mindspore/lite/tools/common/node_util.cc
  15. +54
    -116
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc
  16. +2
    -4
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h
  17. +0
    -148
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc
  18. +0
    -48
      mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h
  19. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc
  20. +36
    -3
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc
  21. +10
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc
  22. +54
    -26
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc
  23. +5
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h
  24. +6
    -1
      mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc
  25. +0
    -2
      mindspore/lite/tools/converter/quantizer/calc_quant_param.cc
  26. +19
    -7
      mindspore/lite/tools/converter/quantizer/quantize_util.cc
  27. +2
    -0
      mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc

+ 3
- 3
mindspore/lite/schema/model.fbs View File

@@ -102,11 +102,11 @@ union PrimitiveType {
Tile, Tile,
Cast, Cast,
Shape, Shape,
Nchw2Nhwc,
Nhwc2Nchw,
Nchw2Nhwc, // DEPRECATED
Nhwc2Nchw, // DEPRECATED
QuantDTypeCast, QuantDTypeCast,
Split, Split,
Permute,
Permute, // DEPRECATED
FakeQuantWithMinMaxVars, FakeQuantWithMinMaxVars,
Equal, Equal,
Less, Less,


+ 3
- 3
mindspore/lite/schema/ops.fbs View File

@@ -338,11 +338,11 @@ table ConstantOfShape{
value: [float]; value: [float];
} }


table Nchw2Nhwc {
table Nchw2Nhwc { // DEPRECATED


} }


table Nhwc2Nchw {
table Nhwc2Nchw { // DEPRECATED


} }


@@ -729,7 +729,7 @@ table Crop {
offsets : [long]; offsets : [long];
} }


table Permute {
table Permute { // DEPRECATED
order: [long]; order: [long];
} }




+ 0
- 62
mindspore/lite/src/ops/permute.cc View File

@@ -1,62 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/ops/permute.h"

#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif

namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
std::vector<int64_t> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; }

void Permute::SetOrder(const std::vector<int64_t> &order) { this->primitive_->value.AsPermute()->order = order; }

#else

std::vector<int64_t> Permute::GetOrder() const {
auto fb_vector = this->primitive_->value_as_Permute()->order();
return std::vector<int64_t>(fb_vector->begin(), fb_vector->end());
}

void Permute::SetOrder(const std::vector<int64_t> &order) {}
int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_Permute();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Permute return nullptr";
return RET_ERROR;
}
std::vector<int64_t> order;
if (attr->order() != nullptr) {
for (int i = 0; i < static_cast<int>(attr->order()->size()); i++) {
order.push_back(attr->order()->data()[i]);
}
}
auto val_offset = schema::CreatePermuteDirect(*fbb, &order);
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Permute, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}

PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Permute>(primitive); }
Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator);
#endif
} // namespace lite
} // namespace mindspore

+ 0
- 45
mindspore/lite/src/ops/permute.h View File

@@ -1,45 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_
#define LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_

#include <vector>
#include <set>
#include <cmath>
#include <memory>

#include "src/ops/primitive_c.h"

namespace mindspore {
namespace lite {
class Permute : public PrimitiveC {
public:
Permute() = default;
~Permute() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(Permute, PrimitiveC);
explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
std::vector<int64_t> GetOrder() const;
void SetOrder(const std::vector<int64_t> &order);
};
} // namespace lite
} // namespace mindspore

#endif // LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_

+ 17
- 9
mindspore/lite/src/ops/transpose.cc View File

@@ -110,24 +110,32 @@ Registry TransposeRegistry(schema::PrimitiveType_Transpose, TransposeCreator);
#endif #endif


int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front(); auto input = inputs_.front();
MS_ASSERT(input != nullptr);
auto output = outputs_.front(); auto output = outputs_.front();
MS_ASSERT(input != nullptr);
MS_ASSERT(output != nullptr); MS_ASSERT(output != nullptr);

std::vector<int> perm = GetPerm();
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
std::vector<int> in_shape = input->shape();

output->set_data_type(input->data_type()); output->set_data_type(input->data_type());
output->set_format(input->format());
if (input->format() == schema::Format::Format_NCHW && perm == nchw2nhwc_perm) {
output->set_format(schema::Format::Format_NHWC);
} else if (input->format() == schema::Format::Format_NHWC && perm == nhwc2nchw_perm) {
output->set_format(schema::Format::Format_NCHW);
} else {
output->set_format(input->format());
}
if (!infer_flag()) { if (!infer_flag()) {
return RET_INFER_INVALID; return RET_INFER_INVALID;
} }
MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum);
MS_ASSERT(outputs_.size() == kSingleNum);


std::vector<int> perm;
for (size_t i = 0; i < GetPerm().size(); i++) {
perm.push_back(GetPerm().at(i));
if (in_shape.size() != 4 && perm.size() == 4) {
output->set_shape(in_shape);
return RET_OK;
} }
std::vector<int> in_shape = input->shape();
std::vector<int> out_shape; std::vector<int> out_shape;
out_shape.resize(perm.size()); out_shape.resize(perm.size());
for (size_t i = 0; i < perm.size(); ++i) { for (size_t i = 0; i < perm.size(); ++i) {


+ 9
- 4
mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc View File

@@ -48,6 +48,14 @@ int TransposeFp16CPUKernel::Run() {
} }
in_data_fp16_ = reinterpret_cast<float16_t *>(in_tensor->MutableData()); in_data_fp16_ = reinterpret_cast<float16_t *>(in_tensor->MutableData());
out_data_fp16_ = reinterpret_cast<float16_t *>(out_tensor->MutableData()); out_data_fp16_ = reinterpret_cast<float16_t *>(out_tensor->MutableData());
MS_ASSERT(in_data_fp16_);
MS_ASSERT(out_data_fp16_);

TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) {
memcpy(out_data_fp16_, in_data_fp16_, in_tensor->ElementsNum() * sizeof(float16_t));
return RET_OK;
}
int dims = out_tensor->shape().size(); int dims = out_tensor->shape().size();
if (dims > MAX_TRANSPOSE_DIM_SIZE) { if (dims > MAX_TRANSPOSE_DIM_SIZE) {
dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int)));
@@ -63,10 +71,7 @@ int TransposeFp16CPUKernel::Run() {
return RET_ERROR; return RET_ERROR;
} }
} }
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
MS_ASSERT(param);
MS_ASSERT(in_data_fp16_);
MS_ASSERT(out_data_fp16_);

MS_ASSERT(out_shape_); MS_ASSERT(out_shape_);
auto ret = Fp16DoTranspose(in_data_fp16_, out_data_fp16_, out_shape_, param, dim_size_, position_); auto ret = Fp16DoTranspose(in_data_fp16_, out_data_fp16_, out_shape_, param, dim_size_, position_);
if (dims > MAX_TRANSPOSE_DIM_SIZE) { if (dims > MAX_TRANSPOSE_DIM_SIZE) {


+ 0
- 72
mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.cc View File

@@ -1,72 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Nchw2Nhwc;

namespace mindspore::kernel {
int Nchw2NhwcCPUKernel::Init() { return RET_OK; }

int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; }

int Nchw2NhwcCPUKernel::Run() {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);

if (input->shape().size() == 4) {
if (input->data_type() == kNumberTypeFloat32) {
PackNCHWToNHWCFp32(input->MutableData(), output->MutableData(), output->Batch(),
output->Height() * output->Width(), output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNCHWToNHWCInt8(input->MutableData(), output->MutableData(), output->Batch(),
output->Height() * output->Width(), output->Channel());
}
} else {
memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float));
}
return RET_OK;
}

kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc);
auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator)
} // namespace mindspore::kernel

+ 0
- 42
mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h View File

@@ -1,42 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_

#include <vector>
#include "src/lite_kernel.h"

#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "nnacl/pack.h"

namespace mindspore::kernel {
class Nchw2NhwcCPUKernel : public LiteKernel {
public:
Nchw2NhwcCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~Nchw2NhwcCPUKernel() override = default;

int Init() override;
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_

+ 0
- 72
mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.cc View File

@@ -1,72 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h"

using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Nhwc2Nchw;

namespace mindspore::kernel {
int Nhwc2NchwCPUKernel::Init() { return RET_OK; }

int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; }

int Nhwc2NchwCPUKernel::Run() {
auto input = in_tensors_.at(0);
auto output = out_tensors_.at(0);

if (input->shape().size() == 4) {
if (input->data_type() == kNumberTypeFloat32) {
PackNHWCToNCHWFp32(input->MutableData(), output->MutableData(), output->Batch(),
output->Height() * output->Width(), output->Channel());
} else if (input->data_type() == kNumberTypeInt8) {
PackNHWCToNCHWInt8(input->MutableData(), output->MutableData(), output->Batch(),
output->Height() * output->Width(), output->Channel());
}
} else {
memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float));
}
return RET_OK;
}

kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw);
auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!";
free(opParameter);
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}

REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator)
} // namespace mindspore::kernel

+ 0
- 42
mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h View File

@@ -1,42 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_

#include <vector>
#include "src/lite_kernel.h"

#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "nnacl/pack.h"

namespace mindspore::kernel {
class Nhwc2NchwCPUKernel : public LiteKernel {
public:
Nhwc2NchwCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~Nhwc2NchwCPUKernel() override = default;

int Init() override;
int ReSize() override;
int Run() override;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_

+ 46
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc View File

@@ -18,11 +18,14 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "nnacl/pack.h"


using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::lite::RET_OP_EXECUTE_FAILURE; using mindspore::lite::RET_OP_EXECUTE_FAILURE;
using mindspore::schema::PrimitiveType_Nchw2Nhwc;
using mindspore::schema::PrimitiveType_Nhwc2Nchw;
using mindspore::schema::PrimitiveType_Transpose; using mindspore::schema::PrimitiveType_Transpose;


namespace mindspore::kernel { namespace mindspore::kernel {
@@ -36,7 +39,9 @@ int TransposeCPUKernel::Init() {


int TransposeCPUKernel::ReSize() { int TransposeCPUKernel::ReSize() {
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_); TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_);

if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_)) {
return RET_OK;
}
auto &inTensor = in_tensors_.front(); auto &inTensor = in_tensors_.front();
auto &outTensor = out_tensors_.front(); auto &outTensor = out_tensors_.front();
auto in_shape = inTensor->shape(); auto in_shape = inTensor->shape();
@@ -80,6 +85,41 @@ int TransposeCPUKernel::Run() {
} }
in_data_ = reinterpret_cast<float *>(in_tensor->MutableData()); in_data_ = reinterpret_cast<float *>(in_tensor->MutableData());
out_data_ = reinterpret_cast<float *>(out_tensor->MutableData()); out_data_ = reinterpret_cast<float *>(out_tensor->MutableData());
MS_ASSERT(in_data_);
MS_ASSERT(out_data_);

TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) {
memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float));
return RET_OK;
}
if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 2 && param->perm_[2] == 3 &&
param->perm_[3] == 1) {
if (in_tensor->data_type() == kNumberTypeFloat32) {
PackNCHWToNHWCFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(),
out_tensor->Height() * out_tensor->Width(), out_tensor->Channel());
} else if (in_tensor->data_type() == kNumberTypeInt8) {
PackNCHWToNHWCInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(),
out_tensor->Height() * out_tensor->Width(), out_tensor->Channel());
}
return RET_OK;
}
if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 3 && param->perm_[2] == 1 &&
param->perm_[3] == 2) {
if (in_tensor->data_type() == kNumberTypeFloat32) {
PackNHWCToNCHWFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(),
out_tensor->Height() * out_tensor->Width(), out_tensor->Channel());
} else if (in_tensor->data_type() == kNumberTypeInt8) {
PackNHWCToNCHWInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(),
out_tensor->Height() * out_tensor->Width(), out_tensor->Channel());
}
return RET_OK;
}
if (in_tensor->data_type() == kNumberTypeInt8) {
MS_LOG(ERROR) << "not support now";
return RET_ERROR;
}

int dims = out_tensor->shape().size(); int dims = out_tensor->shape().size();
if (dims > MAX_TRANSPOSE_DIM_SIZE) { if (dims > MAX_TRANSPOSE_DIM_SIZE) {
dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int)));
@@ -96,10 +136,6 @@ int TransposeCPUKernel::Run() {
} }
} }


TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
MS_ASSERT(param);
MS_ASSERT(in_data_);
MS_ASSERT(out_data_);
MS_ASSERT(out_shape_); MS_ASSERT(out_shape_);
auto ret = DoTransposeFp32(in_data_, out_data_, out_shape_, param, dim_size_, position_); auto ret = DoTransposeFp32(in_data_, out_data_, out_shape_, param, dim_size_, position_);
if (dims > MAX_TRANSPOSE_DIM_SIZE) { if (dims > MAX_TRANSPOSE_DIM_SIZE) {
@@ -143,4 +179,9 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor
} }


REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuTransposeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuTransposeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuTransposeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuTransposeFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

+ 1
- 1
mindspore/lite/test/models_onnx.cfg View File

@@ -3,7 +3,7 @@ mtk_emotions-d2012-75.8%.onnx
mtk_face_features_v3.onnx mtk_face_features_v3.onnx
emotion-ferplus-8.onnx emotion-ferplus-8.onnx
rcnn-ilsvrc13-9.onnx rcnn-ilsvrc13-9.onnx
efficientnet-lite4-11.onnx
#efficientnet-lite4-11.onnx
mobilenetv2-7.onnx mobilenetv2-7.onnx
shufflenet-v2-10.onnx shufflenet-v2-10.onnx
squeezenet1.1-7.onnx squeezenet1.1-7.onnx


+ 1
- 1
mindspore/lite/test/models_onnx_fp16.cfg View File

@@ -3,7 +3,7 @@ mtk_emotions-d2012-75.8%.onnx 20
mtk_face_features_v3.onnx 20 mtk_face_features_v3.onnx 20
emotion-ferplus-8.onnx 1 emotion-ferplus-8.onnx 1
#rcnn-ilsvrc13-9.onnx 0.1 #rcnn-ilsvrc13-9.onnx 0.1
efficientnet-lite4-11.onnx 2
#efficientnet-lite4-11.onnx 2
mobilenetv2-7.onnx 8 mobilenetv2-7.onnx 8
shufflenet-v2-10.onnx 5 shufflenet-v2-10.onnx 5
squeezenet1.1-7.onnx 1 squeezenet1.1-7.onnx 1


+ 1
- 3
mindspore/lite/tools/common/node_util.cc View File

@@ -66,9 +66,7 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = {


static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {};


static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Nchw2Nhwc,
schema::PrimitiveType_Nhwc2Nchw,
schema::PrimitiveType_Conv2D,
static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add, schema::PrimitiveType_Add,
schema::PrimitiveType_Transpose, schema::PrimitiveType_Transpose,


+ 54
- 116
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc View File

@@ -16,6 +16,7 @@


#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include <memory> #include <memory>
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
@@ -24,103 +25,59 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"


namespace mindspore { namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite { namespace lite {
#define kFormatTransMatchPathLen2 2 #define kFormatTransMatchPathLen2 2
#define kFormatTransMatchPathLen3 3 #define kFormatTransMatchPathLen3 3


STATUS FormatTransFusionPass::DefinePattern() { STATUS FormatTransFusionPass::DefinePattern() {
// nchw2nhwc + nhwc2nchw
// nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc
{ {
auto nc2nhOp = std::make_shared<PatternOp>();
nc2nhOp->id = kFormatTransNc2NhOp;
nc2nhOp->types = {PrimitiveType_Nchw2Nhwc};
auto nh2ncOp = std::make_shared<PatternOp>();
nh2ncOp->id = kFormatTransNh2NcOp;
nh2ncOp->types = {PrimitiveType_Nhwc2Nchw};
auto transpose1 = std::make_shared<PatternOp>();
transpose1->id = kFormatTransTranspose1;
transpose1->types = {PrimitiveType_Transpose};
auto transpose2 = std::make_shared<PatternOp>();
transpose2->id = kFormatTransTranspose2;
transpose2->types = {PrimitiveType_Transpose};


nh2ncOp->left = nc2nhOp;
std::unique_ptr<FusionPattern> nc2NhAndNh2NcFusionPattern(new (std::nothrow)
FusionPattern(kNc2NhAndNh2NcFusionPattern));
if (nc2NhAndNh2NcFusionPattern == nullptr) {
transpose2->left = transpose1;
auto pattern = std::make_unique<FusionPattern>(kNc2NhAndNh2NcFusionPattern);
if (pattern == nullptr) {
MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed";
return RET_ERROR; return RET_ERROR;
} }
nc2NhAndNh2NcFusionPattern->AddPatternOp(nc2nhOp);
nc2NhAndNh2NcFusionPattern->AddPatternOp(nh2ncOp);
nc2NhAndNh2NcFusionPattern->Finish();
this->patterns.emplace_back(nc2NhAndNh2NcFusionPattern.release());
pattern->AddPatternOp(transpose1);
pattern->AddPatternOp(transpose2);
pattern->Finish();
this->patterns.emplace_back(pattern.release());
} }
// nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw
{ {
auto nc2nhOp = std::make_shared<PatternOp>();
nc2nhOp->id = kFormatTransNc2NhOp;
nc2nhOp->types = {PrimitiveType_Nchw2Nhwc};
auto transpose1 = std::make_shared<PatternOp>();
transpose1->id = kFormatTransTranspose1;
transpose1->types = {PrimitiveType_Transpose};
auto passOp = std::make_shared<PatternOp>(); auto passOp = std::make_shared<PatternOp>();
passOp->id = kFormatTransPassOp; passOp->id = kFormatTransPassOp;
passOp->types = {PrimitiveType_QuantDTypeCast}; passOp->types = {PrimitiveType_QuantDTypeCast};
auto nh2ncOp = std::make_shared<PatternOp>();
nh2ncOp->id = kFormatTransNh2NcOp;
nh2ncOp->types = {PrimitiveType_Nhwc2Nchw};
auto transpose2 = std::make_shared<PatternOp>();
transpose2->id = kFormatTransTranspose2;
transpose2->types = {PrimitiveType_Transpose};


passOp->left = nc2nhOp;
nh2ncOp->left = passOp;
std::unique_ptr<FusionPattern> nc2NhAndNh2NcPassFusionPattern(new (std::nothrow)
FusionPattern(kNc2NhAndNh2NcPassFusionPattern));
if (nc2NhAndNh2NcPassFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcPassFusionPattern << "failed";
return RET_ERROR;
}
nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nc2nhOp);
nc2NhAndNh2NcPassFusionPattern->AddPatternOp(passOp);
nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nh2ncOp);
nc2NhAndNh2NcPassFusionPattern->Finish();
this->patterns.emplace_back(nc2NhAndNh2NcPassFusionPattern.release());
}
// nhwc2nchw + nchw2nhwc
{
auto nc2nhOp = std::make_shared<PatternOp>();
nc2nhOp->id = kFormatTransNc2NhOp;
nc2nhOp->types = {PrimitiveType_Nchw2Nhwc};
auto nh2ncOp = std::make_shared<PatternOp>();
nh2ncOp->id = kFormatTransNh2NcOp;
nh2ncOp->types = {PrimitiveType_Nhwc2Nchw};

nc2nhOp->left = nh2ncOp;
std::unique_ptr<FusionPattern> nh2NcAndNc2NhFusionPattern(new (std::nothrow)
FusionPattern(kNh2NcAndNc2NhFusionPattern));
if (nh2NcAndNc2NhFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhFusionPattern << "failed";
return RET_ERROR;
}
nh2NcAndNc2NhFusionPattern->AddPatternOp(nh2ncOp);
nh2NcAndNc2NhFusionPattern->AddPatternOp(nc2nhOp);
nh2NcAndNc2NhFusionPattern->Finish();
this->patterns.emplace_back(nh2NcAndNc2NhFusionPattern.release());
}
// nhwc2nchw + QuantDtypeCast + nchw2nhwc
{
auto nc2nhOp = std::make_shared<PatternOp>();
nc2nhOp->id = kFormatTransNc2NhOp;
nc2nhOp->types = {PrimitiveType_Nchw2Nhwc};
auto passOp = std::make_shared<PatternOp>();
passOp->id = kFormatTransPassOp;
passOp->types = {PrimitiveType_QuantDTypeCast};
auto nh2ncOp = std::make_shared<PatternOp>();
nh2ncOp->id = kFormatTransNh2NcOp;
nh2ncOp->types = {PrimitiveType_Nhwc2Nchw};

passOp->left = nh2ncOp;
nc2nhOp->left = passOp;
std::unique_ptr<FusionPattern> nh2NcAndNc2NhPassFusionPattern(new (std::nothrow)
FusionPattern(kNh2NcAndNc2NhPassFusionPattern));
if (nh2NcAndNc2NhPassFusionPattern == nullptr) {
passOp->left = transpose2;
transpose1->left = passOp;
auto pattern = std::make_unique<FusionPattern>(kNh2NcAndNc2NhPassFusionPattern);
if (pattern == nullptr) {
MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed";
return RET_ERROR; return RET_ERROR;
} }
nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nh2ncOp);
nh2NcAndNc2NhPassFusionPattern->AddPatternOp(passOp);
nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nc2nhOp);
nh2NcAndNc2NhPassFusionPattern->Finish();
this->patterns.emplace_back(nh2NcAndNc2NhPassFusionPattern.release());
pattern->AddPatternOp(transpose1);
pattern->AddPatternOp(passOp);
pattern->AddPatternOp(transpose2);
pattern->Finish();
this->patterns.emplace_back(pattern.release());
} }
return RET_OK; return RET_OK;
} }
@@ -136,51 +93,32 @@ STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::str
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }


std::shared_ptr<Path> srcPath;
std::shared_ptr<Path> dstPath;
if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) {
srcPath = matchedPath[kFormatTransNc2NhOp];
dstPath = matchedPath[kFormatTransNh2NcOp];
} else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) {
srcPath = matchedPath[kFormatTransNh2NcOp];
dstPath = matchedPath[kFormatTransNc2NhOp];
} else {
MS_ASSERT(false);
}
if (srcPath == nullptr) {
MS_LOG(ERROR) << "srcPath is failed to get";
return RET_ERROR;
}
if (dstPath == nullptr) {
MS_LOG(ERROR) << "dstPath is failed to get";
std::shared_ptr<Path> srcPath = matchedPath[kFormatTransTranspose1];
std::shared_ptr<Path> dstPath = matchedPath[kFormatTransTranspose2];
if (srcPath == nullptr || dstPath == nullptr) {
MS_LOG(ERROR) << "srcPath or dstPath is failed to get";
return RET_ERROR; return RET_ERROR;
} }
auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); auto srcNode = graph->nodes.at(srcPath->nodeIdx).get();
auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); auto dstNode = graph->nodes.at(dstPath->nodeIdx).get();
MS_ASSERT(srcNode != nullptr); MS_ASSERT(srcNode != nullptr);
MS_ASSERT(dstNode != nullptr); MS_ASSERT(dstNode != nullptr);
if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) {
MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nchw2Nhwc);
MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nhwc2Nchw);
} else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) {
MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nhwc2Nchw);
MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nchw2Nhwc);
} else {
MS_ASSERT(false);
}

auto status = IsolateOneWayNode(graph, srcPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status;
return status;
}

status = IsolateOneWayNode(graph, dstPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status;
return status;
bool isNc2NhAndNh2Nc = srcNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm &&
dstNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm;
bool isNh2NcAndNc2Nh = srcNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm &&
dstNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm;
if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) {
auto status = IsolateOneWayNode(graph, srcPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status;
return status;
}
status = IsolateOneWayNode(graph, dstPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status;
return status;
}
} }

return RET_OK; return RET_OK;
} }
} // namespace lite } // namespace lite


+ 2
- 4
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h View File

@@ -24,12 +24,10 @@


namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
constexpr const char *kFormatTransNc2NhOp = "FormatTransNc2NhOp";
constexpr const char *kFormatTransNh2NcOp = "FormatTransNh2NcOp";
constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1";
constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2";
constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; constexpr const char *kFormatTransPassOp = "FormatTransPassOp";
constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern";
constexpr const char *kNc2NhAndNh2NcPassFusionPattern = "Nc2NhAndNh2NcPassFusionPattern";
constexpr const char *kNh2NcAndNc2NhFusionPattern = "Nh2NcAndNc2NhFusionPattern";
constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern";


class FormatTransFusionPass : public FusionPass { class FormatTransFusionPass : public FusionPass {


+ 0
- 148
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc View File

@@ -1,148 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <vector>
#include <unordered_map>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "src/common/log_adapter.h"
#include "securec/include/securec.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
#define kFormatTransTransposeMatchPathLen 2

STATUS FormatTransPermuteFusionPass::DefinePattern() {
// format trans + permute
{
auto formatTransOp = std::make_shared<PatternOp>();
formatTransOp->id = kFormatTransformOp;
formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};
auto transposeOp = std::make_shared<PatternOp>();
transposeOp->id = kPermuteOp;
transposeOp->types = {PrimitiveType_Transpose};

transposeOp->left = formatTransOp;
std::unique_ptr<FusionPattern> formatTransTransposeFusionPattern(
new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern));
if (formatTransTransposeFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed";
return RET_ERROR;
}
formatTransTransposeFusionPattern->AddPatternOp(formatTransOp);
formatTransTransposeFusionPattern->AddPatternOp(transposeOp);
formatTransTransposeFusionPattern->Finish();
this->patterns.emplace_back(formatTransTransposeFusionPattern.release());
}
// permute + format trans
{
auto formatTransOp = std::make_shared<PatternOp>();
formatTransOp->id = kFormatTransformOp;
formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw};
auto transposeOp = std::make_shared<PatternOp>();
transposeOp->id = kPermuteOp;
transposeOp->types = {PrimitiveType_Transpose};

formatTransOp->left = transposeOp;
std::unique_ptr<FusionPattern> transposeFormatTransFusionPattern(
new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern));
if (transposeFormatTransFusionPattern == nullptr) {
MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed";
return RET_ERROR;
}
transposeFormatTransFusionPattern->AddPatternOp(formatTransOp);
transposeFormatTransFusionPattern->AddPatternOp(transposeOp);
transposeFormatTransFusionPattern->Finish();
this->patterns.emplace_back(transposeFormatTransFusionPattern.release());
}
return RET_OK;
}

STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); }

STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != kFormatTransTransposeMatchPathLen) {
MS_LOG(ERROR) << "schema::Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen
<< " NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

std::shared_ptr<Path> formatTransPath = matchedPath[kFormatTransformOp];
std::shared_ptr<Path> transposePath = matchedPath[kPermuteOp];
if (formatTransPath == nullptr) {
MS_LOG(ERROR) << "formatTransPath is failed to get";
return RET_ERROR;
}
if (transposePath == nullptr) {
MS_LOG(ERROR) << "permutePath is failed to get";
return RET_ERROR;
}
auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx);
auto &transposeNode = graph->nodes.at(transposePath->nodeIdx);
MS_ASSERT(formatTransNode != nullptr);
MS_ASSERT(transposeNode != nullptr);
auto formatTransType = formatTransNode->primitive->value.type;
if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) {
MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or "
<< EnumNamePrimitiveType(PrimitiveType_Nchw2Nhwc) << ", but got "
<< EnumNamePrimitiveType(formatTransType);
return RET_ERROR;
}
MS_ASSERT(transposeNode->primitive != nullptr);
auto transposePrimitive = transposeNode->primitive->value.AsTranspose();
MS_ASSERT(transposePrimitive != nullptr);
auto perm = transposePrimitive->perm;
if (perm.size() != 4) {
return RET_OK;
}
std::vector<int32_t> nchw2nhwcPerm = {0, 2, 3, 1};
std::vector<int32_t> nhwc2nchwPerm = {0, 3, 1, 2};
if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) ||
(perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) {
if (formatTransPath->nodeIdx < transposePath->nodeIdx) {
if (graph->allTensors.at(formatTransNode->inputIndex[0])->format !=
graph->allTensors.at(transposeNode->outputIndex[0])->format) {
return RET_OK;
}
} else {
if (graph->allTensors.at(transposeNode->inputIndex[0])->format !=
graph->allTensors.at(formatTransNode->outputIndex[0])->format) {
return RET_OK;
}
}
auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status;
return status;
}

status = IsolateOneWayNode(graph, transposePath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status;
return status;
}
}

return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 48
mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h View File

@@ -1,48 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H
#define MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H

#include <memory>
#include <string>
#include <unordered_map>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"

namespace mindspore {
namespace lite {
constexpr const char *kFormatTransformOp = "FormatTransOp";
constexpr const char *kPermuteOp = "PermuteOp";
constexpr const char *kFormatTrans2TransposeFusionPattern = "Nc2NhAndNh2NcFusionPattern";
constexpr const char *kTranspose2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern";

class FormatTransPermuteFusionPass : public FusionPass {
public:
FormatTransPermuteFusionPass() = default;

~FormatTransPermuteFusionPass() override = default;

STATUS DefinePattern() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H

+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc View File

@@ -115,7 +115,7 @@ STATUS QuantCastFusionPass::DefinePattern() {
srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; srcOp->types = {schema::PrimitiveType_QuantDTypeCast};
auto formatOp = std::make_shared<PatternOp>(); auto formatOp = std::make_shared<PatternOp>();
formatOp->id = kFormatTransOp; formatOp->id = kFormatTransOp;
formatOp->types = {schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc};
formatOp->types = {PrimitiveType_Transpose};
formatOp->left = srcOp; formatOp->left = srcOp;
auto dstOp = std::make_shared<PatternOp>(); auto dstOp = std::make_shared<PatternOp>();
dstOp->id = kQuantCastDstOp; dstOp->id = kQuantCastDstOp;


+ 36
- 3
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc View File

@@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */


#include <algorithm>
#include <string> #include <string>
#include <memory> #include <memory>
#include <utility> #include <utility>
@@ -196,15 +197,47 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI
} }
auto transNode = std::make_unique<schema::CNodeT>(); auto transNode = std::make_unique<schema::CNodeT>();
transNode->primitive = std::make_unique<schema::PrimitiveT>(); transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.type = schema::PrimitiveType_Transpose;
auto attr = new (std::nothrow) schema::TransposeT();


if (nodeType == kNCHW2NHWC) { if (nodeType == kNCHW2NHWC) {
transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++);
transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc;
attr->perm = {0, 2, 3, 1};
} else { } else {
transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++);
transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw;
attr->perm = {0, 3, 1, 2};
} }
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode);
transNode->primitive->value.value = attr;

OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
auto newOpDef = std::make_unique<schema::CNodeT>();
if (newOpDef == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newOpDef->name = inOpDef->name;
newOpDef->quantType = inOpDef->quantType;
newOpDef->primitive = std::make_unique<schema::PrimitiveT>();
if (newOpDef->primitive == nullptr) {
MS_LOG(ERROR) << "new PrimitiveT failed";
return nullptr;
}
newOpDef->primitive->value.type = schema::PrimitiveType_Transpose;
auto transposeParam = new (std::nothrow) TransposeT;
if (transposeParam == nullptr) {
MS_LOG(ERROR) << "new transposeParam failed";
return nullptr;
}
auto inParam = inOpDef->primitive->value.AsTranspose();
MS_ASSERT(inParam != nullptr);
transposeParam->perm.resize(inParam->perm.size());
std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(),
[](const int32_t ele) { return ele; });
newOpDef->primitive->value.value = transposeParam;
return newOpDef;
};

return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, TransposeOpCopyer);
} }


void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }


+ 10
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc View File

@@ -25,6 +25,10 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"


namespace mindspore { namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite { namespace lite {


STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) {
@@ -34,7 +38,10 @@ STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter; auto &node = *iter;
auto type = node->primitive->value.type; auto type = node->primitive->value.type;
if (type != schema::PrimitiveType_Nchw2Nhwc) {
if (type != PrimitiveType_Transpose) {
continue;
}
if (node->primitive->value.AsTranspose()->perm != nchw2nhwc_perm) {
continue; continue;
} }
std::vector<size_t> pre_nh2nc_nodes; std::vector<size_t> pre_nh2nc_nodes;
@@ -176,7 +183,8 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc
auto &pre_node = graph->nodes.at(input_node_index); auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr); MS_ASSERT(pre_node != nullptr);
auto node_type = pre_node->primitive->value.type; auto node_type = pre_node->primitive->value.type;
if (node_type == schema::PrimitiveType_Nhwc2Nchw) {
if (node_type == schema::PrimitiveType_Transpose &&
pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { if (!IsContain(*pre_nh2nc_nodes, input_node_index)) {
pre_nh2nc_nodes->emplace_back(input_node_index); pre_nh2nc_nodes->emplace_back(input_node_index);
} }


+ 54
- 26
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc View File

@@ -24,12 +24,16 @@
#include "src/common/utils.h" #include "src/common/utils.h"


namespace mindspore { namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite { namespace lite {
bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
MS_ASSERT(node != nullptr); MS_ASSERT(node != nullptr);
auto input_node_indexes = GetInputNodeIdx(*graph, *node); auto input_node_indexes = GetInputNodeIdx(*graph, *node);
pre_type_ = schema::PrimitiveType_NONE;
pre_type_ = kNONE;
size_t has_trans_count = 0; size_t has_trans_count = 0;
auto can_fusion = true; auto can_fusion = true;
for (auto input_node_index : input_node_indexes) { for (auto input_node_index : input_node_indexes) {
@@ -38,16 +42,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
MS_ASSERT(pre_node != nullptr); MS_ASSERT(pre_node != nullptr);
MS_ASSERT(pre_node->primitive != nullptr); MS_ASSERT(pre_node->primitive != nullptr);
MS_ASSERT(pre_node->primitive->value != nullptr); MS_ASSERT(pre_node->primitive->value != nullptr);
if (pre_type_ == schema::PrimitiveType_NONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
pre_type_ = pre_node->primitive->value.type;
if (pre_type_ == kNONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
pre_type_ = kNCHW2NHWC;
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
pre_type_ = kNHWC2NCHW;
} else {
return false;
}
has_trans_count++; has_trans_count++;
} }
} else { } else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (pre_type_ != pre_node->primitive->value.type) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto cur_type = kNONE;
if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
cur_type = kNCHW2NHWC;
} else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
cur_type = kNHWC2NCHW;
} else {
return false;
}
if (pre_type_ != cur_type) {
can_fusion = false; can_fusion = false;
break; break;
} else { } else {
@@ -60,23 +76,35 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
return false; return false;
} }
auto output_node_indexes = GetOutputNodeIdx(*graph, *node); auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
post_type_ = schema::PrimitiveType_NONE;
post_type_ = kNONE;
for (auto output_node_index : output_node_indexes) { for (auto output_node_index : output_node_indexes) {
MS_ASSERT(graph->nodes.size() > output_node_index); MS_ASSERT(graph->nodes.size() > output_node_index);
auto &post_node = graph->nodes.at(output_node_index); auto &post_node = graph->nodes.at(output_node_index);
MS_ASSERT(post_node != nullptr); MS_ASSERT(post_node != nullptr);
MS_ASSERT(post_node->primitive != nullptr); MS_ASSERT(post_node->primitive != nullptr);
MS_ASSERT(post_node->primitive->value != nullptr); MS_ASSERT(post_node->primitive->value != nullptr);
if (post_type_ == schema::PrimitiveType_NONE) {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
post_type_ = post_node->primitive->value.type;
if (post_type_ == kNONE) {
if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) {
if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
post_type_ = kNCHW2NHWC;
} else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
post_type_ = kNHWC2NCHW;
} else {
return false;
}
has_trans_count++; has_trans_count++;
} }
} else { } else {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (post_type_ != post_node->primitive->value.type) {
if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) {
auto cur_type = kNONE;
if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) {
cur_type = kNCHW2NHWC;
} else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) {
cur_type = kNHWC2NCHW;
} else {
return false;
}
if (post_type_ != cur_type) {
can_fusion = false; can_fusion = false;
break; break;
} else { } else {
@@ -88,7 +116,7 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
if (!can_fusion) { if (!can_fusion) {
return false; return false;
} }
if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
if (pre_type_ == kNONE && post_type_ == kNONE) {
return false; return false;
} }
auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size();
@@ -114,21 +142,21 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p
STATUS TransOpInsertPass::FindOutTransType() { STATUS TransOpInsertPass::FindOutTransType() {
pre_insert_trans_type_ = kNHWC2NCHW; pre_insert_trans_type_ = kNHWC2NCHW;
post_insert_trans_type_ = kNHWC2NCHW; post_insert_trans_type_ = kNHWC2NCHW;
if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) {
pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
} else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
} else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
if (pre_type_ == kNONE && post_type_ != kNONE) {
pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
} else if (pre_type_ != kNONE && post_type_ == kNONE) {
pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
} else if (pre_type_ == kNONE && post_type_ == kNONE) {
MS_ASSERT(false); MS_ASSERT(false);
} else { } else {
if (pre_type_ == post_type_) { if (pre_type_ == post_type_) {
MS_LOG(ERROR) << "Unknow error"; MS_LOG(ERROR) << "Unknow error";
return RET_ERROR; return RET_ERROR;
} }
pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
} }
return RET_OK; return RET_OK;
} }


+ 5
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h View File

@@ -18,6 +18,7 @@
#define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H


#include <memory> #include <memory>
#include <vector>
#include "tools/common/graph_util.h" #include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h" #include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
@@ -44,8 +45,10 @@ class TransOpInsertPass : public FormatTransPass {
private: private:
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;
schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE;
schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE;
FormatTransNodeType pre_type_ = kNONE;
std::vector<int> pre_perm_;
FormatTransNodeType post_type_ = kNONE;
std::vector<int> post_perm_;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore


+ 6
- 1
mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc View File

@@ -25,13 +25,18 @@
using mindspore::lite::PrimitiveC; using mindspore::lite::PrimitiveC;
using mindspore::lite::Tensor; using mindspore::lite::Tensor;
namespace mindspore { namespace mindspore {
namespace {
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
} // namespace
namespace lite { namespace lite {
STATUS TransOpRemovePass::Run(MetaGraphT *graph) { STATUS TransOpRemovePass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr); MS_ASSERT(graph != nullptr);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter; auto &node = *iter;
auto type = node->primitive->value.type; auto type = node->primitive->value.type;
if (type == schema::PrimitiveType_Nchw2Nhwc || type == schema::PrimitiveType_Nhwc2Nchw) {
if (type == schema::PrimitiveType_Transpose && (node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm ||
node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm)) {
auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0));
// less than 4 dims can delete // less than 4 dims can delete
if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) {


+ 0
- 2
mindspore/lite/tools/converter/quantizer/calc_quant_param.cc View File

@@ -523,8 +523,6 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer; _registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
_registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>(); _registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>(); _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>();
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode // detection_postprocess op's quant param will not infer only fetch from preNode or postNode
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
// if quantTransNode is inserted after detection_postprocess node, there will be some errors // if quantTransNode is inserted after detection_postprocess node, there will be some errors


+ 19
- 7
mindspore/lite/tools/converter/quantizer/quantize_util.cc View File

@@ -89,13 +89,25 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
auto cnode = std::dynamic_pointer_cast<CNode>(node); auto cnode = std::dynamic_pointer_cast<CNode>(node);
auto type = NodePrimitiveType(cnode); auto type = NodePrimitiveType(cnode);
static const std::vector<schema::PrimitiveType> int8OpList = { static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Mul,
schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_Split,
schema::PrimitiveType_TupleGetItem, schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Crop, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Activation, schema::PrimitiveType_Transpose,
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Gather, schema::PrimitiveType_LayerNorm,
schema::PrimitiveType_Conv2D,
schema::PrimitiveType_DepthwiseConv2D,
schema::PrimitiveType_Add,
schema::PrimitiveType_Mul,
schema::PrimitiveType_Pooling,
schema::PrimitiveType_Concat,
schema::PrimitiveType_Split,
schema::PrimitiveType_TupleGetItem,
schema::PrimitiveType_Reshape,
schema::PrimitiveType_FullConnection,
schema::PrimitiveType_MatMul,
schema::PrimitiveType_Crop,
schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_Activation,
schema::PrimitiveType_Transpose,
schema::PrimitiveType_Eltwise,
schema::PrimitiveType_Gather,
schema::PrimitiveType_LayerNorm,
}; };
bool contain = IsContain(int8OpList, type); bool contain = IsContain(int8OpList, type);
if (!contain) { if (!contain) {


+ 2
- 0
mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc View File

@@ -547,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) {
status = ReplaceConstant(func_graph, cnode); status = ReplaceConstant(func_graph, cnode);
} else if (type == schema::PrimitiveType_Cast) { } else if (type == schema::PrimitiveType_Cast) {
status = AdjustCast(cnode); status = AdjustCast(cnode);
} else if (type == schema::PrimitiveType_Transpose) {
status = ReplaceTransposeWithGraphInput(func_graph, cnode);
} }
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "adjust input pass is failed."; MS_LOG(ERROR) << "adjust input pass is failed.";


Loading…
Cancel
Save