| @@ -102,11 +102,11 @@ union PrimitiveType { | |||
| Tile, | |||
| Cast, | |||
| Shape, | |||
| Nchw2Nhwc, | |||
| Nhwc2Nchw, | |||
| Nchw2Nhwc, // DEPRECATED | |||
| Nhwc2Nchw, // DEPRECATED | |||
| QuantDTypeCast, | |||
| Split, | |||
| Permute, | |||
| Permute, // DEPRECATED | |||
| FakeQuantWithMinMaxVars, | |||
| Equal, | |||
| Less, | |||
| @@ -338,11 +338,11 @@ table ConstantOfShape{ | |||
| value: [float]; | |||
| } | |||
| table Nchw2Nhwc { | |||
| table Nchw2Nhwc { // DEPRECATED | |||
| } | |||
| table Nhwc2Nchw { | |||
| table Nhwc2Nchw { // DEPRECATED | |||
| } | |||
| @@ -729,7 +729,7 @@ table Crop { | |||
| offsets : [long]; | |||
| } | |||
| table Permute { | |||
| table Permute { // DEPRECATED | |||
| order: [long]; | |||
| } | |||
| @@ -1,62 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/permute.h" | |||
| #ifndef PRIMITIVE_WRITEABLE | |||
| #include "src/ops/ops_register.h" | |||
| #endif | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| std::vector<int64_t> Permute::GetOrder() const { return this->primitive_->value.AsPermute()->order; } | |||
| void Permute::SetOrder(const std::vector<int64_t> &order) { this->primitive_->value.AsPermute()->order = order; } | |||
| #else | |||
| std::vector<int64_t> Permute::GetOrder() const { | |||
| auto fb_vector = this->primitive_->value_as_Permute()->order(); | |||
| return std::vector<int64_t>(fb_vector->begin(), fb_vector->end()); | |||
| } | |||
| void Permute::SetOrder(const std::vector<int64_t> &order) {} | |||
| int Permute::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||
| MS_ASSERT(nullptr != primitive); | |||
| MS_ASSERT(nullptr != fbb); | |||
| auto attr = primitive->value_as_Permute(); | |||
| if (attr == nullptr) { | |||
| MS_LOG(ERROR) << "value_as_Permute return nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int64_t> order; | |||
| if (attr->order() != nullptr) { | |||
| for (int i = 0; i < static_cast<int>(attr->order()->size()); i++) { | |||
| order.push_back(attr->order()->data()[i]); | |||
| } | |||
| } | |||
| auto val_offset = schema::CreatePermuteDirect(*fbb, &order); | |||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Permute, val_offset.o); | |||
| fbb->Finish(prim_offset); | |||
| return RET_OK; | |||
| } | |||
| PrimitiveC *PermuteCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Permute>(primitive); } | |||
| Registry PermuteRegistry(schema::PrimitiveType_Permute, PermuteCreator); | |||
| #endif | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,45 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include <memory> | |||
| #include "src/ops/primitive_c.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Permute : public PrimitiveC { | |||
| public: | |||
| Permute() = default; | |||
| ~Permute() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Permute, PrimitiveC); | |||
| explicit Permute(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| std::vector<int64_t> GetOrder() const; | |||
| void SetOrder(const std::vector<int64_t> &order); | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_PERMUTE_H_ | |||
| @@ -110,24 +110,32 @@ Registry TransposeRegistry(schema::PrimitiveType_Transpose, TransposeCreator); | |||
| #endif | |||
| int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| MS_ASSERT(this->primitive_ != nullptr); | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| MS_ASSERT(output != nullptr); | |||
| std::vector<int> perm = GetPerm(); | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| std::vector<int> in_shape = input->shape(); | |||
| output->set_data_type(input->data_type()); | |||
| output->set_format(input->format()); | |||
| if (input->format() == schema::Format::Format_NCHW && perm == nchw2nhwc_perm) { | |||
| output->set_format(schema::Format::Format_NHWC); | |||
| } else if (input->format() == schema::Format::Format_NHWC && perm == nhwc2nchw_perm) { | |||
| output->set_format(schema::Format::Format_NCHW); | |||
| } else { | |||
| output->set_format(input->format()); | |||
| } | |||
| if (!infer_flag()) { | |||
| return RET_INFER_INVALID; | |||
| } | |||
| MS_ASSERT(inputs_.size() == kSingleNum || inputs_.size() == kDoubleNum); | |||
| MS_ASSERT(outputs_.size() == kSingleNum); | |||
| std::vector<int> perm; | |||
| for (size_t i = 0; i < GetPerm().size(); i++) { | |||
| perm.push_back(GetPerm().at(i)); | |||
| if (in_shape.size() != 4 && perm.size() == 4) { | |||
| output->set_shape(in_shape); | |||
| return RET_OK; | |||
| } | |||
| std::vector<int> in_shape = input->shape(); | |||
| std::vector<int> out_shape; | |||
| out_shape.resize(perm.size()); | |||
| for (size_t i = 0; i < perm.size(); ++i) { | |||
| @@ -48,6 +48,14 @@ int TransposeFp16CPUKernel::Run() { | |||
| } | |||
| in_data_fp16_ = reinterpret_cast<float16_t *>(in_tensor->MutableData()); | |||
| out_data_fp16_ = reinterpret_cast<float16_t *>(out_tensor->MutableData()); | |||
| MS_ASSERT(in_data_fp16_); | |||
| MS_ASSERT(out_data_fp16_); | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||
| memcpy(out_data_fp16_, in_data_fp16_, in_tensor->ElementsNum() * sizeof(float16_t)); | |||
| return RET_OK; | |||
| } | |||
| int dims = out_tensor->shape().size(); | |||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | |||
| dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); | |||
| @@ -63,10 +71,7 @@ int TransposeFp16CPUKernel::Run() { | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| MS_ASSERT(param); | |||
| MS_ASSERT(in_data_fp16_); | |||
| MS_ASSERT(out_data_fp16_); | |||
| MS_ASSERT(out_shape_); | |||
| auto ret = Fp16DoTranspose(in_data_fp16_, out_data_fp16_, out_shape_, param, dim_size_, position_); | |||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | |||
| @@ -1,72 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/nchw2nhwc_fp32.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Nchw2Nhwc; | |||
| namespace mindspore::kernel { | |||
| int Nchw2NhwcCPUKernel::Init() { return RET_OK; } | |||
| int Nchw2NhwcCPUKernel::ReSize() { return RET_OK; } | |||
| int Nchw2NhwcCPUKernel::Run() { | |||
| auto input = in_tensors_.at(0); | |||
| auto output = out_tensors_.at(0); | |||
| if (input->shape().size() == 4) { | |||
| if (input->data_type() == kNumberTypeFloat32) { | |||
| PackNCHWToNHWCFp32(input->MutableData(), output->MutableData(), output->Batch(), | |||
| output->Height() * output->Width(), output->Channel()); | |||
| } else if (input->data_type() == kNumberTypeInt8) { | |||
| PackNCHWToNHWCInt8(input->MutableData(), output->MutableData(), output->Batch(), | |||
| output->Height() * output->Width(), output->Channel()); | |||
| } | |||
| } else { | |||
| memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Nchw2Nhwc); | |||
| auto *kernel = new (std::nothrow) Nchw2NhwcCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new Nchw2NhwcCPUKernel fail!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuNchw2NhwcFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -1,42 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/pack.h" | |||
| namespace mindspore::kernel { | |||
| class Nchw2NhwcCPUKernel : public LiteKernel { | |||
| public: | |||
| Nchw2NhwcCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~Nchw2NhwcCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NCHW2NHWC_H_ | |||
| @@ -1,72 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/runtime/kernel/arm/fp32/nhwc2nchw_fp32.h" | |||
| using mindspore::kernel::KERNEL_ARCH::kCPU; | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::schema::PrimitiveType_Nhwc2Nchw; | |||
| namespace mindspore::kernel { | |||
| int Nhwc2NchwCPUKernel::Init() { return RET_OK; } | |||
| int Nhwc2NchwCPUKernel::ReSize() { return RET_OK; } | |||
| int Nhwc2NchwCPUKernel::Run() { | |||
| auto input = in_tensors_.at(0); | |||
| auto output = out_tensors_.at(0); | |||
| if (input->shape().size() == 4) { | |||
| if (input->data_type() == kNumberTypeFloat32) { | |||
| PackNHWCToNCHWFp32(input->MutableData(), output->MutableData(), output->Batch(), | |||
| output->Height() * output->Width(), output->Channel()); | |||
| } else if (input->data_type() == kNumberTypeInt8) { | |||
| PackNHWCToNCHWInt8(input->MutableData(), output->MutableData(), output->Batch(), | |||
| output->Height() * output->Width(), output->Channel()); | |||
| } | |||
| } else { | |||
| memcpy(output->MutableData(), input->MutableData(), input->ElementsNum() * sizeof(float)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter, | |||
| const lite::InnerContext *ctx, const kernel::KernelKey &desc, | |||
| const mindspore::lite::PrimitiveC *primitive) { | |||
| MS_ASSERT(opParameter != nullptr); | |||
| MS_ASSERT(desc.type == schema::PrimitiveType_Nhwc2Nchw); | |||
| auto *kernel = new (std::nothrow) Nhwc2NchwCPUKernel(opParameter, inputs, outputs, ctx, primitive); | |||
| if (kernel == nullptr) { | |||
| MS_LOG(ERROR) << "new Nhwc2NchwCPUKernel fail!"; | |||
| free(opParameter); | |||
| return nullptr; | |||
| } | |||
| auto ret = kernel->Init(); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " | |||
| << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); | |||
| delete kernel; | |||
| return nullptr; | |||
| } | |||
| return kernel; | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuNhwc2NchwFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -1,42 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ | |||
| #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ | |||
| #include <vector> | |||
| #include "src/lite_kernel.h" | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "include/errorcode.h" | |||
| #include "nnacl/pack.h" | |||
| namespace mindspore::kernel { | |||
| class Nhwc2NchwCPUKernel : public LiteKernel { | |||
| public: | |||
| Nhwc2NchwCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) {} | |||
| ~Nhwc2NchwCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NHWC2NCHW_H_ | |||
| @@ -18,11 +18,14 @@ | |||
| #include "schema/model_generated.h" | |||
| #include "src/kernel_registry.h" | |||
| #include "src/runtime/runtime_api.h" | |||
| #include "nnacl/pack.h" | |||
| using mindspore::lite::KernelRegistrar; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::lite::RET_OP_EXECUTE_FAILURE; | |||
| using mindspore::schema::PrimitiveType_Nchw2Nhwc; | |||
| using mindspore::schema::PrimitiveType_Nhwc2Nchw; | |||
| using mindspore::schema::PrimitiveType_Transpose; | |||
| namespace mindspore::kernel { | |||
| @@ -36,7 +39,9 @@ int TransposeCPUKernel::Init() { | |||
| int TransposeCPUKernel::ReSize() { | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_); | |||
| if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||
| return RET_OK; | |||
| } | |||
| auto &inTensor = in_tensors_.front(); | |||
| auto &outTensor = out_tensors_.front(); | |||
| auto in_shape = inTensor->shape(); | |||
| @@ -80,6 +85,41 @@ int TransposeCPUKernel::Run() { | |||
| } | |||
| in_data_ = reinterpret_cast<float *>(in_tensor->MutableData()); | |||
| out_data_ = reinterpret_cast<float *>(out_tensor->MutableData()); | |||
| MS_ASSERT(in_data_); | |||
| MS_ASSERT(out_data_); | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) { | |||
| memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float)); | |||
| return RET_OK; | |||
| } | |||
| if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 2 && param->perm_[2] == 3 && | |||
| param->perm_[3] == 1) { | |||
| if (in_tensor->data_type() == kNumberTypeFloat32) { | |||
| PackNCHWToNHWCFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||
| } else if (in_tensor->data_type() == kNumberTypeInt8) { | |||
| PackNCHWToNHWCInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| if (in_tensor->shape().size() == 4 && param->perm_[0] == 0 && param->perm_[1] == 3 && param->perm_[2] == 1 && | |||
| param->perm_[3] == 2) { | |||
| if (in_tensor->data_type() == kNumberTypeFloat32) { | |||
| PackNHWCToNCHWFp32(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||
| } else if (in_tensor->data_type() == kNumberTypeInt8) { | |||
| PackNHWCToNCHWInt8(in_tensor->MutableData(), out_tensor->MutableData(), out_tensor->Batch(), | |||
| out_tensor->Height() * out_tensor->Width(), out_tensor->Channel()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| if (in_tensor->data_type() == kNumberTypeInt8) { | |||
| MS_LOG(ERROR) << "not support now"; | |||
| return RET_ERROR; | |||
| } | |||
| int dims = out_tensor->shape().size(); | |||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | |||
| dim_size_ = reinterpret_cast<int *>(context_->allocator->Malloc(dims * sizeof(int))); | |||
| @@ -96,10 +136,6 @@ int TransposeCPUKernel::Run() { | |||
| } | |||
| } | |||
| TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_); | |||
| MS_ASSERT(param); | |||
| MS_ASSERT(in_data_); | |||
| MS_ASSERT(out_data_); | |||
| MS_ASSERT(out_shape_); | |||
| auto ret = DoTransposeFp32(in_data_, out_data_, out_shape_, param, dim_size_, position_); | |||
| if (dims > MAX_TRANSPOSE_DIM_SIZE) { | |||
| @@ -143,4 +179,9 @@ kernel::LiteKernel *CpuTransposeFp32KernelCreator(const std::vector<lite::Tensor | |||
| } | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Transpose, CpuTransposeFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nchw2Nhwc, CpuTransposeFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nchw2Nhwc, CpuTransposeFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Nhwc2Nchw, CpuTransposeFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Nhwc2Nchw, CpuTransposeFp32KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -3,7 +3,7 @@ mtk_emotions-d2012-75.8%.onnx | |||
| mtk_face_features_v3.onnx | |||
| emotion-ferplus-8.onnx | |||
| rcnn-ilsvrc13-9.onnx | |||
| efficientnet-lite4-11.onnx | |||
| #efficientnet-lite4-11.onnx | |||
| mobilenetv2-7.onnx | |||
| shufflenet-v2-10.onnx | |||
| squeezenet1.1-7.onnx | |||
| @@ -3,7 +3,7 @@ mtk_emotions-d2012-75.8%.onnx 20 | |||
| mtk_face_features_v3.onnx 20 | |||
| emotion-ferplus-8.onnx 1 | |||
| #rcnn-ilsvrc13-9.onnx 0.1 | |||
| efficientnet-lite4-11.onnx 2 | |||
| #efficientnet-lite4-11.onnx 2 | |||
| mobilenetv2-7.onnx 8 | |||
| shufflenet-v2-10.onnx 5 | |||
| squeezenet1.1-7.onnx 1 | |||
| @@ -66,9 +66,7 @@ static const std::vector<schema::PrimitiveType> fp32FullOpList = { | |||
| static const std::vector<schema::PrimitiveType> int8NeedNhwcOpList = {}; | |||
| static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Nchw2Nhwc, | |||
| schema::PrimitiveType_Nhwc2Nchw, | |||
| schema::PrimitiveType_Conv2D, | |||
| static const std::vector<schema::PrimitiveType> int8OpList = {schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_Add, | |||
| schema::PrimitiveType_Transpose, | |||
| @@ -16,6 +16,7 @@ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| @@ -24,103 +25,59 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| #define kFormatTransMatchPathLen2 2 | |||
| #define kFormatTransMatchPathLen3 3 | |||
| STATUS FormatTransFusionPass::DefinePattern() { | |||
| // nchw2nhwc + nhwc2nchw | |||
| // nchw2nhwc + nhwc2nchw || nhwc2nchw + nchw2nhwc | |||
| { | |||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||
| auto transpose1 = std::make_shared<PatternOp>(); | |||
| transpose1->id = kFormatTransTranspose1; | |||
| transpose1->types = {PrimitiveType_Transpose}; | |||
| auto transpose2 = std::make_shared<PatternOp>(); | |||
| transpose2->id = kFormatTransTranspose2; | |||
| transpose2->types = {PrimitiveType_Transpose}; | |||
| nh2ncOp->left = nc2nhOp; | |||
| std::unique_ptr<FusionPattern> nc2NhAndNh2NcFusionPattern(new (std::nothrow) | |||
| FusionPattern(kNc2NhAndNh2NcFusionPattern)); | |||
| if (nc2NhAndNh2NcFusionPattern == nullptr) { | |||
| transpose2->left = transpose1; | |||
| auto pattern = std::make_unique<FusionPattern>(kNc2NhAndNh2NcFusionPattern); | |||
| if (pattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcFusionPattern << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| nc2NhAndNh2NcFusionPattern->AddPatternOp(nc2nhOp); | |||
| nc2NhAndNh2NcFusionPattern->AddPatternOp(nh2ncOp); | |||
| nc2NhAndNh2NcFusionPattern->Finish(); | |||
| this->patterns.emplace_back(nc2NhAndNh2NcFusionPattern.release()); | |||
| pattern->AddPatternOp(transpose1); | |||
| pattern->AddPatternOp(transpose2); | |||
| pattern->Finish(); | |||
| this->patterns.emplace_back(pattern.release()); | |||
| } | |||
| // nhwc2nchw + QuantDtypeCast + nchw2nhwc || nchw2nhwc + QuantDtypeCast + nhwc2nchw | |||
| { | |||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||
| auto transpose1 = std::make_shared<PatternOp>(); | |||
| transpose1->id = kFormatTransTranspose1; | |||
| transpose1->types = {PrimitiveType_Transpose}; | |||
| auto passOp = std::make_shared<PatternOp>(); | |||
| passOp->id = kFormatTransPassOp; | |||
| passOp->types = {PrimitiveType_QuantDTypeCast}; | |||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||
| auto transpose2 = std::make_shared<PatternOp>(); | |||
| transpose2->id = kFormatTransTranspose2; | |||
| transpose2->types = {PrimitiveType_Transpose}; | |||
| passOp->left = nc2nhOp; | |||
| nh2ncOp->left = passOp; | |||
| std::unique_ptr<FusionPattern> nc2NhAndNh2NcPassFusionPattern(new (std::nothrow) | |||
| FusionPattern(kNc2NhAndNh2NcPassFusionPattern)); | |||
| if (nc2NhAndNh2NcPassFusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kNc2NhAndNh2NcPassFusionPattern << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nc2nhOp); | |||
| nc2NhAndNh2NcPassFusionPattern->AddPatternOp(passOp); | |||
| nc2NhAndNh2NcPassFusionPattern->AddPatternOp(nh2ncOp); | |||
| nc2NhAndNh2NcPassFusionPattern->Finish(); | |||
| this->patterns.emplace_back(nc2NhAndNh2NcPassFusionPattern.release()); | |||
| } | |||
| // nhwc2nchw + nchw2nhwc | |||
| { | |||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||
| nc2nhOp->left = nh2ncOp; | |||
| std::unique_ptr<FusionPattern> nh2NcAndNc2NhFusionPattern(new (std::nothrow) | |||
| FusionPattern(kNh2NcAndNc2NhFusionPattern)); | |||
| if (nh2NcAndNc2NhFusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhFusionPattern << "failed"; | |||
| return RET_ERROR; | |||
| } | |||
| nh2NcAndNc2NhFusionPattern->AddPatternOp(nh2ncOp); | |||
| nh2NcAndNc2NhFusionPattern->AddPatternOp(nc2nhOp); | |||
| nh2NcAndNc2NhFusionPattern->Finish(); | |||
| this->patterns.emplace_back(nh2NcAndNc2NhFusionPattern.release()); | |||
| } | |||
| // nhwc2nchw + QuantDtypeCast + nchw2nhwc | |||
| { | |||
| auto nc2nhOp = std::make_shared<PatternOp>(); | |||
| nc2nhOp->id = kFormatTransNc2NhOp; | |||
| nc2nhOp->types = {PrimitiveType_Nchw2Nhwc}; | |||
| auto passOp = std::make_shared<PatternOp>(); | |||
| passOp->id = kFormatTransPassOp; | |||
| passOp->types = {PrimitiveType_QuantDTypeCast}; | |||
| auto nh2ncOp = std::make_shared<PatternOp>(); | |||
| nh2ncOp->id = kFormatTransNh2NcOp; | |||
| nh2ncOp->types = {PrimitiveType_Nhwc2Nchw}; | |||
| passOp->left = nh2ncOp; | |||
| nc2nhOp->left = passOp; | |||
| std::unique_ptr<FusionPattern> nh2NcAndNc2NhPassFusionPattern(new (std::nothrow) | |||
| FusionPattern(kNh2NcAndNc2NhPassFusionPattern)); | |||
| if (nh2NcAndNc2NhPassFusionPattern == nullptr) { | |||
| passOp->left = transpose2; | |||
| transpose1->left = passOp; | |||
| auto pattern = std::make_unique<FusionPattern>(kNh2NcAndNc2NhPassFusionPattern); | |||
| if (pattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kNh2NcAndNc2NhPassFusionPattern << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nh2ncOp); | |||
| nh2NcAndNc2NhPassFusionPattern->AddPatternOp(passOp); | |||
| nh2NcAndNc2NhPassFusionPattern->AddPatternOp(nc2nhOp); | |||
| nh2NcAndNc2NhPassFusionPattern->Finish(); | |||
| this->patterns.emplace_back(nh2NcAndNc2NhPassFusionPattern.release()); | |||
| pattern->AddPatternOp(transpose1); | |||
| pattern->AddPatternOp(passOp); | |||
| pattern->AddPatternOp(transpose2); | |||
| pattern->Finish(); | |||
| this->patterns.emplace_back(pattern.release()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -136,51 +93,32 @@ STATUS FormatTransFusionPass::DoFusion(schema::MetaGraphT *graph, const std::str | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::shared_ptr<Path> srcPath; | |||
| std::shared_ptr<Path> dstPath; | |||
| if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { | |||
| srcPath = matchedPath[kFormatTransNc2NhOp]; | |||
| dstPath = matchedPath[kFormatTransNh2NcOp]; | |||
| } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { | |||
| srcPath = matchedPath[kFormatTransNh2NcOp]; | |||
| dstPath = matchedPath[kFormatTransNc2NhOp]; | |||
| } else { | |||
| MS_ASSERT(false); | |||
| } | |||
| if (srcPath == nullptr) { | |||
| MS_LOG(ERROR) << "srcPath is failed to get"; | |||
| return RET_ERROR; | |||
| } | |||
| if (dstPath == nullptr) { | |||
| MS_LOG(ERROR) << "dstPath is failed to get"; | |||
| std::shared_ptr<Path> srcPath = matchedPath[kFormatTransTranspose1]; | |||
| std::shared_ptr<Path> dstPath = matchedPath[kFormatTransTranspose2]; | |||
| if (srcPath == nullptr || dstPath == nullptr) { | |||
| MS_LOG(ERROR) << "srcPath or dstPath is failed to get"; | |||
| return RET_ERROR; | |||
| } | |||
| auto srcNode = graph->nodes.at(srcPath->nodeIdx).get(); | |||
| auto dstNode = graph->nodes.at(dstPath->nodeIdx).get(); | |||
| MS_ASSERT(srcNode != nullptr); | |||
| MS_ASSERT(dstNode != nullptr); | |||
| if (patternName == kNc2NhAndNh2NcFusionPattern || patternName == kNc2NhAndNh2NcPassFusionPattern) { | |||
| MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nchw2Nhwc); | |||
| MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nhwc2Nchw); | |||
| } else if (patternName == kNh2NcAndNc2NhFusionPattern || patternName == kNh2NcAndNc2NhPassFusionPattern) { | |||
| MS_ASSERT(GetCNodeTType(*srcNode) == schema::PrimitiveType_Nhwc2Nchw); | |||
| MS_ASSERT(GetCNodeTType(*dstNode) == schema::PrimitiveType_Nchw2Nhwc); | |||
| } else { | |||
| MS_ASSERT(false); | |||
| } | |||
| auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateOneWayNode(graph, dstPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; | |||
| return status; | |||
| bool isNc2NhAndNh2Nc = srcNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm && | |||
| dstNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm; | |||
| bool isNh2NcAndNc2Nh = srcNode->primitive->value.AsTranspose()->perm == nhwc2nchw_perm && | |||
| dstNode->primitive->value.AsTranspose()->perm == nchw2nhwc_perm; | |||
| if (isNc2NhAndNh2Nc || isNh2NcAndNc2Nh) { | |||
| auto status = IsolateOneWayNode(graph, srcPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << srcNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateOneWayNode(graph, dstPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| @@ -24,12 +24,10 @@ | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr const char *kFormatTransNc2NhOp = "FormatTransNc2NhOp"; | |||
| constexpr const char *kFormatTransNh2NcOp = "FormatTransNh2NcOp"; | |||
| constexpr const char *kFormatTransTranspose1 = "FormatTransTransposeOp1"; | |||
| constexpr const char *kFormatTransTranspose2 = "FormatTransTransposeOp2"; | |||
| constexpr const char *kFormatTransPassOp = "FormatTransPassOp"; | |||
| constexpr const char *kNc2NhAndNh2NcFusionPattern = "Nc2NhAndNh2NcFusionPattern"; | |||
| constexpr const char *kNc2NhAndNh2NcPassFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; | |||
| constexpr const char *kNh2NcAndNc2NhFusionPattern = "Nh2NcAndNc2NhFusionPattern"; | |||
| constexpr const char *kNh2NcAndNc2NhPassFusionPattern = "Nh2NcAndNc2NhPassFusionPattern"; | |||
| class FormatTransFusionPass : public FusionPass { | |||
| @@ -1,148 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h" | |||
| #include "src/common/log_adapter.h" | |||
| #include "securec/include/securec.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "include/errorcode.h" | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| #define kFormatTransTransposeMatchPathLen 2 | |||
| STATUS FormatTransPermuteFusionPass::DefinePattern() { | |||
| // format trans + permute | |||
| { | |||
| auto formatTransOp = std::make_shared<PatternOp>(); | |||
| formatTransOp->id = kFormatTransformOp; | |||
| formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||
| auto transposeOp = std::make_shared<PatternOp>(); | |||
| transposeOp->id = kPermuteOp; | |||
| transposeOp->types = {PrimitiveType_Transpose}; | |||
| transposeOp->left = formatTransOp; | |||
| std::unique_ptr<FusionPattern> formatTransTransposeFusionPattern( | |||
| new (std::nothrow) FusionPattern(kFormatTrans2TransposeFusionPattern)); | |||
| if (formatTransTransposeFusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kFormatTrans2TransposeFusionPattern << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| formatTransTransposeFusionPattern->AddPatternOp(formatTransOp); | |||
| formatTransTransposeFusionPattern->AddPatternOp(transposeOp); | |||
| formatTransTransposeFusionPattern->Finish(); | |||
| this->patterns.emplace_back(formatTransTransposeFusionPattern.release()); | |||
| } | |||
| // permute + format trans | |||
| { | |||
| auto formatTransOp = std::make_shared<PatternOp>(); | |||
| formatTransOp->id = kFormatTransformOp; | |||
| formatTransOp->types = {PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; | |||
| auto transposeOp = std::make_shared<PatternOp>(); | |||
| transposeOp->id = kPermuteOp; | |||
| transposeOp->types = {PrimitiveType_Transpose}; | |||
| formatTransOp->left = transposeOp; | |||
| std::unique_ptr<FusionPattern> transposeFormatTransFusionPattern( | |||
| new (std::nothrow) FusionPattern(kTranspose2FormatTransFusionPattern)); | |||
| if (transposeFormatTransFusionPattern == nullptr) { | |||
| MS_LOG(ERROR) << "new " << kTranspose2FormatTransFusionPattern << " failed"; | |||
| return RET_ERROR; | |||
| } | |||
| transposeFormatTransFusionPattern->AddPatternOp(formatTransOp); | |||
| transposeFormatTransFusionPattern->AddPatternOp(transposeOp); | |||
| transposeFormatTransFusionPattern->Finish(); | |||
| this->patterns.emplace_back(transposeFormatTransFusionPattern.release()); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS FormatTransPermuteFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); } | |||
| STATUS FormatTransPermuteFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (matchedPath.size() != kFormatTransTransposeMatchPathLen) { | |||
| MS_LOG(ERROR) << "schema::Format-Transform-Transpose-Fusion should have " << kFormatTransTransposeMatchPathLen | |||
| << " NodeIndex in matchedPair"; | |||
| return RET_PARAM_INVALID; | |||
| } | |||
| std::shared_ptr<Path> formatTransPath = matchedPath[kFormatTransformOp]; | |||
| std::shared_ptr<Path> transposePath = matchedPath[kPermuteOp]; | |||
| if (formatTransPath == nullptr) { | |||
| MS_LOG(ERROR) << "formatTransPath is failed to get"; | |||
| return RET_ERROR; | |||
| } | |||
| if (transposePath == nullptr) { | |||
| MS_LOG(ERROR) << "permutePath is failed to get"; | |||
| return RET_ERROR; | |||
| } | |||
| auto &formatTransNode = graph->nodes.at(formatTransPath->nodeIdx); | |||
| auto &transposeNode = graph->nodes.at(transposePath->nodeIdx); | |||
| MS_ASSERT(formatTransNode != nullptr); | |||
| MS_ASSERT(transposeNode != nullptr); | |||
| auto formatTransType = formatTransNode->primitive->value.type; | |||
| if (formatTransType != PrimitiveType_Nhwc2Nchw && formatTransType != PrimitiveType_Nchw2Nhwc) { | |||
| MS_LOG(ERROR) << "FormatTransNode should be " << EnumNamePrimitiveType(PrimitiveType_Nhwc2Nchw) << " or " | |||
| << EnumNamePrimitiveType(PrimitiveType_Nchw2Nhwc) << ", but got " | |||
| << EnumNamePrimitiveType(formatTransType); | |||
| return RET_ERROR; | |||
| } | |||
| MS_ASSERT(transposeNode->primitive != nullptr); | |||
| auto transposePrimitive = transposeNode->primitive->value.AsTranspose(); | |||
| MS_ASSERT(transposePrimitive != nullptr); | |||
| auto perm = transposePrimitive->perm; | |||
| if (perm.size() != 4) { | |||
| return RET_OK; | |||
| } | |||
| std::vector<int32_t> nchw2nhwcPerm = {0, 2, 3, 1}; | |||
| std::vector<int32_t> nhwc2nchwPerm = {0, 3, 1, 2}; | |||
| if ((perm == nchw2nhwcPerm && formatTransType == PrimitiveType_Nhwc2Nchw) || | |||
| (perm == nhwc2nchwPerm && formatTransType == PrimitiveType_Nchw2Nhwc)) { | |||
| if (formatTransPath->nodeIdx < transposePath->nodeIdx) { | |||
| if (graph->allTensors.at(formatTransNode->inputIndex[0])->format != | |||
| graph->allTensors.at(transposeNode->outputIndex[0])->format) { | |||
| return RET_OK; | |||
| } | |||
| } else { | |||
| if (graph->allTensors.at(transposeNode->inputIndex[0])->format != | |||
| graph->allTensors.at(formatTransNode->outputIndex[0])->format) { | |||
| return RET_OK; | |||
| } | |||
| } | |||
| auto status = IsolateOneWayNode(graph, formatTransPath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << formatTransNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| status = IsolateOneWayNode(graph, transposePath->nodeIdx); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << transposeNode->name << ", error: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,48 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H | |||
| #define MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "tools/converter/legacy_optimizer/fusion/fusion_pass.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| constexpr const char *kFormatTransformOp = "FormatTransOp"; | |||
| constexpr const char *kPermuteOp = "PermuteOp"; | |||
| constexpr const char *kFormatTrans2TransposeFusionPattern = "Nc2NhAndNh2NcFusionPattern"; | |||
| constexpr const char *kTranspose2FormatTransFusionPattern = "Nc2NhAndNh2NcPassFusionPattern"; | |||
| class FormatTransPermuteFusionPass : public FusionPass { | |||
| public: | |||
| FormatTransPermuteFusionPass() = default; | |||
| ~FormatTransPermuteFusionPass() override = default; | |||
| STATUS DefinePattern() override; | |||
| STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName, | |||
| std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override; | |||
| STATUS Run(schema::MetaGraphT *graph) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_PREDICT_FORMAT_TRANS_PERMUTE_FUSION_PASS_H | |||
| @@ -115,7 +115,7 @@ STATUS QuantCastFusionPass::DefinePattern() { | |||
| srcOp->types = {schema::PrimitiveType_QuantDTypeCast}; | |||
| auto formatOp = std::make_shared<PatternOp>(); | |||
| formatOp->id = kFormatTransOp; | |||
| formatOp->types = {schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Nchw2Nhwc}; | |||
| formatOp->types = {PrimitiveType_Transpose}; | |||
| formatOp->left = srcOp; | |||
| auto dstOp = std::make_shared<PatternOp>(); | |||
| dstOp->id = kQuantCastDstOp; | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <algorithm> | |||
| #include <string> | |||
| #include <memory> | |||
| #include <utility> | |||
| @@ -196,15 +197,47 @@ NodeIter FormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeI | |||
| } | |||
| auto transNode = std::make_unique<schema::CNodeT>(); | |||
| transNode->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| auto attr = new (std::nothrow) schema::TransposeT(); | |||
| if (nodeType == kNCHW2NHWC) { | |||
| transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc; | |||
| attr->perm = {0, 2, 3, 1}; | |||
| } else { | |||
| transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++); | |||
| transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw; | |||
| attr->perm = {0, 3, 1, 2}; | |||
| } | |||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode); | |||
| transNode->primitive->value.value = attr; | |||
| OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> { | |||
| auto newOpDef = std::make_unique<schema::CNodeT>(); | |||
| if (newOpDef == nullptr) { | |||
| MS_LOG(ERROR) << "new CNodeT failed"; | |||
| return nullptr; | |||
| } | |||
| newOpDef->name = inOpDef->name; | |||
| newOpDef->quantType = inOpDef->quantType; | |||
| newOpDef->primitive = std::make_unique<schema::PrimitiveT>(); | |||
| if (newOpDef->primitive == nullptr) { | |||
| MS_LOG(ERROR) << "new PrimitiveT failed"; | |||
| return nullptr; | |||
| } | |||
| newOpDef->primitive->value.type = schema::PrimitiveType_Transpose; | |||
| auto transposeParam = new (std::nothrow) TransposeT; | |||
| if (transposeParam == nullptr) { | |||
| MS_LOG(ERROR) << "new transposeParam failed"; | |||
| return nullptr; | |||
| } | |||
| auto inParam = inOpDef->primitive->value.AsTranspose(); | |||
| MS_ASSERT(inParam != nullptr); | |||
| transposeParam->perm.resize(inParam->perm.size()); | |||
| std::transform(inParam->perm.begin(), inParam->perm.end(), transposeParam->perm.begin(), | |||
| [](const int32_t ele) { return ele; }); | |||
| newOpDef->primitive->value.value = transposeParam; | |||
| return newOpDef; | |||
| }; | |||
| return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, TransposeOpCopyer); | |||
| } | |||
| void FormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||
| @@ -25,6 +25,10 @@ | |||
| #include "schema/inner/model_generated.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | |||
| @@ -34,7 +38,10 @@ STATUS GlobalFormatTransformPass::Run(MetaGraphT *graph) { | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto type = node->primitive->value.type; | |||
| if (type != schema::PrimitiveType_Nchw2Nhwc) { | |||
| if (type != PrimitiveType_Transpose) { | |||
| continue; | |||
| } | |||
| if (node->primitive->value.AsTranspose()->perm != nchw2nhwc_perm) { | |||
| continue; | |||
| } | |||
| std::vector<size_t> pre_nh2nc_nodes; | |||
| @@ -176,7 +183,8 @@ STATUS GlobalFormatTransformPass::FindPreNh2NcNodes(MetaGraphT *graph, size_t nc | |||
| auto &pre_node = graph->nodes.at(input_node_index); | |||
| MS_ASSERT(pre_node != nullptr); | |||
| auto node_type = pre_node->primitive->value.type; | |||
| if (node_type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (node_type == schema::PrimitiveType_Transpose && | |||
| pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| if (!IsContain(*pre_nh2nc_nodes, input_node_index)) { | |||
| pre_nh2nc_nodes->emplace_back(input_node_index); | |||
| } | |||
| @@ -24,12 +24,16 @@ | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) { | |||
| MS_ASSERT(graph != nullptr); | |||
| MS_ASSERT(node != nullptr); | |||
| auto input_node_indexes = GetInputNodeIdx(*graph, *node); | |||
| pre_type_ = schema::PrimitiveType_NONE; | |||
| pre_type_ = kNONE; | |||
| size_t has_trans_count = 0; | |||
| auto can_fusion = true; | |||
| for (auto input_node_index : input_node_indexes) { | |||
| @@ -38,16 +42,28 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| MS_ASSERT(pre_node != nullptr); | |||
| MS_ASSERT(pre_node->primitive != nullptr); | |||
| MS_ASSERT(pre_node->primitive->value != nullptr); | |||
| if (pre_type_ == schema::PrimitiveType_NONE) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| pre_type_ = pre_node->primitive->value.type; | |||
| if (pre_type_ == kNONE) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| pre_type_ = kNCHW2NHWC; | |||
| } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| pre_type_ = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (pre_type_ != pre_node->primitive->value.type) { | |||
| if (pre_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| auto cur_type = kNONE; | |||
| if (pre_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| cur_type = kNCHW2NHWC; | |||
| } else if (pre_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| cur_type = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| if (pre_type_ != cur_type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| @@ -60,23 +76,35 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| return false; | |||
| } | |||
| auto output_node_indexes = GetOutputNodeIdx(*graph, *node); | |||
| post_type_ = schema::PrimitiveType_NONE; | |||
| post_type_ = kNONE; | |||
| for (auto output_node_index : output_node_indexes) { | |||
| MS_ASSERT(graph->nodes.size() > output_node_index); | |||
| auto &post_node = graph->nodes.at(output_node_index); | |||
| MS_ASSERT(post_node != nullptr); | |||
| MS_ASSERT(post_node->primitive != nullptr); | |||
| MS_ASSERT(post_node->primitive->value != nullptr); | |||
| if (post_type_ == schema::PrimitiveType_NONE) { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| post_type_ = post_node->primitive->value.type; | |||
| if (post_type_ == kNONE) { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| post_type_ = kNCHW2NHWC; | |||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| post_type_ = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| has_trans_count++; | |||
| } | |||
| } else { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc || | |||
| post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (post_type_ != post_node->primitive->value.type) { | |||
| if (post_node->primitive->value.type == schema::PrimitiveType_Transpose) { | |||
| auto cur_type = kNONE; | |||
| if (post_node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm) { | |||
| cur_type = kNCHW2NHWC; | |||
| } else if (post_node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm) { | |||
| cur_type = kNHWC2NCHW; | |||
| } else { | |||
| return false; | |||
| } | |||
| if (post_type_ != cur_type) { | |||
| can_fusion = false; | |||
| break; | |||
| } else { | |||
| @@ -88,7 +116,7 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| if (!can_fusion) { | |||
| return false; | |||
| } | |||
| if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||
| if (pre_type_ == kNONE && post_type_ == kNONE) { | |||
| return false; | |||
| } | |||
| auto output_size = output_node_indexes.empty() ? 1 : output_node_indexes.size(); | |||
| @@ -114,21 +142,21 @@ bool TransOpInsertPass::CanFusion(schema::MetaGraphT *graph, const std::unique_p | |||
| STATUS TransOpInsertPass::FindOutTransType() { | |||
| pre_insert_trans_type_ = kNHWC2NCHW; | |||
| post_insert_trans_type_ = kNHWC2NCHW; | |||
| if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) { | |||
| pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||
| post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||
| pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) { | |||
| if (pre_type_ == kNONE && post_type_ != kNONE) { | |||
| pre_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } else if (pre_type_ != kNONE && post_type_ == kNONE) { | |||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC; | |||
| } else if (pre_type_ == kNONE && post_type_ == kNONE) { | |||
| MS_ASSERT(false); | |||
| } else { | |||
| if (pre_type_ == post_type_) { | |||
| MS_LOG(ERROR) << "Unknow error"; | |||
| return RET_ERROR; | |||
| } | |||
| pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW; | |||
| pre_insert_trans_type_ = pre_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| post_insert_trans_type_ = post_type_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| @@ -44,8 +45,10 @@ class TransOpInsertPass : public FormatTransPass { | |||
| private: | |||
| FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW; | |||
| FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW; | |||
| schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE; | |||
| schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE; | |||
| FormatTransNodeType pre_type_ = kNONE; | |||
| std::vector<int> pre_perm_; | |||
| FormatTransNodeType post_type_ = kNONE; | |||
| std::vector<int> post_perm_; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -25,13 +25,18 @@ | |||
| using mindspore::lite::PrimitiveC; | |||
| using mindspore::lite::Tensor; | |||
| namespace mindspore { | |||
| namespace { | |||
| std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1}; | |||
| std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2}; | |||
| } // namespace | |||
| namespace lite { | |||
| STATUS TransOpRemovePass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { | |||
| auto &node = *iter; | |||
| auto type = node->primitive->value.type; | |||
| if (type == schema::PrimitiveType_Nchw2Nhwc || type == schema::PrimitiveType_Nhwc2Nchw) { | |||
| if (type == schema::PrimitiveType_Transpose && (node->primitive->value.AsTranspose()->perm == nchw2nhwc_perm || | |||
| node->primitive->value.AsTranspose()->perm == nhwc2nchw_perm)) { | |||
| auto &input_tensor = graph->allTensors.at(node->inputIndex.at(0)); | |||
| // less than 4 dims can delete | |||
| if (!input_tensor->dims.empty() && input_tensor->dims.size() < 4) { | |||
| @@ -523,8 +523,6 @@ QuantParamCalcRegister::QuantParamCalcRegister() { | |||
| _registerMap[schema::PrimitiveType_Transpose] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_MatMul] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_FullConnection] = std::make_shared<ConvCalcer>(); | |||
| _registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer; | |||
| _registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer; | |||
| // detection_postprocess op's quant param will not infer only fetch from preNode or postNode | |||
| // because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float. | |||
| // if quantTransNode is inserted after detection_postprocess node, there will be some errors | |||
| @@ -89,13 +89,25 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { | |||
| auto cnode = std::dynamic_pointer_cast<CNode>(node); | |||
| auto type = NodePrimitiveType(cnode); | |||
| static const std::vector<schema::PrimitiveType> int8OpList = { | |||
| schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Mul, | |||
| schema::PrimitiveType_Pooling, schema::PrimitiveType_Concat, schema::PrimitiveType_Split, | |||
| schema::PrimitiveType_TupleGetItem, schema::PrimitiveType_Reshape, schema::PrimitiveType_FullConnection, | |||
| schema::PrimitiveType_MatMul, schema::PrimitiveType_Crop, schema::PrimitiveType_DeDepthwiseConv2D, | |||
| schema::PrimitiveType_DeConv2D, schema::PrimitiveType_Activation, schema::PrimitiveType_Transpose, | |||
| schema::PrimitiveType_Eltwise, schema::PrimitiveType_Gather, schema::PrimitiveType_LayerNorm, | |||
| schema::PrimitiveType_Conv2D, | |||
| schema::PrimitiveType_DepthwiseConv2D, | |||
| schema::PrimitiveType_Add, | |||
| schema::PrimitiveType_Mul, | |||
| schema::PrimitiveType_Pooling, | |||
| schema::PrimitiveType_Concat, | |||
| schema::PrimitiveType_Split, | |||
| schema::PrimitiveType_TupleGetItem, | |||
| schema::PrimitiveType_Reshape, | |||
| schema::PrimitiveType_FullConnection, | |||
| schema::PrimitiveType_MatMul, | |||
| schema::PrimitiveType_Crop, | |||
| schema::PrimitiveType_DeDepthwiseConv2D, | |||
| schema::PrimitiveType_DeConv2D, | |||
| schema::PrimitiveType_Activation, | |||
| schema::PrimitiveType_Transpose, | |||
| schema::PrimitiveType_Eltwise, | |||
| schema::PrimitiveType_Gather, | |||
| schema::PrimitiveType_LayerNorm, | |||
| }; | |||
| bool contain = IsContain(int8OpList, type); | |||
| if (!contain) { | |||
| @@ -547,6 +547,8 @@ bool OnnxInputAdjustOpPass::Run(const FuncGraphPtr &func_graph) { | |||
| status = ReplaceConstant(func_graph, cnode); | |||
| } else if (type == schema::PrimitiveType_Cast) { | |||
| status = AdjustCast(cnode); | |||
| } else if (type == schema::PrimitiveType_Transpose) { | |||
| status = ReplaceTransposeWithGraphInput(func_graph, cnode); | |||
| } | |||
| if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "adjust input pass is failed."; | |||