| @@ -102,11 +102,11 @@ union PrimitiveType { | |||||
| Tile, | Tile, | ||||
| Cast, | Cast, | ||||
| Shape, | Shape, | ||||
| Nchw2Nhwc, | |||||
| Nhwc2Nchw, | |||||
| Nchw2Nhwc, // DEPRECATED | |||||
| Nhwc2Nchw, // DEPRECATED | |||||
| QuantDTypeCast, | QuantDTypeCast, | ||||
| Split, | Split, | ||||
| Permute, | |||||
| Permute, // DEPRECATED | |||||
| FakeQuantWithMinMaxVars, | FakeQuantWithMinMaxVars, | ||||
| Equal, | Equal, | ||||
| Less, | Less, | ||||
| @@ -338,11 +338,11 @@ table ConstantOfShape{ | |||||
| value: [float]; | value: [float]; | ||||
| } | } | ||||
| table Nchw2Nhwc { | |||||
| table Nchw2Nhwc { // DEPRECATED | |||||
| } | } | ||||
| table Nhwc2Nchw { | |||||
| table Nhwc2Nchw { // DEPRECATED | |||||
| } | } | ||||
| @@ -729,7 +729,7 @@ table Crop { | |||||
| offsets : [long]; | offsets : [long]; | ||||
| } | } | ||||
| table Permute { | |||||
| table Permute { // DEPRECATED | |||||
| order: [long]; | order: [long]; | ||||
| } | } | ||||
| @@ -1,62 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/permute.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| std::vector<int64_t> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; } | |||||
| void Permute::SetOrder(const std::vector<int64_t> &order) { this->primitive_->value.AsPermute()->order = order; } | |||||
| #else | |||||
| std::vector<int64_t> Permute::GetOrder() const { | |||||
| auto fb_vector = this->primitive_->value_as_Permute()->order(); | |||||
| return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | |||||
| } | |||||
| void Permute::SetOrder(const std::vector<int64_t> &order) {} | |||||
| int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_Permute(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_Permute return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int64_t> order; | |||||
| if (attr->order() != nullptr) { | |||||
| for (int i = 0; i < static_cast<int>(attr->order()->size()); i++) { | |||||
| order.push_back(attr->order()->data()[i]); | |||||
| } | |||||
| } | |||||
| auto val_offset = schema::CreatePermuteDirect(*fbb, &order); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Permute, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Permute>(primitive); } | |||||
| Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator); | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include <memory> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class Permute : public PrimitiveC { | |||||
| public: | |||||
| Permute() = default; | |||||
| ~Permute() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(Permute, PrimitiveC); | |||||
| explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| std::vector<int64_t> GetOrder() const; | |||||
| void SetOrder(const std::vector<int64_t> &order); | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ | |||||
| @@ -110,24 +110,32 @@ Registry TransposeRegistry(schema::PrimitiveType_Transpose, TransposeCreator); | |||||
| #endif | #endif | ||||
| int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | ||||
| MS_ASSERT(this->primitive_ != nullptr); | |||||
| auto input = inputs_.front(); | auto input = inputs_.front(); | ||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | auto output = outputs_.front(); | ||||
| MS_ASSERT(input != nullptr); | |||||
| MS_ASSERT(output != nullptr); | MS_ASSERT(output != nullptr); | ||||
| std::vector<int> perm = GetPerm(); | |||||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||||
| std::vector<int> in_shape = input->shape(); | |||||
| output->set_data_type(input->data_type()); | output->set_data_type(input->data_type()); | ||||
| output->set_format(input->format()); | |||||
| if (input->format() == schema::Format::Format_NCHW && perm == nchw2nhwc_perm) { | |||||
| output->set_format(schema::Format::Format_NHWC); | |||||
| } else if (input->format() == schema::Format::Format_NHWC && perm == nhwc2nchw_perm) { | |||||
| output->set_format(schema::Format::Format_NCHW); | |||||
| } else { | |||||
| output->set_format(input->format()); | |||||
| } | |||||
| if (!infer_flag()) { | if (!infer_flag()) { | ||||
| return RET_INFER_INVALID; | return RET_INFER_INVALID; | ||||
| } | } | ||||
| MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum); | |||||
| MS_ASSERT(outputs_.size() == kSingleNum); | |||||
| std::vector<int> perm; | |||||
| for (size_t i = 0; i < GetPerm().size(); i++) { | |||||
| perm.push_back(GetPerm().at(i)); | |||||
| if (in_shape.size() != 4 && perm.size() == 4) { | |||||
| output->set_shape(in_shape); | |||||
| return RET_OK; | |||||
| } | } | ||||
| std::vector<int> in_shape = input->shape(); | |||||
| std::vector<int> out_shape; | std::vector<int> out_shape; | ||||
| out_shape.resize(perm.size()); | out_shape.resize(perm.size()); | ||||
| for (size_t i = 0; i < perm.size(); ++i) { | for (size_t i = 0; i < perm.size(); ++i) { | ||||
| @@ -48,6 +48,14 @@ int TransposeFp16CPUKernel::Run() { | |||||
| } | } | ||||
| in_data_fp16_ = reinterpret_cast<float16_t *>(in_tensor->MutableData()); | in_data_fp16_ = reinterpret_cast<float16_t *>(in_tensor->MutableData()); | ||||
| out_data_fp16_ = reinterpret_cast<float16_t *>(out_tensor->MutableData()); | out_data_fp16_ = reinterpret_cast<float16_t *>(out_tensor->MutableData()); | ||||
| MS_ASSERT(in_data_fp16_); | |||||
| MS_ASSERT(out_data_fp16_); | |||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||||
| if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||||
| memcpy(out_data_fp16_, in_data_fp16_, in_tensor->ElementsNum() * sizeof(float16_t)); | |||||
| return RET_OK; | |||||
| } | |||||
| int dims = out_tensor->shape().size(); | int dims = out_tensor->shape().size(); | ||||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | if (dims > MAX_TRANSPOSE_DIM_SIZE) { | ||||
| dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); | dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); | ||||
| @@ -63,10 +71,7 @@ int TransposeFp16CPUKernel::Run() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||||
| MS_ASSERT(param); | |||||
| MS_ASSERT(in_data_fp16_); | |||||
| MS_ASSERT(out_data_fp16_); | |||||
| MS_ASSERT(out_shape_); | MS_ASSERT(out_shape_); | ||||
| auto ret = Fp16DoTranspose(in_data_fp16_, out_data_fp16_, out_shape_, param, dim_size_, position_); | auto ret = Fp16DoTranspose(in_data_fp16_, out_data_fp16_, out_shape_, param, dim_size_, position_); | ||||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | if (dims > MAX_TRANSPOSE_DIM_SIZE) { | ||||
| @@ -1,72 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Nchw2Nhwc; | |||||
| namespace mindspore::kernel { | |||||
| int Nchw2NhwcCPUKernel::Init() { return RET_OK; } | |||||
| int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; } | |||||
| int Nchw2NhwcCPUKernel::Run() { | |||||
| auto input = in_tensors_.at(0); | |||||
| auto output = out_tensors_.at(0); | |||||
| if (input->shape().size() == 4) { | |||||
| if (input->data_type() == kNumberTypeFloat32) { | |||||
| PackNCHWToNHWCFp32(input->MutableData(), output->MutableData(), output->Batch(), | |||||
| output->Height() * output->Width(), output->Channel()); | |||||
| } else if (input->data_type() == kNumberTypeInt8) { | |||||
| PackNCHWToNHWCInt8(input->MutableData(), output->MutableData(), output->Batch(), | |||||
| output->Height() * output->Width(), output->Channel()); | |||||
| } | |||||
| } else { | |||||
| memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc); | |||||
| auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!"; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "nnacl/pack.h" | |||||
| namespace mindspore::kernel { | |||||
| class Nchw2NhwcCPUKernel : public LiteKernel { | |||||
| public: | |||||
| Nchw2NhwcCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~Nchw2NhwcCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ | |||||
| @@ -1,72 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h" | |||||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||||
| using mindspore::lite::KernelRegistrar; | |||||
| using mindspore::lite::RET_ERROR; | |||||
| using mindspore::lite::RET_OK; | |||||
| using mindspore::schema::PrimitiveType_Nhwc2Nchw; | |||||
| namespace mindspore::kernel { | |||||
| int Nhwc2NchwCPUKernel::Init() { return RET_OK; } | |||||
| int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; } | |||||
| int Nhwc2NchwCPUKernel::Run() { | |||||
| auto input = in_tensors_.at(0); | |||||
| auto output = out_tensors_.at(0); | |||||
| if (input->shape().size() == 4) { | |||||
| if (input->data_type() == kNumberTypeFloat32) { | |||||
| PackNHWCToNCHWFp32(input->MutableData(), output->MutableData(), output->Batch(), | |||||
| output->Height() * output->Width(), output->Channel()); | |||||
| } else if (input->data_type() == kNumberTypeInt8) { | |||||
| PackNHWCToNCHWInt8(input->MutableData(), output->MutableData(), output->Batch(), | |||||
| output->Height() * output->Width(), output->Channel()); | |||||
| } | |||||
| } else { | |||||
| memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float)); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||||
| const mindspore::lite::PrimitiveC *primitive) { | |||||
| MS_ASSERT(opParameter != nullptr); | |||||
| MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw); | |||||
| auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||||
| if (kernel == nullptr) { | |||||
| MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!"; | |||||
| free(opParameter); | |||||
| return nullptr; | |||||
| } | |||||
| auto ret = kernel->Init(); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||||
| delete kernel; | |||||
| return nullptr; | |||||
| } | |||||
| return kernel; | |||||
| } | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | |||||
| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ | |||||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ | |||||
| #include <vector> | |||||
| #include "src/lite_kernel.h" | |||||
| #include "schema/model_generated.h" | |||||
| #include "src/kernel_registry.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "nnacl/pack.h" | |||||
| namespace mindspore::kernel { | |||||
| class Nhwc2NchwCPUKernel : public LiteKernel { | |||||
| public: | |||||
| Nhwc2NchwCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||||
| const mindspore::lite::PrimitiveC *primitive) | |||||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||||
| ~Nhwc2NchwCPUKernel() override = default; | |||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| }; | |||||
| } // namespace mindspore::kernel | |||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ | |||||
| @@ -18,11 +18,14 @@ | |||||
| #include "schema/model_generated.h" | #include "schema/model_generated.h" | ||||
| #include "src/kernel_registry.h" | #include "src/kernel_registry.h" | ||||
| #include "src/runtime/runtime_api.h" | #include "src/runtime/runtime_api.h" | ||||
| #include "nnacl/pack.h" | |||||
| using mindspore::lite::KernelRegistrar; | using mindspore::lite::KernelRegistrar; | ||||
| using mindspore::lite::RET_ERROR; | using mindspore::lite::RET_ERROR; | ||||
| using mindspore::lite::RET_OK; | using mindspore::lite::RET_OK; | ||||
| using mindspore::lite::RET_OP_EXECUTE_FAILURE; | using mindspore::lite::RET_OP_EXECUTE_FAILURE; | ||||
| using mindspore::schema::PrimitiveType_Nchw2Nhwc; | |||||
| using mindspore::schema::PrimitiveType_Nhwc2Nchw; | |||||
| using mindspore::schema::PrimitiveType_Transpose; | using mindspore::schema::PrimitiveType_Transpose; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| @@ -36,7 +39,9 @@ int TransposeCPUKernel::Init() { | |||||
| int TransposeCPUKernel::ReSize() { | int TransposeCPUKernel::ReSize() { | ||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_); | TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_); | ||||
| if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||||
| return RET_OK; | |||||
| } | |||||
| auto &inTensor = in_tensors_.front(); | auto &inTensor = in_tensors_.front(); | ||||
| auto &outTensor = out_tensors_.front(); | auto &outTensor = out_tensors_.front(); | ||||
| auto in_shape = inTensor->shape(); | auto in_shape = inTensor->shape(); | ||||
| @@ -80,6 +85,41 @@ int TransposeCPUKernel::Run() { | |||||
| } | } | ||||
| in_data_ = reinterpret_cast<float *>(in_tensor->MutableData()); | in_data_ = reinterpret_cast<float *>(in_tensor->MutableData()); | ||||
| out_data_ = reinterpret_cast<float *>(out_tensor->MutableData()); | out_data_ = reinterpret_cast<float *>(out_tensor->MutableData()); | ||||
| MS_ASSERT(in_data_); | |||||
| MS_ASSERT(out_data_); | |||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||||
| if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||||
| memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float)); | |||||
| return RET_OK; | |||||
| } | |||||
| if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 2 && param->perm_[2] == 3 && | |||||
| param->perm_[3] == 1) { | |||||
| if (in_tensor->data_type() == kNumberTypeFloat32) { | |||||
| PackNCHWToNHWCFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||||
| } else if (in_tensor->data_type() == kNumberTypeInt8) { | |||||
| PackNCHWToNHWCInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 3 && param->perm_[2] == 1 && | |||||
| param->perm_[3] == 2) { | |||||
| if (in_tensor->data_type() == kNumberTypeFloat32) { | |||||
| PackNHWCToNCHWFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||||
| } else if (in_tensor->data_type() == kNumberTypeInt8) { | |||||
| PackNHWCToNCHWInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| if (in_tensor->data_type() == kNumberTypeInt8) { | |||||
| MS_LOG(ERROR) << "not support now"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int dims = out_tensor->shape().size(); | int dims = out_tensor->shape().size(); | ||||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | if (dims > MAX_TRANSPOSE_DIM_SIZE) { | ||||
| dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); | dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); | ||||
| @@ -96,10 +136,6 @@ int TransposeCPUKernel::Run() { | |||||
| } | } | ||||
| } | } | ||||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||||
| MS_ASSERT(param); | |||||
| MS_ASSERT(in_data_); | |||||
| MS_ASSERT(out_data_); | |||||
| MS_ASSERT(out_shape_); | MS_ASSERT(out_shape_); | ||||
| auto ret = DoTransposeFp32(in_data_, out_data_, out_shape_, param, dim_size_, position_); | auto ret = DoTransposeFp32(in_data_, out_data_, out_shape_, param, dim_size_, position_); | ||||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | if (dims > MAX_TRANSPOSE_DIM_SIZE) { | ||||
| @@ -143,4 +179,9 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor | |||||
| } | } | ||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) | REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuTransposeFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuTransposeFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuTransposeFp32KernelCreator) | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuTransposeFp32KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -3,7 +3,7 @@ mtk_emotions-d2012-75.8%.onnx | |||||
| mtk_face_features_v3.onnx | mtk_face_features_v3.onnx | ||||
| emotion-ferplus-8.onnx | emotion-ferplus-8.onnx | ||||
| rcnn-ilsvrc13-9.onnx | rcnn-ilsvrc13-9.onnx | ||||
| efficientnet-lite4-11.onnx | |||||
| #efficientnet-lite4-11.onnx | |||||
| mobilenetv2-7.onnx | mobilenetv2-7.onnx | ||||
| shufflenet-v2-10.onnx | shufflenet-v2-10.onnx | ||||
| squeezenet1.1-7.onnx | squeezenet1.1-7.onnx | ||||
| @@ -3,7 +3,7 @@ mtk_emotions-d2012-75.8%.onnx 20 | |||||
| mtk_face_features_v3.onnx 20 | mtk_face_features_v3.onnx 20 | ||||
| emotion-ferplus-8.onnx 1 | emotion-ferplus-8.onnx 1 | ||||
| #rcnn-ilsvrc13-9.onnx 0.1 | #rcnn-ilsvrc13-9.onnx 0.1 | ||||
| efficientnet-lite4-11.onnx 2 | |||||
| #efficientnet-lite4-11.onnx 2 | |||||
| mobilenetv2-7.onnx 8 | mobilenetv2-7.onnx 8 | ||||
| shufflenet-v2-10.onnx 5 | shufflenet-v2-10.onnx 5 | ||||
| squeezenet1.1-7.onnx 1 | squeezenet1.1-7.onnx 1 | ||||
| @@ -66,9 +66,7 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = { | |||||
| static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; | static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; | ||||
| static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Nchw2Nhwc, | |||||
| schema::PrimitiveType_Nhwc2Nchw, | |||||
| schema::PrimitiveType_Conv2D, | |||||
| static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Conv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, | schema::PrimitiveType_DepthwiseConv2D, | ||||
| schema::PrimitiveType_Add, | schema::PrimitiveType_Add, | ||||
| schema::PrimitiveType_Transpose, | schema::PrimitiveType_Transpose, | ||||
| @@ -16,6 +16,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | |||||
| #include <memory> | #include <memory> | ||||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | ||||
| #include "src/common/log_adapter.h" | #include "src/common/log_adapter.h" | ||||
| @@ -24,103 +25,59 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace { | |||||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||||
| } // namespace | |||||
| namespace lite { | namespace lite { | ||||
| #define kFormatTransMatchPathLen2 2 | #define kFormatTransMatchPathLen2 2 | ||||
| #define kFormatTransMatchPathLen3 3 | #define kFormatTransMatchPathLen3 3 | ||||
| STATUS FormatTransFusionPass::DefinePattern() { | STATUS FormatTransFusionPass::DefinePattern() { | ||||
| // nchw2nhwc + nhwc2nchw | |||||
| // nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc | |||||
| { | { | ||||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||||
| auto transpose1 = std::make_shared<PatternOp>(); | |||||
| transpose1->id = kFormatTransTranspose1; | |||||
| transpose1->types = {PrimitiveType_Transpose}; | |||||
| auto transpose2 = std::make_shared<PatternOp>(); | |||||
| transpose2->id = kFormatTransTranspose2; | |||||
| transpose2->types = {PrimitiveType_Transpose}; | |||||
| nh2ncOp->left = nc2nhOp; | |||||
| std::unique_ptr<FusionPattern> nc2NhAndNh2NcFusionPattern(new (std::nothrow) | |||||
| FusionPattern(kNc2NhAndNh2NcFusionPattern)); | |||||
| if (nc2NhAndNh2NcFusionPattern == nullptr) { | |||||
| transpose2->left = transpose1; | |||||
| auto pattern = std::make_unique<FusionPattern>(kNc2NhAndNh2NcFusionPattern); | |||||
| if (pattern == nullptr) { | |||||
| MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; | MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| nc2NhAndNh2NcFusionPattern->AddPatternOp(nc2nhOp); | |||||
| nc2NhAndNh2NcFusionPattern->AddPatternOp(nh2ncOp); | |||||
| nc2NhAndNh2NcFusionPattern->Finish(); | |||||
| this->patterns.emplace_back(nc2NhAndNh2NcFusionPattern.release()); | |||||
| pattern->AddPatternOp(transpose1); | |||||
| pattern->AddPatternOp(transpose2); | |||||
| pattern->Finish(); | |||||
| this->patterns.emplace_back(pattern.release()); | |||||
| } | } | ||||
| // nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw | |||||
| { | { | ||||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||||
| auto transpose1 = std::make_shared<PatternOp>(); | |||||
| transpose1->id = kFormatTransTranspose1; | |||||
| transpose1->types = {PrimitiveType_Transpose}; | |||||
| auto passOp = std::make_shared<PatternOp>(); | auto passOp = std::make_shared<PatternOp>(); | ||||
| passOp->id = kFormatTransPassOp; | passOp->id = kFormatTransPassOp; | ||||
| passOp->types = {PrimitiveType_QuantDTypeCast}; | passOp->types = {PrimitiveType_QuantDTypeCast}; | ||||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||||
| auto transpose2 = std::make_shared<PatternOp>(); | |||||
| transpose2->id = kFormatTransTranspose2; | |||||
| transpose2->types = {PrimitiveType_Transpose}; | |||||
| passOp->left = nc2nhOp; | |||||
| nh2ncOp->left = passOp; | |||||
| std::unique_ptr<FusionPattern> nc2NhAndNh2NcPassFusionPattern(new (std::nothrow) | |||||
| FusionPattern(kNc2NhAndNh2NcPassFusionPattern)); | |||||
| if (nc2NhAndNh2NcPassFusionPattern == nullptr) { | |||||
| MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcPassFusionPattern << "failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nc2nhOp); | |||||
| nc2NhAndNh2NcPassFusionPattern->AddPatternOp(passOp); | |||||
| nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nh2ncOp); | |||||
| nc2NhAndNh2NcPassFusionPattern->Finish(); | |||||
| this->patterns.emplace_back(nc2NhAndNh2NcPassFusionPattern.release()); | |||||
| } | |||||
| // nhwc2nchw + nchw2nhwc | |||||
| { | |||||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||||
| nc2nhOp->left = nh2ncOp; | |||||
| std::unique_ptr<FusionPattern> nh2NcAndNc2NhFusionPattern(new (std::nothrow) | |||||
| FusionPattern(kNh2NcAndNc2NhFusionPattern)); | |||||
| if (nh2NcAndNc2NhFusionPattern == nullptr) { | |||||
| MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhFusionPattern << "failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| nh2NcAndNc2NhFusionPattern->AddPatternOp(nh2ncOp); | |||||
| nh2NcAndNc2NhFusionPattern->AddPatternOp(nc2nhOp); | |||||
| nh2NcAndNc2NhFusionPattern->Finish(); | |||||
| this->patterns.emplace_back(nh2NcAndNc2NhFusionPattern.release()); | |||||
| } | |||||
| // nhwc2nchw + QuantDtypeCast + nchw2nhwc | |||||
| { | |||||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||||
| auto passOp = std::make_shared<PatternOp>(); | |||||
| passOp->id = kFormatTransPassOp; | |||||
| passOp->types = {PrimitiveType_QuantDTypeCast}; | |||||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||||
| passOp->left = nh2ncOp; | |||||
| nc2nhOp->left = passOp; | |||||
| std::unique_ptr<FusionPattern> nh2NcAndNc2NhPassFusionPattern(new (std::nothrow) | |||||
| FusionPattern(kNh2NcAndNc2NhPassFusionPattern)); | |||||
| if (nh2NcAndNc2NhPassFusionPattern == nullptr) { | |||||
| passOp->left = transpose2; | |||||
| transpose1->left = passOp; | |||||
| auto pattern = std::make_unique<FusionPattern>(kNh2NcAndNc2NhPassFusionPattern); | |||||
| if (pattern == nullptr) { | |||||
| MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; | MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nh2ncOp); | |||||
| nh2NcAndNc2NhPassFusionPattern->AddPatternOp(passOp); | |||||
| nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nc2nhOp); | |||||
| nh2NcAndNc2NhPassFusionPattern->Finish(); | |||||
| this->patterns.emplace_back(nh2NcAndNc2NhPassFusionPattern.release()); | |||||
| pattern->AddPatternOp(transpose1); | |||||
| pattern->AddPatternOp(passOp); | |||||
| pattern->AddPatternOp(transpose2); | |||||
| pattern->Finish(); | |||||
| this->patterns.emplace_back(pattern.release()); | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -136,51 +93,32 @@ STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::str | |||||
| return RET_PARAM_INVALID; | return RET_PARAM_INVALID; | ||||
| } | } | ||||
| std::shared_ptr<Path> srcPath; | |||||
| std::shared_ptr<Path> dstPath; | |||||
| if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { | |||||
| srcPath = matchedPath[kFormatTransNc2NhOp]; | |||||
| dstPath = matchedPath[kFormatTransNh2NcOp]; | |||||
| } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { | |||||
| srcPath = matchedPath[kFormatTransNh2NcOp]; | |||||
| dstPath = matchedPath[kFormatTransNc2NhOp]; | |||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| if (srcPath == nullptr) { | |||||
| MS_LOG(ERROR) << "srcPath is failed to get"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (dstPath == nullptr) { | |||||
| MS_LOG(ERROR) << "dstPath is failed to get"; | |||||
| std::shared_ptr<Path> srcPath = matchedPath[kFormatTransTranspose1]; | |||||
| std::shared_ptr<Path> dstPath = matchedPath[kFormatTransTranspose2]; | |||||
| if (srcPath == nullptr || dstPath == nullptr) { | |||||
| MS_LOG(ERROR) << "srcPath or dstPath is failed to get"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); | auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); | ||||
| auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); | auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); | ||||
| MS_ASSERT(srcNode != nullptr); | MS_ASSERT(srcNode != nullptr); | ||||
| MS_ASSERT(dstNode != nullptr); | MS_ASSERT(dstNode != nullptr); | ||||
| if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { | |||||
| MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nchw2Nhwc); | |||||
| MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nhwc2Nchw); | |||||
| } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { | |||||
| MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nhwc2Nchw); | |||||
| MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nchw2Nhwc); | |||||
| } else { | |||||
| MS_ASSERT(false); | |||||
| } | |||||
| auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; | |||||
| return status; | |||||
| } | |||||
| status = IsolateOneWayNode(graph, dstPath->nodeIdx); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; | |||||
| return status; | |||||
| bool isNc2NhAndNh2Nc = srcNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm && | |||||
| dstNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm; | |||||
| bool isNh2NcAndNc2Nh = srcNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm && | |||||
| dstNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm; | |||||
| if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) { | |||||
| auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; | |||||
| return status; | |||||
| } | |||||
| status = IsolateOneWayNode(graph, dstPath->nodeIdx); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; | |||||
| return status; | |||||
| } | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| } // namespace lite | } // namespace lite | ||||
| @@ -24,12 +24,10 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| constexpr const char *kFormatTransNc2NhOp = "FormatTransNc2NhOp"; | |||||
| constexpr const char *kFormatTransNh2NcOp = "FormatTransNh2NcOp"; | |||||
| constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1"; | |||||
| constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2"; | |||||
| constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; | constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; | ||||
| constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; | constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; | ||||
| constexpr const char *kNc2NhAndNh2NcPassFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; | |||||
| constexpr const char *kNh2NcAndNc2NhFusionPattern = "Nh2NcAndNc2NhFusionPattern"; | |||||
| constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; | constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; | ||||
| class FormatTransFusionPass : public FusionPass { | class FormatTransFusionPass : public FusionPass { | ||||
| @@ -1,148 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <memory> | |||||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "securec/include/securec.h" | |||||
| #include "tools/common/graph_util.h" | |||||
| #include "include/errorcode.h" | |||||
| #include "schema/inner/model_generated.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #define kFormatTransTransposeMatchPathLen 2 | |||||
| STATUS FormatTransPermuteFusionPass::DefinePattern() { | |||||
| // format trans + permute | |||||
| { | |||||
| auto formatTransOp = std::make_shared<PatternOp>(); | |||||
| formatTransOp->id = kFormatTransformOp; | |||||
| formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||||
| auto transposeOp = std::make_shared<PatternOp>(); | |||||
| transposeOp->id = kPermuteOp; | |||||
| transposeOp->types = {PrimitiveType_Transpose}; | |||||
| transposeOp->left = formatTransOp; | |||||
| std::unique_ptr<FusionPattern> formatTransTransposeFusionPattern( | |||||
| new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern)); | |||||
| if (formatTransTransposeFusionPattern == nullptr) { | |||||
| MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| formatTransTransposeFusionPattern->AddPatternOp(formatTransOp); | |||||
| formatTransTransposeFusionPattern->AddPatternOp(transposeOp); | |||||
| formatTransTransposeFusionPattern->Finish(); | |||||
| this->patterns.emplace_back(formatTransTransposeFusionPattern.release()); | |||||
| } | |||||
| // permute + format trans | |||||
| { | |||||
| auto formatTransOp = std::make_shared<PatternOp>(); | |||||
| formatTransOp->id = kFormatTransformOp; | |||||
| formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||||
| auto transposeOp = std::make_shared<PatternOp>(); | |||||
| transposeOp->id = kPermuteOp; | |||||
| transposeOp->types = {PrimitiveType_Transpose}; | |||||
| formatTransOp->left = transposeOp; | |||||
| std::unique_ptr<FusionPattern> transposeFormatTransFusionPattern( | |||||
| new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern)); | |||||
| if (transposeFormatTransFusionPattern == nullptr) { | |||||
| MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| transposeFormatTransFusionPattern->AddPatternOp(formatTransOp); | |||||
| transposeFormatTransFusionPattern->AddPatternOp(transposeOp); | |||||
| transposeFormatTransFusionPattern->Finish(); | |||||
| this->patterns.emplace_back(transposeFormatTransFusionPattern.release()); | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } | |||||
| STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||||
| MS_ASSERT(graph != nullptr); | |||||
| if (matchedPath.size() != kFormatTransTransposeMatchPathLen) { | |||||
| MS_LOG(ERROR) << "schema::Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen | |||||
| << " NodeIndex in matchedPair"; | |||||
| return RET_PARAM_INVALID; | |||||
| } | |||||
| std::shared_ptr<Path> formatTransPath = matchedPath[kFormatTransformOp]; | |||||
| std::shared_ptr<Path> transposePath = matchedPath[kPermuteOp]; | |||||
| if (formatTransPath == nullptr) { | |||||
| MS_LOG(ERROR) << "formatTransPath is failed to get"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (transposePath == nullptr) { | |||||
| MS_LOG(ERROR) << "permutePath is failed to get"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx); | |||||
| auto &transposeNode = graph->nodes.at(transposePath->nodeIdx); | |||||
| MS_ASSERT(formatTransNode != nullptr); | |||||
| MS_ASSERT(transposeNode != nullptr); | |||||
| auto formatTransType = formatTransNode->primitive->value.type; | |||||
| if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) { | |||||
| MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or " | |||||
| << EnumNamePrimitiveType(PrimitiveType_Nchw2Nhwc) << ", but got " | |||||
| << EnumNamePrimitiveType(formatTransType); | |||||
| return RET_ERROR; | |||||
| } | |||||
| MS_ASSERT(transposeNode->primitive != nullptr); | |||||
| auto transposePrimitive = transposeNode->primitive->value.AsTranspose(); | |||||
| MS_ASSERT(transposePrimitive != nullptr); | |||||
| auto perm = transposePrimitive->perm; | |||||
| if (perm.size() != 4) { | |||||
| return RET_OK; | |||||
| } | |||||
| std::vector<int32_t> nchw2nhwcPerm = {0, 2, 3, 1}; | |||||
| std::vector<int32_t> nhwc2nchwPerm = {0, 3, 1, 2}; | |||||
| if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || | |||||
| (perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { | |||||
| if (formatTransPath->nodeIdx < transposePath->nodeIdx) { | |||||
| if (graph->allTensors.at(formatTransNode->inputIndex[0])->format != | |||||
| graph->allTensors.at(transposeNode->outputIndex[0])->format) { | |||||
| return RET_OK; | |||||
| } | |||||
| } else { | |||||
| if (graph->allTensors.at(transposeNode->inputIndex[0])->format != | |||||
| graph->allTensors.at(formatTransNode->outputIndex[0])->format) { | |||||
| return RET_OK; | |||||
| } | |||||
| } | |||||
| auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status; | |||||
| return status; | |||||
| } | |||||
| status = IsolateOneWayNode(graph, transposePath->nodeIdx); | |||||
| if (status != RET_OK) { | |||||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status; | |||||
| return status; | |||||
| } | |||||
| } | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H | |||||
| #define MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| constexpr const char *kFormatTransformOp = "FormatTransOp"; | |||||
| constexpr const char *kPermuteOp = "PermuteOp"; | |||||
| constexpr const char *kFormatTrans2TransposeFusionPattern = "Nc2NhAndNh2NcFusionPattern"; | |||||
| constexpr const char *kTranspose2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; | |||||
| class FormatTransPermuteFusionPass : public FusionPass { | |||||
| public: | |||||
| FormatTransPermuteFusionPass() = default; | |||||
| ~FormatTransPermuteFusionPass() override = default; | |||||
| STATUS DefinePattern() override; | |||||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||||
| STATUS Run(schema::MetaGraphT *graph) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H | |||||
| @@ -115,7 +115,7 @@ STATUS QuantCastFusionPass::DefinePattern() { | |||||
| srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; | srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; | ||||
| auto formatOp = std::make_shared<PatternOp>(); | auto formatOp = std::make_shared<PatternOp>(); | ||||
| formatOp->id = kFormatTransOp; | formatOp->id = kFormatTransOp; | ||||
| formatOp->types = {schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc}; | |||||
| formatOp->types = {PrimitiveType_Transpose}; | |||||
| formatOp->left = srcOp; | formatOp->left = srcOp; | ||||
| auto dstOp = std::make_shared<PatternOp>(); | auto dstOp = std::make_shared<PatternOp>(); | ||||
| dstOp->id = kQuantCastDstOp; | dstOp->id = kQuantCastDstOp; | ||||
| @@ -14,6 +14,7 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include <string> | #include <string> | ||||
| #include <memory> | #include <memory> | ||||
| #include <utility> | #include <utility> | ||||
| @@ -196,15 +197,47 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI | |||||
| } | } | ||||
| auto transNode = std::make_unique<schema::CNodeT>(); | auto transNode = std::make_unique<schema::CNodeT>(); | ||||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | transNode->primitive = std::make_unique<schema::PrimitiveT>(); | ||||
| transNode->primitive->value.type = schema::PrimitiveType_Transpose; | |||||
| auto attr = new (std::nothrow) schema::TransposeT(); | |||||
| if (nodeType == kNCHW2NHWC) { | if (nodeType == kNCHW2NHWC) { | ||||
| transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); | transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); | ||||
| transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; | |||||
| attr->perm = {0, 2, 3, 1}; | |||||
| } else { | } else { | ||||
| transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); | transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); | ||||
| transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; | |||||
| attr->perm = {0, 3, 1, 2}; | |||||
| } | } | ||||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); | |||||
| transNode->primitive->value.value = attr; | |||||
| OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> { | |||||
| auto newOpDef = std::make_unique<schema::CNodeT>(); | |||||
| if (newOpDef == nullptr) { | |||||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| newOpDef->name = inOpDef->name; | |||||
| newOpDef->quantType = inOpDef->quantType; | |||||
| newOpDef->primitive = std::make_unique<schema::PrimitiveT>(); | |||||
| if (newOpDef->primitive == nullptr) { | |||||
| MS_LOG(ERROR) << "new PrimitiveT failed"; | |||||
| return nullptr; | |||||
| } | |||||
| newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; | |||||
| auto transposeParam = new (std::nothrow) TransposeT; | |||||
| if (transposeParam == nullptr) { | |||||
| MS_LOG(ERROR) << "new transposeParam failed"; | |||||
| return nullptr; | |||||
| } | |||||
| auto inParam = inOpDef->primitive->value.AsTranspose(); | |||||
| MS_ASSERT(inParam != nullptr); | |||||
| transposeParam->perm.resize(inParam->perm.size()); | |||||
| std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(), | |||||
| [](const int32_t ele) { return ele; }); | |||||
| newOpDef->primitive->value.value = transposeParam; | |||||
| return newOpDef; | |||||
| }; | |||||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, TransposeOpCopyer); | |||||
| } | } | ||||
| void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | ||||
| @@ -25,6 +25,10 @@ | |||||
| #include "schema/inner/model_generated.h" | #include "schema/inner/model_generated.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace { | |||||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||||
| } // namespace | |||||
| namespace lite { | namespace lite { | ||||
| STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | ||||
| @@ -34,7 +38,10 @@ STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | |||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto &node = *iter; | auto &node = *iter; | ||||
| auto type = node->primitive->value.type; | auto type = node->primitive->value.type; | ||||
| if (type != schema::PrimitiveType_Nchw2Nhwc) { | |||||
| if (type != PrimitiveType_Transpose) { | |||||
| continue; | |||||
| } | |||||
| if (node->primitive->value.AsTranspose()->perm != nchw2nhwc_perm) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| std::vector<size_t> pre_nh2nc_nodes; | std::vector<size_t> pre_nh2nc_nodes; | ||||
| @@ -176,7 +183,8 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc | |||||
| auto &pre_node = graph->nodes.at(input_node_index); | auto &pre_node = graph->nodes.at(input_node_index); | ||||
| MS_ASSERT(pre_node != nullptr); | MS_ASSERT(pre_node != nullptr); | ||||
| auto node_type = pre_node->primitive->value.type; | auto node_type = pre_node->primitive->value.type; | ||||
| if (node_type == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| if (node_type == schema::PrimitiveType_Transpose && | |||||
| pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { | if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { | ||||
| pre_nh2nc_nodes->emplace_back(input_node_index); | pre_nh2nc_nodes->emplace_back(input_node_index); | ||||
| } | } | ||||
| @@ -24,12 +24,16 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace { | |||||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||||
| } // namespace | |||||
| namespace lite { | namespace lite { | ||||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| MS_ASSERT(node != nullptr); | MS_ASSERT(node != nullptr); | ||||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | auto input_node_indexes = GetInputNodeIdx(*graph, *node); | ||||
| pre_type_ = schema::PrimitiveType_NONE; | |||||
| pre_type_ = kNONE; | |||||
| size_t has_trans_count = 0; | size_t has_trans_count = 0; | ||||
| auto can_fusion = true; | auto can_fusion = true; | ||||
| for (auto input_node_index : input_node_indexes) { | for (auto input_node_index : input_node_indexes) { | ||||
| @@ -38,16 +42,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||||
| MS_ASSERT(pre_node != nullptr); | MS_ASSERT(pre_node != nullptr); | ||||
| MS_ASSERT(pre_node->primitive != nullptr); | MS_ASSERT(pre_node->primitive != nullptr); | ||||
| MS_ASSERT(pre_node->primitive->value != nullptr); | MS_ASSERT(pre_node->primitive->value != nullptr); | ||||
| if (pre_type_ == schema::PrimitiveType_NONE) { | |||||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| pre_type_ = pre_node->primitive->value.type; | |||||
| if (pre_type_ == kNONE) { | |||||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||||
| if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||||
| pre_type_ = kNCHW2NHWC; | |||||
| } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| pre_type_ = kNHWC2NCHW; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| has_trans_count++; | has_trans_count++; | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| if (pre_type_ != pre_node->primitive->value.type) { | |||||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||||
| auto cur_type = kNONE; | |||||
| if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||||
| cur_type = kNCHW2NHWC; | |||||
| } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| cur_type = kNHWC2NCHW; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (pre_type_ != cur_type) { | |||||
| can_fusion = false; | can_fusion = false; | ||||
| break; | break; | ||||
| } else { | } else { | ||||
| @@ -60,23 +76,35 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | ||||
| post_type_ = schema::PrimitiveType_NONE; | |||||
| post_type_ = kNONE; | |||||
| for (auto output_node_index : output_node_indexes) { | for (auto output_node_index : output_node_indexes) { | ||||
| MS_ASSERT(graph->nodes.size() > output_node_index); | MS_ASSERT(graph->nodes.size() > output_node_index); | ||||
| auto &post_node = graph->nodes.at(output_node_index); | auto &post_node = graph->nodes.at(output_node_index); | ||||
| MS_ASSERT(post_node != nullptr); | MS_ASSERT(post_node != nullptr); | ||||
| MS_ASSERT(post_node->primitive != nullptr); | MS_ASSERT(post_node->primitive != nullptr); | ||||
| MS_ASSERT(post_node->primitive->value != nullptr); | MS_ASSERT(post_node->primitive->value != nullptr); | ||||
| if (post_type_ == schema::PrimitiveType_NONE) { | |||||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| post_type_ = post_node->primitive->value.type; | |||||
| if (post_type_ == kNONE) { | |||||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||||
| post_type_ = kNCHW2NHWC; | |||||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| post_type_ = kNHWC2NCHW; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| has_trans_count++; | has_trans_count++; | ||||
| } | } | ||||
| } else { | } else { | ||||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| if (post_type_ != post_node->primitive->value.type) { | |||||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||||
| auto cur_type = kNONE; | |||||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||||
| cur_type = kNCHW2NHWC; | |||||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||||
| cur_type = kNHWC2NCHW; | |||||
| } else { | |||||
| return false; | |||||
| } | |||||
| if (post_type_ != cur_type) { | |||||
| can_fusion = false; | can_fusion = false; | ||||
| break; | break; | ||||
| } else { | } else { | ||||
| @@ -88,7 +116,7 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||||
| if (!can_fusion) { | if (!can_fusion) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||||
| if (pre_type_ == kNONE && post_type_ == kNONE) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); | auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); | ||||
| @@ -114,21 +142,21 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||||
| STATUS TransOpInsertPass::FindOutTransType() { | STATUS TransOpInsertPass::FindOutTransType() { | ||||
| pre_insert_trans_type_ = kNHWC2NCHW; | pre_insert_trans_type_ = kNHWC2NCHW; | ||||
| post_insert_trans_type_ = kNHWC2NCHW; | post_insert_trans_type_ = kNHWC2NCHW; | ||||
| if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) { | |||||
| pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||||
| post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| } else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||||
| pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||||
| } else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||||
| if (pre_type_ == kNONE && post_type_ != kNONE) { | |||||
| pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||||
| post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| } else if (pre_type_ != kNONE && post_type_ == kNONE) { | |||||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||||
| } else if (pre_type_ == kNONE && post_type_ == kNONE) { | |||||
| MS_ASSERT(false); | MS_ASSERT(false); | ||||
| } else { | } else { | ||||
| if (pre_type_ == post_type_) { | if (pre_type_ == post_type_) { | ||||
| MS_LOG(ERROR) << "Unknow error"; | MS_LOG(ERROR) << "Unknow error"; | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||||
| } | } | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | |||||
| #include "tools/common/graph_util.h" | #include "tools/common/graph_util.h" | ||||
| #include "tools/converter/converter_flags.h" | #include "tools/converter/converter_flags.h" | ||||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | ||||
| @@ -44,8 +45,10 @@ class TransOpInsertPass : public FormatTransPass { | |||||
| private: | private: | ||||
| FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; | FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; | ||||
| FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; | FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; | ||||
| schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE; | |||||
| schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE; | |||||
| FormatTransNodeType pre_type_ = kNONE; | |||||
| std::vector<int> pre_perm_; | |||||
| FormatTransNodeType post_type_ = kNONE; | |||||
| std::vector<int> post_perm_; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -25,13 +25,18 @@ | |||||
| using mindspore::lite::PrimitiveC; | using mindspore::lite::PrimitiveC; | ||||
| using mindspore::lite::Tensor; | using mindspore::lite::Tensor; | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace { | |||||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||||
| } // namespace | |||||
| namespace lite { | namespace lite { | ||||
| STATUS TransOpRemovePass::Run(MetaGraphT *graph) { | STATUS TransOpRemovePass::Run(MetaGraphT *graph) { | ||||
| MS_ASSERT(graph != nullptr); | MS_ASSERT(graph != nullptr); | ||||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | ||||
| auto &node = *iter; | auto &node = *iter; | ||||
| auto type = node->primitive->value.type; | auto type = node->primitive->value.type; | ||||
| if (type == schema::PrimitiveType_Nchw2Nhwc || type == schema::PrimitiveType_Nhwc2Nchw) { | |||||
| if (type == schema::PrimitiveType_Transpose && (node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm || | |||||
| node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm)) { | |||||
| auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); | auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); | ||||
| // less than 4 dims can delete | // less than 4 dims can delete | ||||
| if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { | if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { | ||||
| @@ -523,8 +523,6 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | ||||
| _registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>(); | _registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>(); | ||||
| _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>(); | _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>(); | ||||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | |||||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | |||||
| // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | ||||
| // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. | // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. | ||||
| // if quantTransNode is inserted after detection_postprocess node, there will be some errors | // if quantTransNode is inserted after detection_postprocess node, there will be some errors | ||||
| @@ -89,13 +89,25 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | auto cnode = std::dynamic_pointer_cast<CNode>(node); | ||||
| auto type = NodePrimitiveType(cnode); | auto type = NodePrimitiveType(cnode); | ||||
| static const std::vector<schema::PrimitiveType> int8OpList = { | static const std::vector<schema::PrimitiveType> int8OpList = { | ||||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Mul, | |||||
| schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_TupleGetItem, schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, | |||||
| schema::PrimitiveType_MatMul, schema::PrimitiveType_Crop, schema::PrimitiveType_DeDepthwiseConv2D, | |||||
| schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Activation, schema::PrimitiveType_Transpose, | |||||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Gather, schema::PrimitiveType_LayerNorm, | |||||
| schema::PrimitiveType_Conv2D, | |||||
| schema::PrimitiveType_DepthwiseConv2D, | |||||
| schema::PrimitiveType_Add, | |||||
| schema::PrimitiveType_Mul, | |||||
| schema::PrimitiveType_Pooling, | |||||
| schema::PrimitiveType_Concat, | |||||
| schema::PrimitiveType_Split, | |||||
| schema::PrimitiveType_TupleGetItem, | |||||
| schema::PrimitiveType_Reshape, | |||||
| schema::PrimitiveType_FullConnection, | |||||
| schema::PrimitiveType_MatMul, | |||||
| schema::PrimitiveType_Crop, | |||||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||||
| schema::PrimitiveType_DeConv2D, | |||||
| schema::PrimitiveType_Activation, | |||||
| schema::PrimitiveType_Transpose, | |||||
| schema::PrimitiveType_Eltwise, | |||||
| schema::PrimitiveType_Gather, | |||||
| schema::PrimitiveType_LayerNorm, | |||||
| }; | }; | ||||
| bool contain = IsContain(int8OpList, type); | bool contain = IsContain(int8OpList, type); | ||||
| if (!contain) { | if (!contain) { | ||||
| @@ -547,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { | |||||
| status = ReplaceConstant(func_graph, cnode); | status = ReplaceConstant(func_graph, cnode); | ||||
| } else if (type == schema::PrimitiveType_Cast) { | } else if (type == schema::PrimitiveType_Cast) { | ||||
| status = AdjustCast(cnode); | status = AdjustCast(cnode); | ||||
| } else if (type == schema::PrimitiveType_Transpose) { | |||||
| status = ReplaceTransposeWithGraphInput(func_graph, cnode); | |||||
| } | } | ||||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | ||||
| MS_LOG(ERROR) << "adjust input pass is failed."; | MS_LOG(ERROR) << "adjust input pass is failed."; | ||||