diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 19737f823f..5a4ad2de76 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -102,11 +102,11 @@ union PrimitiveType { Tile, Cast, Shape, - Nchw2Nhwc, - Nhwc2Nchw, + Nchw2Nhwc, // DEPRECATED + Nhwc2Nchw, // DEPRECATED QuantDTypeCast, Split, - Permute, + Permute, // DEPRECATED FakeQuantWithMinMaxVars, Equal, Less, diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 942d75ce3e..011f006440 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -338,11 +338,11 @@ table ConstantOfShape{ value: [float]; } -table Nchw2Nhwc { +table Nchw2Nhwc { // DEPRECATED } -table Nhwc2Nchw { +table Nhwc2Nchw { // DEPRECATED } @@ -729,7 +729,7 @@ table Crop { offsets : [long]; } -table Permute { +table Permute { // DEPRECATED order: [long]; } diff --git a/mindspore/lite/src/ops/permute.cc b/mindspore/lite/src/ops/permute.cc deleted file mode 100644 index de1dccc1d6..0000000000 --- a/mindspore/lite/src/ops/permute.cc +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/ops/permute.h" - -#ifndef PRIMITIVE_WRITEABLE -#include "src/ops/ops_register.h" -#endif - -namespace mindspore { -namespace lite { -#ifdef PRIMITIVE_WRITEABLE -std::vector Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; } - -void Permute::SetOrder(const std::vector &order) { this->primitive_->value.AsPermute()->order = order; } - -#else - -std::vector Permute::GetOrder() const { - auto fb_vector = this->primitive_->value_as_Permute()->order(); - return std::vector(fb_vector->begin(), fb_vector->end()); -} - -void Permute::SetOrder(const std::vector &order) {} -int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { - MS_ASSERT(nullptr != primitive); - MS_ASSERT(nullptr != fbb); - auto attr = primitive->value_as_Permute(); - if (attr == nullptr) { - MS_LOG(ERROR) << "value_as_Permute return nullptr"; - return RET_ERROR; - } - std::vector order; - if (attr->order() != nullptr) { - for (int i = 0; i < static_cast(attr->order()->size()); i++) { - order.push_back(attr->order()->data()[i]); - } - } - auto val_offset = schema::CreatePermuteDirect(*fbb, &order); - auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Permute, val_offset.o); - fbb->Finish(prim_offset); - return RET_OK; -} - -PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC(primitive); } -Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator); -#endif -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/src/ops/permute.h b/mindspore/lite/src/ops/permute.h deleted file mode 100644 index a4b5acfd4b..0000000000 --- a/mindspore/lite/src/ops/permute.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ -#define LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ - -#include -#include -#include -#include - -#include "src/ops/primitive_c.h" - -namespace mindspore { -namespace lite { -class Permute : public PrimitiveC { - public: - Permute() = default; - ~Permute() = default; -#ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(Permute, PrimitiveC); - explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} -#else - int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; -#endif - std::vector GetOrder() const; - void SetOrder(const std::vector &order); -}; -} // namespace lite -} // namespace mindspore - -#endif // LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ diff --git a/mindspore/lite/src/ops/transpose.cc b/mindspore/lite/src/ops/transpose.cc index 63e2f111a4..1b09454972 100644 --- a/mindspore/lite/src/ops/transpose.cc +++ b/mindspore/lite/src/ops/transpose.cc @@ -110,24 +110,32 @@ Registry TransposeRegistry(schema::PrimitiveType_Transpose, TransposeCreator); #endif int Transpose::InferShape(std::vector inputs_, std::vector outputs_) { - MS_ASSERT(this->primitive_ != nullptr); auto input = inputs_.front(); - MS_ASSERT(input != nullptr); auto output = outputs_.front(); + MS_ASSERT(input != nullptr); MS_ASSERT(output != nullptr); + + std::vector perm = GetPerm(); + std::vector nchw2nhwc_perm = {0, 2, 3, 1}; + std::vector nhwc2nchw_perm = {0, 3, 1, 2}; + std::vector in_shape = input->shape(); + output->set_data_type(input->data_type()); - output->set_format(input->format()); + if (input->format() == schema::Format::Format_NCHW && perm == nchw2nhwc_perm) { + output->set_format(schema::Format::Format_NHWC); + } else if (input->format() == schema::Format::Format_NHWC && perm == nhwc2nchw_perm) { + output->set_format(schema::Format::Format_NCHW); + } else { + output->set_format(input->format()); + } if (!infer_flag()) { return RET_INFER_INVALID; } - MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum); - MS_ASSERT(outputs_.size() == kSingleNum); - std::vector perm; - for (size_t i = 0; i < GetPerm().size(); i++) { - perm.push_back(GetPerm().at(i)); + if (in_shape.size() != 4 && perm.size() == 4) { + output->set_shape(in_shape); + return RET_OK; } - std::vector in_shape = input->shape(); std::vector out_shape; out_shape.resize(perm.size()); for (size_t i = 0; i < perm.size(); ++i) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc index 7b4b7a9c9c..bdab6b2da9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/transpose_fp16.cc @@ -48,6 +48,14 @@ int TransposeFp16CPUKernel::Run() { } in_data_fp16_ = reinterpret_cast(in_tensor->MutableData()); out_data_fp16_ = reinterpret_cast(out_tensor->MutableData()); + MS_ASSERT(in_data_fp16_); + MS_ASSERT(out_data_fp16_); + + TransposeParameter *param = reinterpret_cast(this->op_parameter_); + if (in_tensor->shape().size() != static_cast(param->num_axes_)) { + memcpy(out_data_fp16_, in_data_fp16_, in_tensor->ElementsNum() * sizeof(float16_t)); + return RET_OK; + } int dims = out_tensor->shape().size(); if (dims > MAX_TRANSPOSE_DIM_SIZE) { dim_size_ = reinterpret_cast(context_->allocator->Malloc(dims * sizeof(int))); @@ -63,10 +71,7 @@ int TransposeFp16CPUKernel::Run() { return RET_ERROR; } } - TransposeParameter *param = reinterpret_cast(this->op_parameter_); - MS_ASSERT(param); - MS_ASSERT(in_data_fp16_); - MS_ASSERT(out_data_fp16_); + MS_ASSERT(out_shape_); auto ret = Fp16DoTranspose(in_data_fp16_, out_data_fp16_, out_shape_, param, dim_size_, position_); if (dims > MAX_TRANSPOSE_DIM_SIZE) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.cc deleted file mode 100644 index 6c1acf7982..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h" - -using mindspore::kernel::KERNEL_ARCH::kCPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_Nchw2Nhwc; - -namespace mindspore::kernel { -int Nchw2NhwcCPUKernel::Init() { return RET_OK; } - -int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; } - -int Nchw2NhwcCPUKernel::Run() { - auto input = in_tensors_.at(0); - auto output = out_tensors_.at(0); - - if (input->shape().size() == 4) { - if (input->data_type() == kNumberTypeFloat32) { - PackNCHWToNHWCFp32(input->MutableData(), output->MutableData(), output->Batch(), - output->Height() * output->Width(), output->Channel()); - } else if (input->data_type() == kNumberTypeInt8) { - PackNCHWToNHWCInt8(input->MutableData(), output->MutableData(), output->Batch(), - output->Height() * output->Width(), output->Channel()); - } - } else { - memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float)); - } - return RET_OK; -} - -kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc); - auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!"; - free(opParameter); - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - delete kernel; - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h deleted file mode 100644 index 29ccbc6bb6..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ - -#include -#include "src/lite_kernel.h" - -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "nnacl/pack.h" - -namespace mindspore::kernel { -class Nchw2NhwcCPUKernel : public LiteKernel { - public: - Nchw2NhwcCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~Nchw2NhwcCPUKernel() override = default; - - int Init() override; - int ReSize() override; - int Run() override; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.cc deleted file mode 100644 index 0bcc47fb96..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.cc +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h" - -using mindspore::kernel::KERNEL_ARCH::kCPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_Nhwc2Nchw; - -namespace mindspore::kernel { -int Nhwc2NchwCPUKernel::Init() { return RET_OK; } - -int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; } - -int Nhwc2NchwCPUKernel::Run() { - auto input = in_tensors_.at(0); - auto output = out_tensors_.at(0); - - if (input->shape().size() == 4) { - if (input->data_type() == kNumberTypeFloat32) { - PackNHWCToNCHWFp32(input->MutableData(), output->MutableData(), output->Batch(), - output->Height() * output->Width(), output->Channel()); - } else if (input->data_type() == kNumberTypeInt8) { - PackNHWCToNCHWInt8(input->MutableData(), output->MutableData(), output->Batch(), - output->Height() * output->Width(), output->Channel()); - } - } else { - memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float)); - } - return RET_OK; -} - -kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector &inputs, - const std::vector &outputs, OpParameter *opParameter, - const lite::InnerContext *ctx, const kernel::KernelKey &desc, - const mindspore::lite::PrimitiveC *primitive) { - MS_ASSERT(opParameter != nullptr); - MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw); - auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs, ctx, primitive); - if (kernel == nullptr) { - MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!"; - free(opParameter); - return nullptr; - } - auto ret = kernel->Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); - delete kernel; - return nullptr; - } - return kernel; -} - -REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) -REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h deleted file mode 100644 index 16cd5f599a..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h +++ /dev/null @@ -1,42 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ - -#include -#include "src/lite_kernel.h" - -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "nnacl/pack.h" - -namespace mindspore::kernel { -class Nhwc2NchwCPUKernel : public LiteKernel { - public: - Nhwc2NchwCPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} - ~Nhwc2NchwCPUKernel() override = default; - - int Init() override; - int ReSize() override; - int Run() override; -}; -} // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc index 0288052088..35e53a79af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/transpose_fp32.cc @@ -18,11 +18,14 @@ #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" +#include "nnacl/pack.h" using mindspore::lite::KernelRegistrar; using mindspore::lite::RET_ERROR; using mindspore::lite::RET_OK; using mindspore::lite::RET_OP_EXECUTE_FAILURE; +using mindspore::schema::PrimitiveType_Nchw2Nhwc; +using mindspore::schema::PrimitiveType_Nhwc2Nchw; using mindspore::schema::PrimitiveType_Transpose; namespace mindspore::kernel { @@ -36,7 +39,9 @@ int TransposeCPUKernel::Init() { int TransposeCPUKernel::ReSize() { TransposeParameter *param = reinterpret_cast(op_parameter_); - + if (in_tensors_.at(kInputIndex)->shape().size() != static_cast(param->num_axes_)) { + return RET_OK; + } auto &inTensor = in_tensors_.front(); auto &outTensor = out_tensors_.front(); auto in_shape = inTensor->shape(); @@ -80,6 +85,41 @@ int TransposeCPUKernel::Run() { } in_data_ = reinterpret_cast(in_tensor->MutableData()); out_data_ = reinterpret_cast(out_tensor->MutableData()); + MS_ASSERT(in_data_); + MS_ASSERT(out_data_); + + TransposeParameter *param = reinterpret_cast(this->op_parameter_); + if (in_tensor->shape().size() != static_cast(param->num_axes_)) { + memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float)); + return RET_OK; + } + if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 2 && param->perm_[2] == 3 && + param->perm_[3] == 1) { + if (in_tensor->data_type() == kNumberTypeFloat32) { + PackNCHWToNHWCFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), + out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); + } else if (in_tensor->data_type() == kNumberTypeInt8) { + PackNCHWToNHWCInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), + out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); + } + return RET_OK; + } + if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 3 && param->perm_[2] == 1 && + param->perm_[3] == 2) { + if (in_tensor->data_type() == kNumberTypeFloat32) { + PackNHWCToNCHWFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), + out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); + } else if (in_tensor->data_type() == kNumberTypeInt8) { + PackNHWCToNCHWInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), + out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); + } + return RET_OK; + } + if (in_tensor->data_type() == kNumberTypeInt8) { + MS_LOG(ERROR) << "not support now"; + return RET_ERROR; + } + int dims = out_tensor->shape().size(); if (dims > MAX_TRANSPOSE_DIM_SIZE) { dim_size_ = reinterpret_cast(context_->allocator->Malloc(dims * sizeof(int))); @@ -96,10 +136,6 @@ int TransposeCPUKernel::Run() { } } - TransposeParameter *param = reinterpret_cast(this->op_parameter_); - MS_ASSERT(param); - MS_ASSERT(in_data_); - MS_ASSERT(out_data_); MS_ASSERT(out_shape_); auto ret = DoTransposeFp32(in_data_, out_data_, out_shape_, param, dim_size_, position_); if (dims > MAX_TRANSPOSE_DIM_SIZE) { @@ -143,4 +179,9 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector fp32FullOpList = { static const std::vector int8NeedNhwcOpList = {}; -static const std::vector int8OpList = {schema::PrimitiveType_Nchw2Nhwc, - schema::PrimitiveType_Nhwc2Nchw, - schema::PrimitiveType_Conv2D, +static const std::vector int8OpList = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Transpose, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc index 14d99e76b9..ee4e36c350 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" #include "src/common/log_adapter.h" @@ -24,103 +25,59 @@ #include "schema/inner/model_generated.h" namespace mindspore { +namespace { +std::vector nchw2nhwc_perm = {0, 2, 3, 1}; +std::vector nhwc2nchw_perm = {0, 3, 1, 2}; +} // namespace namespace lite { #define kFormatTransMatchPathLen2 2 #define kFormatTransMatchPathLen3 3 STATUS FormatTransFusionPass::DefinePattern() { - // nchw2nhwc + nhwc2nchw + // nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc { - auto nc2nhOp = std::make_shared(); - nc2nhOp->id = kFormatTransNc2NhOp; - nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; - auto nh2ncOp = std::make_shared(); - nh2ncOp->id = kFormatTransNh2NcOp; - nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + auto transpose1 = std::make_shared(); + transpose1->id = kFormatTransTranspose1; + transpose1->types = {PrimitiveType_Transpose}; + auto transpose2 = std::make_shared(); + transpose2->id = kFormatTransTranspose2; + transpose2->types = {PrimitiveType_Transpose}; - nh2ncOp->left = nc2nhOp; - std::unique_ptr nc2NhAndNh2NcFusionPattern(new (std::nothrow) - FusionPattern(kNc2NhAndNh2NcFusionPattern)); - if (nc2NhAndNh2NcFusionPattern == nullptr) { + transpose2->left = transpose1; + auto pattern = std::make_unique(kNc2NhAndNh2NcFusionPattern); + if (pattern == nullptr) { MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; return RET_ERROR; } - nc2NhAndNh2NcFusionPattern->AddPatternOp(nc2nhOp); - nc2NhAndNh2NcFusionPattern->AddPatternOp(nh2ncOp); - nc2NhAndNh2NcFusionPattern->Finish(); - this->patterns.emplace_back(nc2NhAndNh2NcFusionPattern.release()); + pattern->AddPatternOp(transpose1); + pattern->AddPatternOp(transpose2); + pattern->Finish(); + this->patterns.emplace_back(pattern.release()); } + // nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw { - auto nc2nhOp = std::make_shared(); - nc2nhOp->id = kFormatTransNc2NhOp; - nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; + auto transpose1 = std::make_shared(); + transpose1->id = kFormatTransTranspose1; + transpose1->types = {PrimitiveType_Transpose}; auto passOp = std::make_shared(); passOp->id = kFormatTransPassOp; passOp->types = {PrimitiveType_QuantDTypeCast}; - auto nh2ncOp = std::make_shared(); - nh2ncOp->id = kFormatTransNh2NcOp; - nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; + auto transpose2 = std::make_shared(); + transpose2->id = kFormatTransTranspose2; + transpose2->types = {PrimitiveType_Transpose}; - passOp->left = nc2nhOp; - nh2ncOp->left = passOp; - std::unique_ptr nc2NhAndNh2NcPassFusionPattern(new (std::nothrow) - FusionPattern(kNc2NhAndNh2NcPassFusionPattern)); - if (nc2NhAndNh2NcPassFusionPattern == nullptr) { - MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcPassFusionPattern << "failed"; - return RET_ERROR; - } - nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nc2nhOp); - nc2NhAndNh2NcPassFusionPattern->AddPatternOp(passOp); - nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nh2ncOp); - nc2NhAndNh2NcPassFusionPattern->Finish(); - this->patterns.emplace_back(nc2NhAndNh2NcPassFusionPattern.release()); - } - // nhwc2nchw + nchw2nhwc - { - auto nc2nhOp = std::make_shared(); - nc2nhOp->id = kFormatTransNc2NhOp; - nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; - auto nh2ncOp = std::make_shared(); - nh2ncOp->id = kFormatTransNh2NcOp; - nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; - - nc2nhOp->left = nh2ncOp; - std::unique_ptr nh2NcAndNc2NhFusionPattern(new (std::nothrow) - FusionPattern(kNh2NcAndNc2NhFusionPattern)); - if (nh2NcAndNc2NhFusionPattern == nullptr) { - MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhFusionPattern << "failed"; - return RET_ERROR; - } - nh2NcAndNc2NhFusionPattern->AddPatternOp(nh2ncOp); - nh2NcAndNc2NhFusionPattern->AddPatternOp(nc2nhOp); - nh2NcAndNc2NhFusionPattern->Finish(); - this->patterns.emplace_back(nh2NcAndNc2NhFusionPattern.release()); - } - // nhwc2nchw + QuantDtypeCast + nchw2nhwc - { - auto nc2nhOp = std::make_shared(); - nc2nhOp->id = kFormatTransNc2NhOp; - nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; - auto passOp = std::make_shared(); - passOp->id = kFormatTransPassOp; - passOp->types = {PrimitiveType_QuantDTypeCast}; - auto nh2ncOp = std::make_shared(); - nh2ncOp->id = kFormatTransNh2NcOp; - nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; - - passOp->left = nh2ncOp; - nc2nhOp->left = passOp; - std::unique_ptr nh2NcAndNc2NhPassFusionPattern(new (std::nothrow) - FusionPattern(kNh2NcAndNc2NhPassFusionPattern)); - if (nh2NcAndNc2NhPassFusionPattern == nullptr) { + passOp->left = transpose2; + transpose1->left = passOp; + auto pattern = std::make_unique(kNh2NcAndNc2NhPassFusionPattern); + if (pattern == nullptr) { MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; return RET_ERROR; } - nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nh2ncOp); - nh2NcAndNc2NhPassFusionPattern->AddPatternOp(passOp); - nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nc2nhOp); - nh2NcAndNc2NhPassFusionPattern->Finish(); - this->patterns.emplace_back(nh2NcAndNc2NhPassFusionPattern.release()); + pattern->AddPatternOp(transpose1); + pattern->AddPatternOp(passOp); + pattern->AddPatternOp(transpose2); + pattern->Finish(); + this->patterns.emplace_back(pattern.release()); } return RET_OK; } @@ -136,51 +93,32 @@ STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::str return RET_PARAM_INVALID; } - std::shared_ptr srcPath; - std::shared_ptr dstPath; - if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { - srcPath = matchedPath[kFormatTransNc2NhOp]; - dstPath = matchedPath[kFormatTransNh2NcOp]; - } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { - srcPath = matchedPath[kFormatTransNh2NcOp]; - dstPath = matchedPath[kFormatTransNc2NhOp]; - } else { - MS_ASSERT(false); - } - if (srcPath == nullptr) { - MS_LOG(ERROR) << "srcPath is failed to get"; - return RET_ERROR; - } - if (dstPath == nullptr) { - MS_LOG(ERROR) << "dstPath is failed to get"; + std::shared_ptr srcPath = matchedPath[kFormatTransTranspose1]; + std::shared_ptr dstPath = matchedPath[kFormatTransTranspose2]; + if (srcPath == nullptr || dstPath == nullptr) { + MS_LOG(ERROR) << "srcPath or dstPath is failed to get"; return RET_ERROR; } auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); MS_ASSERT(srcNode != nullptr); MS_ASSERT(dstNode != nullptr); - if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { - MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nchw2Nhwc); - MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nhwc2Nchw); - } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { - MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nhwc2Nchw); - MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nchw2Nhwc); - } else { - MS_ASSERT(false); - } - - auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; - return status; - } - - status = IsolateOneWayNode(graph, dstPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; - return status; + bool isNc2NhAndNh2Nc = srcNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm && + dstNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm; + bool isNh2NcAndNc2Nh = srcNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm && + dstNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm; + if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) { + auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; + return status; + } + status = IsolateOneWayNode(graph, dstPath->nodeIdx); + if (status != RET_OK) { + MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; + return status; + } } - return RET_OK; } } // namespace lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h index 60541edaad..9ca9c9b2c7 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h @@ -24,12 +24,10 @@ namespace mindspore { namespace lite { -constexpr const char *kFormatTransNc2NhOp = "FormatTransNc2NhOp"; -constexpr const char *kFormatTransNh2NcOp = "FormatTransNh2NcOp"; +constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1"; +constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2"; constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; -constexpr const char *kNc2NhAndNh2NcPassFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; -constexpr const char *kNh2NcAndNc2NhFusionPattern = "Nh2NcAndNc2NhFusionPattern"; constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; class FormatTransFusionPass : public FusionPass { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc deleted file mode 100644 index f4ec0bb792..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.cc +++ /dev/null @@ -1,148 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" -#include "src/common/log_adapter.h" -#include "securec/include/securec.h" -#include "tools/common/graph_util.h" -#include "include/errorcode.h" -#include "schema/inner/model_generated.h" - -namespace mindspore { -namespace lite { -#define kFormatTransTransposeMatchPathLen 2 - -STATUS FormatTransPermuteFusionPass::DefinePattern() { - // format trans + permute - { - auto formatTransOp = std::make_shared(); - formatTransOp->id = kFormatTransformOp; - formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; - auto transposeOp = std::make_shared(); - transposeOp->id = kPermuteOp; - transposeOp->types = {PrimitiveType_Transpose}; - - transposeOp->left = formatTransOp; - std::unique_ptr formatTransTransposeFusionPattern( - new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern)); - if (formatTransTransposeFusionPattern == nullptr) { - MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed"; - return RET_ERROR; - } - formatTransTransposeFusionPattern->AddPatternOp(formatTransOp); - formatTransTransposeFusionPattern->AddPatternOp(transposeOp); - formatTransTransposeFusionPattern->Finish(); - this->patterns.emplace_back(formatTransTransposeFusionPattern.release()); - } - // permute + format trans - { - auto formatTransOp = std::make_shared(); - formatTransOp->id = kFormatTransformOp; - formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; - auto transposeOp = std::make_shared(); - transposeOp->id = kPermuteOp; - transposeOp->types = {PrimitiveType_Transpose}; - - formatTransOp->left = transposeOp; - std::unique_ptr transposeFormatTransFusionPattern( - new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern)); - if (transposeFormatTransFusionPattern == nullptr) { - MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed"; - return RET_ERROR; - } - transposeFormatTransFusionPattern->AddPatternOp(formatTransOp); - transposeFormatTransFusionPattern->AddPatternOp(transposeOp); - transposeFormatTransFusionPattern->Finish(); - this->patterns.emplace_back(transposeFormatTransFusionPattern.release()); - } - return RET_OK; -} - -STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } - -STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) { - MS_ASSERT(graph != nullptr); - if (matchedPath.size() != kFormatTransTransposeMatchPathLen) { - MS_LOG(ERROR) << "schema::Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen - << " NodeIndex in matchedPair"; - return RET_PARAM_INVALID; - } - - std::shared_ptr formatTransPath = matchedPath[kFormatTransformOp]; - std::shared_ptr transposePath = matchedPath[kPermuteOp]; - if (formatTransPath == nullptr) { - MS_LOG(ERROR) << "formatTransPath is failed to get"; - return RET_ERROR; - } - if (transposePath == nullptr) { - MS_LOG(ERROR) << "permutePath is failed to get"; - return RET_ERROR; - } - auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx); - auto &transposeNode = graph->nodes.at(transposePath->nodeIdx); - MS_ASSERT(formatTransNode != nullptr); - MS_ASSERT(transposeNode != nullptr); - auto formatTransType = formatTransNode->primitive->value.type; - if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) { - MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or " - << EnumNamePrimitiveType(PrimitiveType_Nchw2Nhwc) << ", but got " - << EnumNamePrimitiveType(formatTransType); - return RET_ERROR; - } - MS_ASSERT(transposeNode->primitive != nullptr); - auto transposePrimitive = transposeNode->primitive->value.AsTranspose(); - MS_ASSERT(transposePrimitive != nullptr); - auto perm = transposePrimitive->perm; - if (perm.size() != 4) { - return RET_OK; - } - std::vector nchw2nhwcPerm = {0, 2, 3, 1}; - std::vector nhwc2nchwPerm = {0, 3, 1, 2}; - if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || - (perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { - if (formatTransPath->nodeIdx < transposePath->nodeIdx) { - if (graph->allTensors.at(formatTransNode->inputIndex[0])->format != - graph->allTensors.at(transposeNode->outputIndex[0])->format) { - return RET_OK; - } - } else { - if (graph->allTensors.at(transposeNode->inputIndex[0])->format != - graph->allTensors.at(formatTransNode->outputIndex[0])->format) { - return RET_OK; - } - } - auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status; - return status; - } - - status = IsolateOneWayNode(graph, transposePath->nodeIdx); - if (status != RET_OK) { - MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status; - return status; - } - } - - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h deleted file mode 100644 index 722227f19a..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H -#define MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H - -#include -#include -#include -#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" - -namespace mindspore { -namespace lite { -constexpr const char *kFormatTransformOp = "FormatTransOp"; -constexpr const char *kPermuteOp = "PermuteOp"; -constexpr const char *kFormatTrans2TransposeFusionPattern = "Nc2NhAndNh2NcFusionPattern"; -constexpr const char *kTranspose2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; - -class FormatTransPermuteFusionPass : public FusionPass { - public: - FormatTransPermuteFusionPass() = default; - - ~FormatTransPermuteFusionPass() override = default; - - STATUS DefinePattern() override; - - STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, - std::unordered_map> &matchedPath) override; - - STATUS Run(schema::MetaGraphT *graph) override; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc index d3c5231290..6d968f41a5 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.cc @@ -115,7 +115,7 @@ STATUS QuantCastFusionPass::DefinePattern() { srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; auto formatOp = std::make_shared(); formatOp->id = kFormatTransOp; - formatOp->types = {schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc}; + formatOp->types = {PrimitiveType_Transpose}; formatOp->left = srcOp; auto dstOp = std::make_shared(); dstOp->id = kQuantCastDstOp; diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc index 0020e090b0..547e6ad02e 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include #include @@ -196,15 +197,47 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI } auto transNode = std::make_unique(); transNode->primitive = std::make_unique(); + transNode->primitive->value.type = schema::PrimitiveType_Transpose; + auto attr = new (std::nothrow) schema::TransposeT(); if (nodeType == kNCHW2NHWC) { transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); - transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; + attr->perm = {0, 2, 3, 1}; } else { transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); - transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; + attr->perm = {0, 3, 1, 2}; } - return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); + transNode->primitive->value.value = attr; + + OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr { + auto newOpDef = std::make_unique(); + if (newOpDef == nullptr) { + MS_LOG(ERROR) << "new CNodeT failed"; + return nullptr; + } + newOpDef->name = inOpDef->name; + newOpDef->quantType = inOpDef->quantType; + newOpDef->primitive = std::make_unique(); + if (newOpDef->primitive == nullptr) { + MS_LOG(ERROR) << "new PrimitiveT failed"; + return nullptr; + } + newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; + auto transposeParam = new (std::nothrow) TransposeT; + if (transposeParam == nullptr) { + MS_LOG(ERROR) << "new transposeParam failed"; + return nullptr; + } + auto inParam = inOpDef->primitive->value.AsTranspose(); + MS_ASSERT(inParam != nullptr); + transposeParam->perm.resize(inParam->perm.size()); + std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(), + [](const int32_t ele) { return ele; }); + newOpDef->primitive->value.value = transposeParam; + return newOpDef; + }; + + return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, TransposeOpCopyer); } void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc index b78738f1c7..188366dbf8 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/global_format_transform_pass.cc @@ -25,6 +25,10 @@ #include "schema/inner/model_generated.h" namespace mindspore { +namespace { +std::vector nchw2nhwc_perm = {0, 2, 3, 1}; +std::vector nhwc2nchw_perm = {0, 3, 1, 2}; +} // namespace namespace lite { STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { @@ -34,7 +38,10 @@ STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto &node = *iter; auto type = node->primitive->value.type; - if (type != schema::PrimitiveType_Nchw2Nhwc) { + if (type != PrimitiveType_Transpose) { + continue; + } + if (node->primitive->value.AsTranspose()->perm != nchw2nhwc_perm) { continue; } std::vector pre_nh2nc_nodes; @@ -176,7 +183,8 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc auto &pre_node = graph->nodes.at(input_node_index); MS_ASSERT(pre_node != nullptr); auto node_type = pre_node->primitive->value.type; - if (node_type == schema::PrimitiveType_Nhwc2Nchw) { + if (node_type == schema::PrimitiveType_Transpose && + pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { pre_nh2nc_nodes->emplace_back(input_node_index); } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc index 0c4c62efd4..3efc133e5b 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.cc @@ -24,12 +24,16 @@ #include "src/common/utils.h" namespace mindspore { +namespace { +std::vector nchw2nhwc_perm = {0, 2, 3, 1}; +std::vector nhwc2nchw_perm = {0, 3, 1, 2}; +} // namespace namespace lite { bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr &node) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); auto input_node_indexes = GetInputNodeIdx(*graph, *node); - pre_type_ = schema::PrimitiveType_NONE; + pre_type_ = kNONE; size_t has_trans_count = 0; auto can_fusion = true; for (auto input_node_index : input_node_indexes) { @@ -38,16 +42,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p MS_ASSERT(pre_node != nullptr); MS_ASSERT(pre_node->primitive != nullptr); MS_ASSERT(pre_node->primitive->value != nullptr); - if (pre_type_ == schema::PrimitiveType_NONE) { - if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - pre_type_ = pre_node->primitive->value.type; + if (pre_type_ == kNONE) { + if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { + if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { + pre_type_ = kNCHW2NHWC; + } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { + pre_type_ = kNHWC2NCHW; + } else { + return false; + } has_trans_count++; } } else { - if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - if (pre_type_ != pre_node->primitive->value.type) { + if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { + auto cur_type = kNONE; + if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { + cur_type = kNCHW2NHWC; + } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { + cur_type = kNHWC2NCHW; + } else { + return false; + } + if (pre_type_ != cur_type) { can_fusion = false; break; } else { @@ -60,23 +76,35 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p return false; } auto output_node_indexes = GetOutputNodeIdx(*graph, *node); - post_type_ = schema::PrimitiveType_NONE; + post_type_ = kNONE; for (auto output_node_index : output_node_indexes) { MS_ASSERT(graph->nodes.size() > output_node_index); auto &post_node = graph->nodes.at(output_node_index); MS_ASSERT(post_node != nullptr); MS_ASSERT(post_node->primitive != nullptr); MS_ASSERT(post_node->primitive->value != nullptr); - if (post_type_ == schema::PrimitiveType_NONE) { - if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - post_type_ = post_node->primitive->value.type; + if (post_type_ == kNONE) { + if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { + if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { + post_type_ = kNCHW2NHWC; + } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { + post_type_ = kNHWC2NCHW; + } else { + return false; + } has_trans_count++; } } else { - if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || - post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { - if (post_type_ != post_node->primitive->value.type) { + if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { + auto cur_type = kNONE; + if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { + cur_type = kNCHW2NHWC; + } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { + cur_type = kNHWC2NCHW; + } else { + return false; + } + if (post_type_ != cur_type) { can_fusion = false; break; } else { @@ -88,7 +116,7 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p if (!can_fusion) { return false; } - if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { + if (pre_type_ == kNONE && post_type_ == kNONE) { return false; } auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); @@ -114,21 +142,21 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p STATUS TransOpInsertPass::FindOutTransType() { pre_insert_trans_type_ = kNHWC2NCHW; post_insert_trans_type_ = kNHWC2NCHW; - if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) { - pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; - post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; - } else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { - pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; - post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; - } else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { + if (pre_type_ == kNONE && post_type_ != kNONE) { + pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; + post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; + } else if (pre_type_ != kNONE && post_type_ == kNONE) { + pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; + post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; + } else if (pre_type_ == kNONE && post_type_ == kNONE) { MS_ASSERT(false); } else { if (pre_type_ == post_type_) { MS_LOG(ERROR) << "Unknow error"; return RET_ERROR; } - pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; - post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; + pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; + post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; } return RET_OK; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h index 3b5e21c78d..e3172302a7 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h @@ -18,6 +18,7 @@ #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H #include +#include #include "tools/common/graph_util.h" #include "tools/converter/converter_flags.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" @@ -44,8 +45,10 @@ class TransOpInsertPass : public FormatTransPass { private: FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; - schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE; - schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE; + FormatTransNodeType pre_type_ = kNONE; + std::vector pre_perm_; + FormatTransNodeType post_type_ = kNONE; + std::vector post_perm_; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc index 8fe7afe469..d3fc0a7c2b 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/trans_format_remove_pass.cc @@ -25,13 +25,18 @@ using mindspore::lite::PrimitiveC; using mindspore::lite::Tensor; namespace mindspore { +namespace { +std::vector nchw2nhwc_perm = {0, 2, 3, 1}; +std::vector nhwc2nchw_perm = {0, 3, 1, 2}; +} // namespace namespace lite { STATUS TransOpRemovePass::Run(MetaGraphT *graph) { MS_ASSERT(graph != nullptr); for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { auto &node = *iter; auto type = node->primitive->value.type; - if (type == schema::PrimitiveType_Nchw2Nhwc || type == schema::PrimitiveType_Nhwc2Nchw) { + if (type == schema::PrimitiveType_Transpose && (node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm || + node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm)) { auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); // less than 4 dims can delete if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index 45f318b1dd..0218a02520 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -523,8 +523,6 @@ QuantParamCalcRegister::QuantParamCalcRegister() { _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; _registerMap[schema::PrimitiveType_MatMul] = std::make_shared(); _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared(); - _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; - _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; // detection_postprocess op's quant param will not infer only fetch from preNode or postNode // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. // if quantTransNode is inserted after detection_postprocess node, there will be some errors diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index f7693eb5b0..5de9b7d43d 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -89,13 +89,25 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { auto cnode = std::dynamic_pointer_cast(node); auto type = NodePrimitiveType(cnode); static const std::vector int8OpList = { - schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, - schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Mul, - schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_Split, - schema::PrimitiveType_TupleGetItem, schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, - schema::PrimitiveType_MatMul, schema::PrimitiveType_Crop, schema::PrimitiveType_DeDepthwiseConv2D, - schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Activation, schema::PrimitiveType_Transpose, - schema::PrimitiveType_Eltwise, schema::PrimitiveType_Gather, schema::PrimitiveType_LayerNorm, + schema::PrimitiveType_Conv2D, + schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_Add, + schema::PrimitiveType_Mul, + schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, + schema::PrimitiveType_Split, + schema::PrimitiveType_TupleGetItem, + schema::PrimitiveType_Reshape, + schema::PrimitiveType_FullConnection, + schema::PrimitiveType_MatMul, + schema::PrimitiveType_Crop, + schema::PrimitiveType_DeDepthwiseConv2D, + schema::PrimitiveType_DeConv2D, + schema::PrimitiveType_Activation, + schema::PrimitiveType_Transpose, + schema::PrimitiveType_Eltwise, + schema::PrimitiveType_Gather, + schema::PrimitiveType_LayerNorm, }; bool contain = IsContain(int8OpList, type); if (!contain) { diff --git a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc index fdd4e07d0f..92b1983644 100644 --- a/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/onnx_inputs_adjust_pass.cc @@ -547,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { status = ReplaceConstant(func_graph, cnode); } else if (type == schema::PrimitiveType_Cast) { status = AdjustCast(cnode); + } else if (type == schema::PrimitiveType_Transpose) { + status = ReplaceTransposeWithGraphInput(func_graph, cnode); } if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { MS_LOG(ERROR) << "adjust input pass is failed.";