Merge pull request !6349 from zhengjun10/mastertags/v1.1.0
| @@ -120,6 +120,8 @@ void ConvertConvWeight(const ParameterPtr ¶m_node) { | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[1] = filter_k; | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[2] = filter_h; | |||
| utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[3] = filter_w; | |||
| weight->set_tensor_shape({static_cast<int>(filter_c), static_cast<int>(filter_k), static_cast<int>(filter_h), | |||
| static_cast<int>(filter_w)}); | |||
| } | |||
| return; | |||
| } | |||
| @@ -250,7 +252,12 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp | |||
| MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; | |||
| return RET_ERROR; | |||
| } | |||
| int group = GetValue<int>(prim.GetAttr("group")); | |||
| auto groupAttr = prim.GetAttr("group"); | |||
| if (groupAttr == nullptr) { | |||
| MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; | |||
| return RET_NULL_PTR; | |||
| } | |||
| int group = GetValue<int>(groupAttr); | |||
| if (group > 1) { | |||
| PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); | |||
| } else { | |||
| @@ -205,6 +205,8 @@ if(BUILD_CONVERTER) | |||
| ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc | |||
| ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc | |||
| ) | |||
| endif() | |||
| ### train | |||
| @@ -63,6 +63,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} | |||
| ../optimizer/fusion/conv_bn_fusion.cc | |||
| ../optimizer/fusion/constant_folding_fusion.cc | |||
| ../optimizer/fusion/quant_dtype_cast_fusion.cc | |||
| ../optimizer/graph/weight_format_transform_pass.cc | |||
| ../optimizer/graph/weight_format_hardcode_pass.cc | |||
| ) | |||
| add_subdirectory(../anf_importer anf_importer) | |||
| @@ -25,6 +25,8 @@ | |||
| #include "tools/optimizer/fusion/conv_bn_fusion.h" | |||
| #include "tools/optimizer/fusion/constant_folding_fusion.h" | |||
| #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| #include "tools/converter/quantizer/post_training_quantizer.h" | |||
| #include "tools/converter/quantizer/quant_cast.h" | |||
| #include "tools/converter/quantizer/weight_quantizer.h" | |||
| @@ -41,6 +43,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| // fusion const_fold | |||
| auto optimizer = std::make_shared<opt::GraphOptimizer>(); | |||
| auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false); | |||
| auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true); | |||
| // for now - trainning is not supporting fuse operations | |||
| if (config != nullptr && config->trainModel == false) { | |||
| @@ -61,11 +64,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver | |||
| pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(true, "conv_tuple_relu6", | |||
| schema::PrimitiveType_Activation, | |||
| schema::ActivationType_RELU6)); | |||
| auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>(); | |||
| weight_format_hardcode_pass->SetFmkType(config->fmk); | |||
| weight_format_hardcode_pass->SetQuantType(config->quantType); | |||
| graph_pm->AddPass(weight_format_hardcode_pass); | |||
| auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>(); | |||
| weight_format_transform_pass->SetFmkType(config->fmk); | |||
| weight_format_transform_pass->SetQuantType(config->quantType); | |||
| graph_pm->AddPass(weight_format_transform_pass); | |||
| } | |||
| pm->AddPass(std::make_shared<opt::ConstFoldPass>()); | |||
| optimizer->AddPassManager(pm); | |||
| FuncGraphPtr new_graph = optimizer->Optimize(old_graph); | |||
| optimizer->AddPassManager(graph_pm); | |||
| auto new_graph = optimizer->Optimize(old_graph); | |||
| if (new_graph == nullptr) { | |||
| ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); | |||
| return nullptr; | |||
| @@ -28,8 +28,6 @@ | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/infershape_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" | |||
| #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" | |||
| @@ -62,23 +60,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { | |||
| int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| STATUS status; | |||
| { | |||
| Optimizer weightFormatOptimizer; | |||
| auto weightHardCodePass = new WeightFormatHardCodePass(); | |||
| auto weightFormatPass = new WeightFormatTransformPass(); | |||
| weightHardCodePass->SetQuantType(ctx.quantType); | |||
| weightHardCodePass->SetFmkType(ctx.fmk); | |||
| weightFormatPass->SetQuantType(ctx.quantType); | |||
| weightFormatPass->SetFmkType(ctx.fmk); | |||
| weightFormatOptimizer.AddPass(weightHardCodePass); | |||
| weightFormatOptimizer.AddPass(weightFormatPass); | |||
| status = weightFormatOptimizer.Run(graphDefT); | |||
| if (status != RET_OK && status != RET_NO_CHANGE) { | |||
| MS_LOG(ERROR) << "Run weightFormatOptimizer graphPasses Failed"; | |||
| return status; | |||
| } | |||
| } | |||
| { | |||
| Optimizer unusedOpRemoveOptimizer; | |||
| unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); | |||
| @@ -149,6 +130,8 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { | |||
| formatTransOptimizer.AddPass(formatTransPass); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); | |||
| formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); | |||
| @@ -79,7 +79,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN | |||
| MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); | |||
| const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); | |||
| MS_ASSERT(mulNodeBiasTensor != nullptr); | |||
| if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) { | |||
| if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) { | |||
| // dont fusion, return | |||
| return RET_OK; | |||
| } | |||
| @@ -4,8 +4,6 @@ add_library(graph_pass_mid OBJECT | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_hardcode_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc | |||
| ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc | |||
| @@ -1,52 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H | |||
| #include <memory> | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class WeightFormatHardCodePass : public GraphPass { | |||
| public: | |||
| WeightFormatHardCodePass() = default; | |||
| ~WeightFormatHardCodePass() override = default; | |||
| void SetQuantType(QuantType quantType); | |||
| void SetFmkType(converter::FmkType fmkType); | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| private: | |||
| STATUS HardCodeCAFFE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| STATUS HardCodeTFLITE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| STATUS HardCodeONNX(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| STATUS HardCodeMS(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor); | |||
| private: | |||
| QuantType quantType = QuantType_QUANT_NONE; | |||
| converter::FmkType fmkType = converter::FmkType_TF; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H | |||
| @@ -1,142 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" | |||
| #include <queue> | |||
| #include "tools/common/node_util.h" | |||
| #include "tools/common/converter_op_utils.h" | |||
| #include "utils/log_adapter.h" | |||
| #include "src/common/utils.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| void WeightFormatTransformPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } | |||
| void WeightFormatTransformPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } | |||
| void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dstFormat = format; } | |||
| STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) { | |||
| auto status = QuantDataFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| } else { | |||
| auto status = NonQuantDataFormatTrans(graph); | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && | |||
| opType != PrimitiveType_DeConv2D && opType != PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = graph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT || | |||
| weightTensor->dataType == DataType_DT_INT8); | |||
| STATUS status; | |||
| if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || | |||
| opType == PrimitiveType_DeConv2D || opType == PrimitiveType_DeDepthwiseConv2D) { // weight should be HWCK | |||
| Format curDstFormat; | |||
| if (this->dstFormat == Format_NUM_OF_FORMAT) { | |||
| curDstFormat = Format_KHWC; | |||
| } else { | |||
| curDstFormat = this->dstFormat; | |||
| } | |||
| status = TransFilterFormat(weightTensor.get(), curDstFormat); | |||
| if (status == RET_OK) { | |||
| weightTensor->format = curDstFormat; | |||
| } else { | |||
| MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat) | |||
| << " failed, node : " << node->name; | |||
| return ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS WeightFormatTransformPass::NonQuantDataFormatTrans(MetaGraphT *graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| for (auto &node : graph->nodes) { | |||
| MS_ASSERT(node != nullptr); | |||
| MS_ASSERT(node->primitive != nullptr); | |||
| auto opType = node->primitive->value.type; | |||
| if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D && | |||
| opType != PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(node->inputIndex.size() >= 2); | |||
| auto weightIndex = node->inputIndex.at(1); | |||
| MS_ASSERT(subGraph->allTensors.size() > weightIndex); | |||
| auto &weightTensor = graph->allTensors[weightIndex]; | |||
| MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); | |||
| STATUS status; | |||
| if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || | |||
| opType == schema::PrimitiveType_DeConv2D) { | |||
| schema::Format curDstFormat; | |||
| if (this->dstFormat == schema::Format::Format_NUM_OF_FORMAT) { | |||
| curDstFormat = schema::Format::Format_KHWC; | |||
| } else { | |||
| curDstFormat = this->dstFormat; | |||
| } | |||
| status = TransFilterFormat(weightTensor.get(), curDstFormat); | |||
| if (status == RET_OK) { | |||
| // node->attr.AsConv2D()->format = schema::Format::Format_NCHW; | |||
| weightTensor->format = curDstFormat; | |||
| } else { | |||
| MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat) | |||
| << " failed, node : " << node->name; | |||
| return ERROR; | |||
| } | |||
| } else { // weight should be CKHW | |||
| schema::Format curDstFormat; | |||
| if (this->dstFormat == schema::Format::Format_NUM_OF_FORMAT) { | |||
| curDstFormat = schema::Format::Format_KHWC; | |||
| } else { | |||
| curDstFormat = this->dstFormat; | |||
| } | |||
| status = TransFilterFormat(weightTensor.get(), curDstFormat); | |||
| if (status == RET_OK) { | |||
| // node->attr.AsDepthwiseConv2D()->format = schema::Format::Format_NCHW; | |||
| weightTensor->format = curDstFormat; | |||
| } else { | |||
| MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat) | |||
| << " failed, node : " << node->name; | |||
| return ERROR; | |||
| } | |||
| } | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -1,53 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H | |||
| #define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H | |||
| #include "tools/converter/optimizer.h" | |||
| #include "tools/common/graph_util.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class WeightFormatTransformPass : public GraphPass { | |||
| public: | |||
| WeightFormatTransformPass() = default; | |||
| ~WeightFormatTransformPass() override = default; | |||
| void SetQuantType(QuantType quantType); | |||
| void SetFmkType(converter::FmkType fmkType); | |||
| void SetDstFormat(schema::Format format); | |||
| STATUS Run(MetaGraphT *graph) override; | |||
| private: | |||
| STATUS QuantDataFormatTrans(MetaGraphT *graph); | |||
| STATUS NonQuantDataFormatTrans(MetaGraphT *graph); | |||
| private: | |||
| QuantType quantType = QuantType_QUANT_NONE; | |||
| converter::FmkType fmkType = converter::FmkType_TF; | |||
| schema::Format dstFormat = schema::Format::Format_NUM_OF_FORMAT; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H | |||
| @@ -18,6 +18,7 @@ | |||
| #include <algorithm> | |||
| #include <utility> | |||
| #include "src/ops/primitive_c.h" | |||
| #include "src/common/common.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "backend/optimizer/common/helper.h" | |||
| @@ -391,7 +392,17 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { | |||
| } | |||
| return schema::PrimitiveType_NONE; | |||
| } | |||
| ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { | |||
| MS_ASSERT(node != nullptr); | |||
| if (!utils::isa<ParameterPtr>(node)) { | |||
| MS_LOG(ERROR) << "get lite param value node must paramter"; | |||
| return nullptr; | |||
| } | |||
| auto param = node->cast<ParameterPtr>(); | |||
| MS_ASSERT(param != nullptr); | |||
| auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param()); | |||
| return param_value; | |||
| } | |||
| bool IsParamNode(const BaseRef &n) { | |||
| if (!utils::isa<ParameterPtr>(n)) { | |||
| return false; | |||
| @@ -542,5 +553,551 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu | |||
| } | |||
| return output_node_list; | |||
| } | |||
| STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | |||
| int32_t *filterH, int32_t *filterW) { | |||
| MS_ASSERT(oriDims.size() == 4); | |||
| if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { | |||
| *filterK = oriDims.at(lite::KCHW_K); | |||
| *filterC = oriDims.at(lite::KCHW_C); | |||
| *filterH = oriDims.at(lite::KCHW_H); | |||
| *filterW = oriDims.at(lite::KCHW_W); | |||
| } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { | |||
| *filterC = oriDims.at(lite::CKHW_C); | |||
| *filterK = oriDims.at(lite::CKHW_K); | |||
| *filterH = oriDims.at(lite::CKHW_H); | |||
| *filterW = oriDims.at(lite::CKHW_W); | |||
| } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { | |||
| *filterH = oriDims.at(lite::HWCK_H); | |||
| *filterW = oriDims.at(lite::HWCK_W); | |||
| *filterC = oriDims.at(lite::HWCK_C); | |||
| *filterK = oriDims.at(lite::HWCK_K); | |||
| } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { | |||
| *filterH = oriDims.at(lite::HWKC_H); | |||
| *filterW = oriDims.at(lite::HWKC_W); | |||
| *filterK = oriDims.at(lite::HWKC_K); | |||
| *filterC = oriDims.at(lite::HWKC_C); | |||
| } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { | |||
| *filterK = oriDims.at(lite::NHWC_N); | |||
| *filterH = oriDims.at(lite::NHWC_H); | |||
| *filterW = oriDims.at(lite::NHWC_W); | |||
| *filterC = oriDims.at(lite::NHWC_C); | |||
| } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { | |||
| *filterC = oriDims.at(lite::CHWK_C); | |||
| *filterH = oriDims.at(lite::CHWK_H); | |||
| *filterW = oriDims.at(lite::CHWK_W); | |||
| *filterK = oriDims.at(lite::CHWK_K); | |||
| } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { | |||
| *filterK = oriDims.at(lite::KHWC_K); | |||
| *filterH = oriDims.at(lite::KHWC_H); | |||
| *filterW = oriDims.at(lite::KHWC_W); | |||
| *filterC = oriDims.at(lite::KHWC_C); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { | |||
| tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); | |||
| } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { | |||
| tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); | |||
| } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { | |||
| tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); | |||
| } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { | |||
| tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); | |||
| } else if (type == kKHWC2CHWK) { | |||
| tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); | |||
| } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { | |||
| tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| template<typename T> | |||
| static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| int count = filterH * filterW * filterC * filterK; | |||
| if (count <= 0) { | |||
| MS_LOG(ERROR) << "Dim size invalid"; | |||
| return RET_ERROR; | |||
| } | |||
| std::unique_ptr<T[]> buf(new(std::nothrow) T[count]); | |||
| if (buf == nullptr) { | |||
| MS_LOG(ERROR) << "new buf failed"; | |||
| return RET_ERROR; | |||
| } | |||
| void *originWeightData = tensor->tensor_addr(); | |||
| T *weightData = static_cast<T *>(originWeightData); | |||
| if (weightData == nullptr) { | |||
| MS_LOG(ERROR) << "weightData is nullptr"; | |||
| return RET_ERROR; | |||
| } | |||
| T *p1Buff = nullptr; | |||
| T *p2Buff = nullptr; | |||
| switch (type) { | |||
| case kCHWK2HWCK: | |||
| case kCHWK2KHWC: { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); | |||
| if (type == kCHWK2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCHWK2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kKHWC2HWCK: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kKCHW2HWCK: | |||
| case kKCHW2CKHW: | |||
| case kKCHW2KHWC: | |||
| case kKCHW2HWKC: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kKCHW2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kKCHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else if (type == kKCHW2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kCKHW2HWCK: | |||
| case kCKHW2KHWC: | |||
| case kCKHW2HWKC: { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| if (type == kCKHW2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kCKHW2KHWC) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kHWCK2KCHW: | |||
| case kHWCK2CKHW: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| if (type == kHWCK2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kHWKC2KCHW: | |||
| case kHWKC2CKHW: { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kHWKC2KCHW) { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kNHWC2HWCK: | |||
| case kNHWC2KCHW: | |||
| case kNHWC2CKHW: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); | |||
| if (type == kNHWC2HWCK) { | |||
| p2Buff = | |||
| buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); | |||
| } else if (type == kNHWC2CKHW) { | |||
| p2Buff = | |||
| buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); | |||
| } else { | |||
| p2Buff = | |||
| buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); | |||
| } | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| case kKHWC2CHWK: { | |||
| for (int k = 0; k < filterK; ++k) { | |||
| for (int h = 0; h < filterH; ++h) { | |||
| for (int w = 0; w < filterW; ++w) { | |||
| for (int c = 0; c < filterC; ++c) { | |||
| p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); | |||
| p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); | |||
| *p2Buff = *p1Buff; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported transFilterType: " << type; | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T)); | |||
| if (ret != EOK) { | |||
| MS_LOG(ERROR) << "memcpy_s failed: " << ret; | |||
| return RET_ERROR; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| template<typename T> | |||
| static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { | |||
| MS_ASSERT(tensor != nullptr); | |||
| auto oriDims = tensor->tensor_shape(); | |||
| if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { | |||
| MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| int32_t filterH; | |||
| int32_t filterW; | |||
| int32_t filterC; | |||
| int32_t filterK; | |||
| auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "GetFilterDim failed: " << status; | |||
| return status; | |||
| } | |||
| status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "SetFilterDim failed: " << status; | |||
| return status; | |||
| } | |||
| status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "TransFilterData failed: " << status; | |||
| return status; | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) { | |||
| if (tensor == nullptr) { | |||
| return lite::RET_NULL_PTR; | |||
| } | |||
| auto ori_dims = tensor->tensor_shape(); | |||
| if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { | |||
| MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| auto src_format = tensor->format(); | |||
| auto data_type = tensor->tensor_type(); | |||
| lite::STATUS status; | |||
| switch (dst_format) { | |||
| case schema::Format::Format_KHWC: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_KCHW: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKCHW2KHWC); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CKHW: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCKHW2KHWC); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CHWK: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCHWK2KHWC); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_KHWC:return RET_OK; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| case schema::Format::Format_HWCK: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_KCHW: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKCHW2HWCK); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_KHWC: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKHWC2HWCK); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CKHW: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCKHW2HWCK); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CHWK: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCHWK2HWCK); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return lite::RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_HWCK:return RET_OK; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| case schema::Format::Format_KCHW: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_KCHW:return RET_OK; | |||
| case schema::Format::Format_HWCK: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWCK2KCHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_HWKC: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWKC2KCHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_KHWC: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKHWC2KCHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CKHW: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCKHW2KCHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CHWK: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kCHWK2KCHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| case schema::Format::Format_CKHW: { | |||
| switch (src_format) { | |||
| case schema::Format::Format_HWCK: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWCK2CKHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_HWKC: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kHWKC2CKHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_KCHW: | |||
| if (data_type == kNumberTypeFloat32) { | |||
| status = TransFilterFormat<float>(tensor, kKCHW2CKHW); | |||
| } else if (data_type == kNumberTypeUInt8) { | |||
| status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW); | |||
| } else if (data_type == kNumberTypeInt8) { | |||
| status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported data_type: " << data_type; | |||
| return RET_ERROR; | |||
| } | |||
| break; | |||
| case schema::Format::Format_CKHW:return RET_OK; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " | |||
| << EnumNameFormat(dst_format); | |||
| return RET_ERROR; | |||
| } | |||
| if (status != RET_OK) { | |||
| MS_LOG(ERROR) << "TransFilterData failed: " << status; | |||
| return status; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -18,6 +18,7 @@ | |||
| #define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "src/ops//primitive_c.h" | |||
| #include "ir/anf.h" | |||
| #include "ir/func_graph.h" | |||
| @@ -28,6 +29,9 @@ | |||
| #include "tools/converter/return_code.h" | |||
| using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>; | |||
| using mindspore::lite::RET_ERROR; | |||
| using mindspore::lite::RET_OK; | |||
| using mindspore::lite::STATUS; | |||
| namespace mindspore { | |||
| namespace opt { | |||
| bool IsRealCNodeKernel(const AnfNodePtr &node); | |||
| @@ -68,6 +72,47 @@ size_t GetOutputTensorNum(const AnfNodePtr &node); | |||
| bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||
| size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); | |||
| ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); | |||
| enum kTransFilterType { | |||
| kKCHW2HWCK, // 0 | |||
| kKCHW2KHWC, | |||
| kCKHW2KHWC, | |||
| kCKHW2HWCK, | |||
| kKCHW2HWKC, | |||
| kCKHW2HWKC, | |||
| kHWCK2KCHW, | |||
| kHWCK2CKHW, | |||
| kHWKC2KCHW, | |||
| kHWKC2CKHW, | |||
| kNHWC2KCHW, // 10 | |||
| kNHWC2CKHW, | |||
| kNHWC2HWCK, | |||
| kKHWC2HWCK, | |||
| kCHWK2HWCK, | |||
| kKHWC2CHWK, | |||
| kCHWK2KHWC, | |||
| kKHWC2KCHW, | |||
| kCKHW2KCHW, | |||
| kCHWK2KCHW, | |||
| kKCHW2CKHW // 20 | |||
| }; | |||
| STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, | |||
| int32_t *filterH, int32_t *filterW); | |||
| STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW); | |||
| template<typename T> | |||
| static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, | |||
| int32_t filterH, int32_t filterW); | |||
| template<typename T> | |||
| static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); | |||
| STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ | |||
| @@ -195,7 +195,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| } | |||
| changed = true; | |||
| auto output_nums = GetOutputTensorNum(input_cnode); | |||
| std::vector<Tensor *> output_tensors{output_nums, new Tensor()}; | |||
| std::vector<Tensor *> output_tensors; | |||
| for (size_t j = 0; j < output_nums; j++) { | |||
| output_tensors.push_back(new Tensor()); | |||
| } | |||
| auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0)); | |||
| if (lite_primitive == nullptr) { | |||
| MS_LOG(ERROR) << "lite_primitive is nullptr"; | |||
| @@ -0,0 +1,215 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/weight_format_hardcode_pass.h" | |||
| #include <memory> | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| using mindspore::lite::converter::FmkType_MS; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::schema::QuantType_AwareTraining; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| } // namespace | |||
| void WeightFormatHardCodePass::SetQuantType(QuantType type) { | |||
| this->quant_type = type; | |||
| } | |||
| void WeightFormatHardCodePass::SetFmkType(FmkType type) { | |||
| this->fmk_type = type; | |||
| } | |||
| lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node, | |||
| const ParamValueLitePtr ¶m_value) const { | |||
| MS_ASSERT(conv_cnode != nullptr); | |||
| MS_ASSERT(param_value != nullptr); | |||
| switch (quant_type) { | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW); | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, | |||
| const ParamValueLitePtr ¶m_value) const { | |||
| MS_ASSERT(conv_cnode != nullptr); | |||
| MS_ASSERT(param_value != nullptr); | |||
| auto op_type = GetCNodeType(conv_node); | |||
| switch (this->quant_type) { | |||
| case QuantType_AwareTraining: { | |||
| // sum up from current onnx quant models | |||
| if (op_type == schema::PrimitiveType_Conv2D) { | |||
| param_value->set_format(schema::Format::Format_KHWC); | |||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_CHWK); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE: { | |||
| // conv (K x C/group x kH x kW) group = 1 | |||
| // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) | |||
| // deconv (C x K/group x kH x kW) group = 1 | |||
| // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) | |||
| if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D | |||
| || op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, | |||
| const ParamValueLitePtr ¶m_value) const { | |||
| MS_ASSERT(conv_cnode != nullptr); | |||
| MS_ASSERT(param_value != nullptr); | |||
| auto op_type = GetCNodeType(conv_node); | |||
| switch (this->quant_type) { | |||
| case QuantType_AwareTraining: { | |||
| if (op_type == schema::PrimitiveType_Conv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| } else { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } | |||
| } | |||
| break; | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE: { | |||
| // sum up from current ms quant models | |||
| if (op_type == schema::PrimitiveType_Conv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_CKHW); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_KCHW); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_node, | |||
| const ParamValueLitePtr ¶m_value) const { | |||
| MS_ASSERT(conv_cnode != nullptr); | |||
| MS_ASSERT(param_value != nullptr); | |||
| auto op_type = GetCNodeType(conv_node); | |||
| switch (this->quant_type) { | |||
| case QuantType_AwareTraining: | |||
| case QuantType_PostTraining: | |||
| case QuantType_WeightQuant: | |||
| case QuantType_QUANT_NONE: { | |||
| if (op_type == schema::PrimitiveType_Conv2D) { | |||
| param_value->set_format(schema::Format::Format_KHWC); | |||
| } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { | |||
| param_value->set_format(schema::Format::Format_CHWK); | |||
| } else if (op_type == schema::PrimitiveType_DeConv2D) { | |||
| param_value->set_format(schema::Format::Format_CHWK); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| break; | |||
| default: { | |||
| MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " | |||
| << conv_node->fullname_with_scope(); | |||
| return lite::RET_ERROR; | |||
| } | |||
| } | |||
| return lite::RET_OK; | |||
| } | |||
| bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNode>(node)) { | |||
| continue; | |||
| } | |||
| auto conv_cnode = node->cast<CNodePtr>(); | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); | |||
| auto weight_node = conv_cnode->input(kConvWeightIndex); | |||
| MS_ASSERT(weight_node != nullptr); | |||
| auto param_value = GetLiteParamValue(weight_node); | |||
| if (param_value == nullptr) { | |||
| MS_LOG(ERROR) << "weight node must param value"; | |||
| return false; | |||
| } | |||
| lite::STATUS status; | |||
| switch (fmk_type) { | |||
| case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value); | |||
| break; | |||
| case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value); | |||
| break; | |||
| case FmkType_ONNX:status = HardCodeONNX(node, param_value); | |||
| break; | |||
| case FmkType_MS:status = HardCodeMS(node, param_value); | |||
| break; | |||
| default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope(); | |||
| return false; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| #include "src/param_value_lite.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| using mindspore::schema::QuantType; | |||
| namespace mindspore::opt { | |||
| class WeightFormatHardCodePass : public Pass { | |||
| public: | |||
| WeightFormatHardCodePass() : Pass("weight_format_hardcode_pass") {} | |||
| ~WeightFormatHardCodePass() override = default; | |||
| void SetQuantType(QuantType type); | |||
| void SetFmkType(FmkType fmkType); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| lite::STATUS HardCodeCAFFE(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; | |||
| lite::STATUS HardCodeONNX(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; | |||
| lite::STATUS HardCodeMS(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; | |||
| lite::STATUS HardCodeTFLITE(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; | |||
| private: | |||
| QuantType quant_type = schema::QuantType_QUANT_NONE; | |||
| FmkType fmk_type = lite::converter::FmkType_TF; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ | |||
| @@ -0,0 +1,96 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tools/optimizer/graph/weight_format_transform_pass.h" | |||
| #include <memory> | |||
| #include "tools/optimizer/common/gllo_utils.h" | |||
| using mindspore::lite::converter::FmkType_CAFFE; | |||
| using mindspore::lite::converter::FmkType_TFLITE; | |||
| using mindspore::lite::converter::FmkType_ONNX; | |||
| using mindspore::lite::converter::FmkType_MS; | |||
| using mindspore::schema::QuantType_WeightQuant; | |||
| using mindspore::schema::QuantType_QUANT_NONE; | |||
| using mindspore::schema::QuantType_AwareTraining; | |||
| using mindspore::schema::QuantType_PostTraining; | |||
| namespace mindspore::opt { | |||
| namespace { | |||
| constexpr size_t kConvWeightIndex = 2; | |||
| } // namespace | |||
| void WeightFormatTransformPass::SetQuantType(QuantType type) { | |||
| this->quant_type = type; | |||
| } | |||
| void WeightFormatTransformPass::SetFmkType(FmkType type) { | |||
| this->fmk_type = type; | |||
| } | |||
| void WeightFormatTransformPass::SetDstFormat(schema::Format format) { | |||
| this->dst_format = format; | |||
| } | |||
| lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) { | |||
| MS_ASSERT(graph != nullptr); | |||
| auto node_list = TopoSort(graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (!utils::isa<CNodePtr>(node)) { | |||
| continue; | |||
| } | |||
| auto type = opt::GetCNodeType(node); | |||
| if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D | |||
| && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { | |||
| continue; | |||
| } | |||
| auto conv_cnode = node->cast<CNodePtr>(); | |||
| MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); | |||
| auto weight_node = conv_cnode->input(kConvWeightIndex); | |||
| MS_ASSERT(weight_node != nullptr); | |||
| auto weight_value = GetLiteParamValue(weight_node); | |||
| if (weight_value == nullptr) { | |||
| MS_LOG(ERROR) << "weight node must param value"; | |||
| return false; | |||
| } | |||
| MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 | |||
| || weight_value->tensor_type() == TypeId::kNumberTypeUInt8); | |||
| lite::STATUS status; | |||
| schema::Format weight_dst_format = schema::Format::Format_KHWC; | |||
| if (dst_format != schema::Format::Format_NUM_OF_FORMAT) { | |||
| weight_dst_format = dst_format; | |||
| } | |||
| status = TransFilterFormat(weight_value, weight_dst_format); | |||
| if (status == RET_OK) { | |||
| weight_value->set_format(weight_dst_format); | |||
| } else { | |||
| MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To" | |||
| << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope() | |||
| << "quant type:" << quant_type; | |||
| return ERROR; | |||
| } | |||
| auto type_id = static_cast<TypeId>(weight_value->tensor_type()); | |||
| auto type_ptr = TypeIdToType(type_id); | |||
| auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, weight_value->tensor_shape()); | |||
| weight_node->set_abstract(abstract_tensor); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) { | |||
| MS_ASSERT(func_graph != nullptr); | |||
| auto status = ConvWeightFormatTrans(func_graph); | |||
| if (status != lite::RET_OK) { | |||
| MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status; | |||
| return status; | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace mindspore::opt | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_ | |||
| #define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_ | |||
| #include <string> | |||
| #include "schema/inner/model_generated.h" | |||
| #include "tools/converter/converter_flags.h" | |||
| #include "backend/optimizer/common/pass.h" | |||
| using mindspore::lite::converter::FmkType; | |||
| using mindspore::schema::QuantType; | |||
| namespace mindspore::opt { | |||
| class WeightFormatTransformPass : public Pass { | |||
| public: | |||
| WeightFormatTransformPass() : Pass("weight_format_transform_pass") {} | |||
| ~WeightFormatTransformPass() override = default; | |||
| void SetQuantType(QuantType type); | |||
| void SetFmkType(FmkType fmkType); | |||
| void SetDstFormat(schema::Format format); | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| private: | |||
| lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph); | |||
| private: | |||
| QuantType quant_type = schema::QuantType_QUANT_NONE; | |||
| FmkType fmk_type = lite::converter::FmkType_TF; | |||
| schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT; | |||
| }; | |||
| } // namespace mindspore::opt | |||
| #endif // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_ | |||