diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 1dccb681a3..27bf01e8c0 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -120,6 +120,8 @@ void ConvertConvWeight(const ParameterPtr ¶m_node) { utils::cast(abstract_tensor->BuildShape())->shape()[1] = filter_k; utils::cast(abstract_tensor->BuildShape())->shape()[2] = filter_h; utils::cast(abstract_tensor->BuildShape())->shape()[3] = filter_w; + weight->set_tensor_shape({static_cast(filter_c), static_cast(filter_k), static_cast(filter_h), + static_cast(filter_w)}); } return; } @@ -250,7 +252,12 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inp MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type; return RET_ERROR; } - int group = GetValue(prim.GetAttr("group")); + auto groupAttr = prim.GetAttr("group"); + if (groupAttr == nullptr) { + MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model"; + return RET_NULL_PTR; + } + int group = GetValue(groupAttr); if (group > 1) { PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); } else { diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 039e597c13..597abb7555 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -205,6 +205,8 @@ if(BUILD_CONVERTER) ${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc ${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc + ${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc + ${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc ) endif() ### train diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 6adb8ab486..0f4dbca099 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -63,6 +63,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/fusion/conv_bn_fusion.cc ../optimizer/fusion/constant_folding_fusion.cc ../optimizer/fusion/quant_dtype_cast_fusion.cc + ../optimizer/graph/weight_format_transform_pass.cc + ../optimizer/graph/weight_format_hardcode_pass.cc ) add_subdirectory(../anf_importer anf_importer) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 6f52a97a41..ac8b1b2212 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -25,6 +25,8 @@ #include "tools/optimizer/fusion/conv_bn_fusion.h" #include "tools/optimizer/fusion/constant_folding_fusion.h" #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h" +#include "tools/optimizer/graph/weight_format_hardcode_pass.h" +#include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/converter/quantizer/post_training_quantizer.h" #include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/weight_quantizer.h" @@ -41,6 +43,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver // fusion const_fold auto optimizer = std::make_shared(); auto pm = std::make_shared("anf fusion pass manager", false); + auto graph_pm = std::make_shared("anf graph pass manager", true); // for now - trainning is not supporting fuse operations if (config != nullptr && config->trainModel == false) { @@ -61,11 +64,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver pm->AddPass(std::make_shared(true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6)); + auto weight_format_hardcode_pass = std::make_shared(); + weight_format_hardcode_pass->SetFmkType(config->fmk); + weight_format_hardcode_pass->SetQuantType(config->quantType); + graph_pm->AddPass(weight_format_hardcode_pass); + auto weight_format_transform_pass = std::make_shared(); + weight_format_transform_pass->SetFmkType(config->fmk); + weight_format_transform_pass->SetQuantType(config->quantType); + graph_pm->AddPass(weight_format_transform_pass); } pm->AddPass(std::make_shared()); optimizer->AddPassManager(pm); - FuncGraphPtr new_graph = optimizer->Optimize(old_graph); + optimizer->AddPassManager(graph_pm); + auto new_graph = optimizer->Optimize(old_graph); if (new_graph == nullptr) { ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR); return nullptr; diff --git a/mindspore/lite/tools/converter/graphdef_transform.cc b/mindspore/lite/tools/converter/graphdef_transform.cc index 1f9ec8f689..97c1f71650 100644 --- a/mindspore/lite/tools/converter/graphdef_transform.cc +++ b/mindspore/lite/tools/converter/graphdef_transform.cc @@ -28,8 +28,6 @@ #include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h" #include "tools/converter/legacy_optimizer/graph/infershape_pass.h" #include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h" -#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h" -#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" #include "tools/converter/legacy_optimizer/graph/format_trans_pass.h" #include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h" #include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h" @@ -62,23 +60,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) { int GraphDefTransform::Transform(const converter::Flags &ctx) { STATUS status; - { - Optimizer weightFormatOptimizer; - auto weightHardCodePass = new WeightFormatHardCodePass(); - auto weightFormatPass = new WeightFormatTransformPass(); - weightHardCodePass->SetQuantType(ctx.quantType); - weightHardCodePass->SetFmkType(ctx.fmk); - weightFormatPass->SetQuantType(ctx.quantType); - weightFormatPass->SetFmkType(ctx.fmk); - weightFormatOptimizer.AddPass(weightHardCodePass); - weightFormatOptimizer.AddPass(weightFormatPass); - status = weightFormatOptimizer.Run(graphDefT); - if (status != RET_OK && status != RET_NO_CHANGE) { - MS_LOG(ERROR) << "Run weightFormatOptimizer graphPasses Failed"; - return status; - } - } - { Optimizer unusedOpRemoveOptimizer; unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass()); @@ -149,6 +130,8 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) { formatTransOptimizer.AddPass(formatTransPass); formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass()); formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass()); + formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); + formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass()); formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass()); formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass()); formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass()); diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc index 8c3bfd495b..fe9ee557e3 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc @@ -79,7 +79,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); MS_ASSERT(mulNodeBiasTensor != nullptr); - if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) { + if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) { // dont fusion, return return RET_OK; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index 919b03265c..e084f6af14 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -4,8 +4,6 @@ add_library(graph_pass_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_hardcode_pass.cc - ${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h deleted file mode 100644 index 9abdc79efb..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H - -#include -#include "tools/converter/converter_flags.h" -#include "tools/converter/optimizer.h" -#include "tools/common/graph_util.h" - -namespace mindspore { -namespace lite { -class WeightFormatHardCodePass : public GraphPass { - public: - WeightFormatHardCodePass() = default; - - ~WeightFormatHardCodePass() override = default; - - void SetQuantType(QuantType quantType); - - void SetFmkType(converter::FmkType fmkType); - - STATUS Run(MetaGraphT *graph) override; - - private: - STATUS HardCodeCAFFE(const std::unique_ptr &node, const std::unique_ptr &weightTensor); - STATUS HardCodeTFLITE(const std::unique_ptr &node, const std::unique_ptr &weightTensor); - STATUS HardCodeONNX(const std::unique_ptr &node, const std::unique_ptr &weightTensor); - STATUS HardCodeMS(const std::unique_ptr &node, const std::unique_ptr &weightTensor); - - private: - QuantType quantType = QuantType_QUANT_NONE; - converter::FmkType fmkType = converter::FmkType_TF; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc deleted file mode 100644 index 0cb13c5cd4..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc +++ /dev/null @@ -1,142 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h" -#include -#include "tools/common/node_util.h" -#include "tools/common/converter_op_utils.h" -#include "utils/log_adapter.h" -#include "src/common/utils.h" - -namespace mindspore { -namespace lite { -void WeightFormatTransformPass::SetQuantType(QuantType quantType) { this->quantType = quantType; } - -void WeightFormatTransformPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; } - -void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dstFormat = format; } - -STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) { - auto status = QuantDataFormatTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status; - return status; - } - } else { - auto status = NonQuantDataFormatTrans(graph); - if (status != RET_OK) { - MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status; - return status; - } - } - return RET_OK; -} - -STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - for (auto &node : graph->nodes) { - MS_ASSERT(node != nullptr); - MS_ASSERT(node->primitive != nullptr); - auto opType = node->primitive->value.type; - if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && - opType != PrimitiveType_DeConv2D && opType != PrimitiveType_DeDepthwiseConv2D) { - continue; - } - MS_ASSERT(node->inputIndex.size() >= 2); - auto weightIndex = node->inputIndex.at(1); - MS_ASSERT(subGraph->allTensors.size() > weightIndex); - auto &weightTensor = graph->allTensors[weightIndex]; - MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT || - weightTensor->dataType == DataType_DT_INT8); - STATUS status; - if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || - opType == PrimitiveType_DeConv2D || opType == PrimitiveType_DeDepthwiseConv2D) { // weight should be HWCK - Format curDstFormat; - if (this->dstFormat == Format_NUM_OF_FORMAT) { - curDstFormat = Format_KHWC; - } else { - curDstFormat = this->dstFormat; - } - status = TransFilterFormat(weightTensor.get(), curDstFormat); - if (status == RET_OK) { - weightTensor->format = curDstFormat; - } else { - MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat) - << " failed, node : " << node->name; - return ERROR; - } - } - } - return RET_OK; -} - -STATUS WeightFormatTransformPass::NonQuantDataFormatTrans(MetaGraphT *graph) { - MS_ASSERT(graph != nullptr); - for (auto &node : graph->nodes) { - MS_ASSERT(node != nullptr); - MS_ASSERT(node->primitive != nullptr); - auto opType = node->primitive->value.type; - if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D && - opType != PrimitiveType_DeDepthwiseConv2D) { - continue; - } - MS_ASSERT(node->inputIndex.size() >= 2); - auto weightIndex = node->inputIndex.at(1); - MS_ASSERT(subGraph->allTensors.size() > weightIndex); - auto &weightTensor = graph->allTensors[weightIndex]; - MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT); - STATUS status; - if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D || - opType == schema::PrimitiveType_DeConv2D) { - schema::Format curDstFormat; - if (this->dstFormat == schema::Format::Format_NUM_OF_FORMAT) { - curDstFormat = schema::Format::Format_KHWC; - } else { - curDstFormat = this->dstFormat; - } - status = TransFilterFormat(weightTensor.get(), curDstFormat); - if (status == RET_OK) { - // node->attr.AsConv2D()->format = schema::Format::Format_NCHW; - weightTensor->format = curDstFormat; - } else { - MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat) - << " failed, node : " << node->name; - return ERROR; - } - } else { // weight should be CKHW - schema::Format curDstFormat; - if (this->dstFormat == schema::Format::Format_NUM_OF_FORMAT) { - curDstFormat = schema::Format::Format_KHWC; - } else { - curDstFormat = this->dstFormat; - } - status = TransFilterFormat(weightTensor.get(), curDstFormat); - if (status == RET_OK) { - // node->attr.AsDepthwiseConv2D()->format = schema::Format::Format_NCHW; - weightTensor->format = curDstFormat; - } else { - MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat) - << " failed, node : " << node->name; - return ERROR; - } - } - } - return RET_OK; -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h deleted file mode 100644 index 5d582ca784..0000000000 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H - -#include "tools/converter/optimizer.h" -#include "tools/common/graph_util.h" -#include "tools/converter/converter_flags.h" - -namespace mindspore { -namespace lite { -class WeightFormatTransformPass : public GraphPass { - public: - WeightFormatTransformPass() = default; - - ~WeightFormatTransformPass() override = default; - - void SetQuantType(QuantType quantType); - - void SetFmkType(converter::FmkType fmkType); - - void SetDstFormat(schema::Format format); - - STATUS Run(MetaGraphT *graph) override; - - private: - STATUS QuantDataFormatTrans(MetaGraphT *graph); - - STATUS NonQuantDataFormatTrans(MetaGraphT *graph); - - private: - QuantType quantType = QuantType_QUANT_NONE; - converter::FmkType fmkType = converter::FmkType_TF; - schema::Format dstFormat = schema::Format::Format_NUM_OF_FORMAT; -}; -} // namespace lite -} // namespace mindspore - -#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.cc b/mindspore/lite/tools/optimizer/common/gllo_utils.cc index a2b9e44544..0c81e0c780 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.cc +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.cc @@ -18,6 +18,7 @@ #include #include #include "src/ops/primitive_c.h" +#include "src/common/common.h" #include "frontend/operator/ops.h" #include "backend/optimizer/common/helper.h" @@ -391,7 +392,17 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) { } return schema::PrimitiveType_NONE; } - +ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) { + MS_ASSERT(node != nullptr); + if (!utils::isa(node)) { + MS_LOG(ERROR) << "get lite param value node must paramter"; + return nullptr; + } + auto param = node->cast(); + MS_ASSERT(param != nullptr); + auto param_value = std::dynamic_pointer_cast(param->default_param()); + return param_value; +} bool IsParamNode(const BaseRef &n) { if (!utils::isa(n)) { return false; @@ -542,5 +553,551 @@ std::shared_ptr>> GetRealNodeUsedListByOu } return output_node_list; } +STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, + int32_t *filterH, int32_t *filterW) { + MS_ASSERT(oriDims.size() == 4); + if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) { + *filterK = oriDims.at(lite::KCHW_K); + *filterC = oriDims.at(lite::KCHW_C); + *filterH = oriDims.at(lite::KCHW_H); + *filterW = oriDims.at(lite::KCHW_W); + } else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) { + *filterC = oriDims.at(lite::CKHW_C); + *filterK = oriDims.at(lite::CKHW_K); + *filterH = oriDims.at(lite::CKHW_H); + *filterW = oriDims.at(lite::CKHW_W); + } else if (type == kHWCK2KCHW || type == kHWCK2CKHW) { + *filterH = oriDims.at(lite::HWCK_H); + *filterW = oriDims.at(lite::HWCK_W); + *filterC = oriDims.at(lite::HWCK_C); + *filterK = oriDims.at(lite::HWCK_K); + } else if (type == kHWKC2KCHW || type == kHWKC2CKHW) { + *filterH = oriDims.at(lite::HWKC_H); + *filterW = oriDims.at(lite::HWKC_W); + *filterK = oriDims.at(lite::HWKC_K); + *filterC = oriDims.at(lite::HWKC_C); + } else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) { + *filterK = oriDims.at(lite::NHWC_N); + *filterH = oriDims.at(lite::NHWC_H); + *filterW = oriDims.at(lite::NHWC_W); + *filterC = oriDims.at(lite::NHWC_C); + } else if (type == kCHWK2HWCK || type == kCHWK2KHWC) { + *filterC = oriDims.at(lite::CHWK_C); + *filterH = oriDims.at(lite::CHWK_H); + *filterW = oriDims.at(lite::CHWK_W); + *filterK = oriDims.at(lite::CHWK_K); + } else if (type == kKHWC2HWCK || type == kKHWC2CHWK) { + *filterK = oriDims.at(lite::KHWC_K); + *filterH = oriDims.at(lite::KHWC_H); + *filterW = oriDims.at(lite::KHWC_W); + *filterC = oriDims.at(lite::KHWC_C); + } else { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + return RET_OK; +} + +STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW) { + MS_ASSERT(tensor != nullptr); + if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) { + tensor->set_tensor_shape({filterH, filterW, filterC, filterK}); + } else if (type == kKCHW2HWKC || type == kCKHW2HWKC) { + tensor->set_tensor_shape({filterH, filterW, filterK, filterC}); + } else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) { + tensor->set_tensor_shape({filterK, filterC, filterH, filterW}); + } else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) { + tensor->set_tensor_shape({filterC, filterK, filterH, filterW}); + } else if (type == kKHWC2CHWK) { + tensor->set_tensor_shape({filterC, filterH, filterW, filterK}); + } else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) { + tensor->set_tensor_shape({filterK, filterH, filterW, filterC}); + } else { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + return RET_OK; +} +template +static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW) { + MS_ASSERT(tensor != nullptr); + int count = filterH * filterW * filterC * filterK; + if (count <= 0) { + MS_LOG(ERROR) << "Dim size invalid"; + return RET_ERROR; + } + std::unique_ptr buf(new(std::nothrow) T[count]); + if (buf == nullptr) { + MS_LOG(ERROR) << "new buf failed"; + return RET_ERROR; + } + + void *originWeightData = tensor->tensor_addr(); + T *weightData = static_cast(originWeightData); + + if (weightData == nullptr) { + MS_LOG(ERROR) << "weightData is nullptr"; + return RET_ERROR; + } + T *p1Buff = nullptr; + T *p2Buff = nullptr; + switch (type) { + case kCHWK2HWCK: + case kCHWK2KHWC: { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); + if (type == kCHWK2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCHWK2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kKHWC2HWCK: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kKCHW2HWCK: + case kKCHW2CKHW: + case kKCHW2KHWC: + case kKCHW2HWKC: { + for (int k = 0; k < filterK; ++k) { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + if (type == kKCHW2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else if (type == kKCHW2CKHW) { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kCKHW2HWCK: + case kCKHW2KHWC: + case kCKHW2HWKC: { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + if (type == kCKHW2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCKHW2KHWC) { + p2Buff = + buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else { + p2Buff = + buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kHWCK2KCHW: + case kHWCK2CKHW: { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + if (type == kHWCK2KCHW) { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kHWKC2KCHW: + case kHWKC2CKHW: { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kHWKC2KCHW) { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kNHWC2HWCK: + case kNHWC2KCHW: + case kNHWC2CKHW: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kNHWC2HWCK) { + p2Buff = + buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kNHWC2CKHW) { + p2Buff = + buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = + buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } + } + break; + case kKHWC2CHWK: { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } + } + break; + default: { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + } + + auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return RET_ERROR; + } + return RET_OK; +} + +template +static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) { + MS_ASSERT(tensor != nullptr); + auto oriDims = tensor->tensor_shape(); + if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { + MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size(); + return lite::RET_ERROR; + } + + int32_t filterH; + int32_t filterW; + int32_t filterC; + int32_t filterK; + auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "GetFilterDim failed: " << status; + return status; + } + status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "SetFilterDim failed: " << status; + return status; + } + status = TransFilterData(tensor, type, filterK, filterC, filterH, filterW); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "TransFilterData failed: " << status; + return status; + } + + return lite::RET_OK; +} + +STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) { + if (tensor == nullptr) { + return lite::RET_NULL_PTR; + } + auto ori_dims = tensor->tensor_shape(); + if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) { + MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size(); + return lite::RET_ERROR; + } + auto src_format = tensor->format(); + auto data_type = tensor->tensor_type(); + lite::STATUS status; + switch (dst_format) { + case schema::Format::Format_KHWC: { + switch (src_format) { + case schema::Format::Format_KCHW: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKCHW2KHWC); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKCHW2KHWC); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKCHW2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CKHW: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCKHW2KHWC); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCKHW2KHWC); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCKHW2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CHWK: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCHWK2KHWC); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCHWK2KHWC); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCHWK2KHWC); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_KHWC:return RET_OK; + default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " + << EnumNameFormat(dst_format); + return RET_ERROR; + } + } + break; + case schema::Format::Format_HWCK: { + switch (src_format) { + case schema::Format::Format_KCHW: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKCHW2HWCK); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKCHW2HWCK); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKCHW2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_KHWC: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKHWC2HWCK); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKHWC2HWCK); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKHWC2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CKHW: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCKHW2HWCK); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCKHW2HWCK); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCKHW2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CHWK: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCHWK2HWCK); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCHWK2HWCK); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCHWK2HWCK); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return lite::RET_ERROR; + } + break; + case schema::Format::Format_HWCK:return RET_OK; + default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " + << EnumNameFormat(dst_format); + return RET_ERROR; + } + } + break; + case schema::Format::Format_KCHW: { + switch (src_format) { + case schema::Format::Format_KCHW:return RET_OK; + case schema::Format::Format_HWCK: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWCK2KCHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWCK2KCHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWCK2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_HWKC: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWKC2KCHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWKC2KCHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWKC2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_KHWC: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKHWC2KCHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKHWC2KCHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKHWC2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CKHW: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCKHW2KCHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCKHW2KCHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCKHW2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CHWK: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kCHWK2KCHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kCHWK2KCHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kCHWK2KCHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " + << EnumNameFormat(dst_format); + return RET_ERROR; + } + } + break; + case schema::Format::Format_CKHW: { + switch (src_format) { + case schema::Format::Format_HWCK: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWCK2CKHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWCK2CKHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWCK2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_HWKC: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kHWKC2CKHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kHWKC2CKHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kHWKC2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_KCHW: + if (data_type == kNumberTypeFloat32) { + status = TransFilterFormat(tensor, kKCHW2CKHW); + } else if (data_type == kNumberTypeUInt8) { + status = TransFilterFormat(tensor, kKCHW2CKHW); + } else if (data_type == kNumberTypeInt8) { + status = TransFilterFormat(tensor, kKCHW2CKHW); + } else { + MS_LOG(ERROR) << "Unsupported data_type: " << data_type; + return RET_ERROR; + } + break; + case schema::Format::Format_CKHW:return RET_OK; + default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " + << EnumNameFormat(dst_format); + return RET_ERROR; + } + } + break; + default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to " + << EnumNameFormat(dst_format); + return RET_ERROR; + } + if (status != RET_OK) { + MS_LOG(ERROR) << "TransFilterData failed: " << status; + return status; + } + return RET_OK; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/lite/tools/optimizer/common/gllo_utils.h b/mindspore/lite/tools/optimizer/common/gllo_utils.h index 4034a80c4f..1e6a2fbaf1 100644 --- a/mindspore/lite/tools/optimizer/common/gllo_utils.h +++ b/mindspore/lite/tools/optimizer/common/gllo_utils.h @@ -18,6 +18,7 @@ #define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ #include +#include #include "src/ops//primitive_c.h" #include "ir/anf.h" #include "ir/func_graph.h" @@ -28,6 +29,9 @@ #include "tools/converter/return_code.h" using PrimitiveCPtr = std::shared_ptr; +using mindspore::lite::RET_ERROR; +using mindspore::lite::RET_OK; +using mindspore::lite::STATUS; namespace mindspore { namespace opt { bool IsRealCNodeKernel(const AnfNodePtr &node); @@ -68,6 +72,47 @@ size_t GetOutputTensorNum(const AnfNodePtr &node); bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node); size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); + +ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node); + +enum kTransFilterType { + kKCHW2HWCK, // 0 + kKCHW2KHWC, + kCKHW2KHWC, + kCKHW2HWCK, + kKCHW2HWKC, + kCKHW2HWKC, + kHWCK2KCHW, + kHWCK2CKHW, + kHWKC2KCHW, + kHWKC2CKHW, + kNHWC2KCHW, // 10 + kNHWC2CKHW, + kNHWC2HWCK, + kKHWC2HWCK, + kCHWK2HWCK, + kKHWC2CHWK, + kCHWK2KHWC, + kKHWC2KCHW, + kCKHW2KCHW, + kCHWK2KCHW, + kKCHW2CKHW // 20 +}; + +STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC, + int32_t *filterH, int32_t *filterW); + +STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW); + +template +static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC, + int32_t filterH, int32_t filterW); + +template +static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type); + +STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format); } // namespace opt } // namespace mindspore #endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_ diff --git a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc index ffb1273711..a116b2667b 100644 --- a/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc @@ -195,7 +195,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An } changed = true; auto output_nums = GetOutputTensorNum(input_cnode); - std::vector output_tensors{output_nums, new Tensor()}; + std::vector output_tensors; + for (size_t j = 0; j < output_nums; j++) { + output_tensors.push_back(new Tensor()); + } auto lite_primitive = GetValueNode>(input_cnode->input(0)); if (lite_primitive == nullptr) { MS_LOG(ERROR) << "lite_primitive is nullptr"; diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc new file mode 100644 index 0000000000..30b47f4395 --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -0,0 +1,215 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/weight_format_hardcode_pass.h" +#include +#include "tools/optimizer/common/gllo_utils.h" + +using mindspore::lite::converter::FmkType_CAFFE; +using mindspore::lite::converter::FmkType_TFLITE; +using mindspore::lite::converter::FmkType_ONNX; +using mindspore::lite::converter::FmkType_MS; +using mindspore::schema::QuantType_WeightQuant; +using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_AwareTraining; +using mindspore::schema::QuantType_PostTraining; +namespace mindspore::opt { +namespace { +constexpr size_t kConvWeightIndex = 2; +} // namespace +void WeightFormatHardCodePass::SetQuantType(QuantType type) { + this->quant_type = type; +} +void WeightFormatHardCodePass::SetFmkType(FmkType type) { + this->fmk_type = type; +} +lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node, + const ParamValueLitePtr ¶m_value) const { + MS_ASSERT(conv_cnode != nullptr); + MS_ASSERT(param_value != nullptr); + switch (quant_type) { + case QuantType_WeightQuant: + case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW); + break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, + const ParamValueLitePtr ¶m_value) const { + MS_ASSERT(conv_cnode != nullptr); + MS_ASSERT(param_value != nullptr); + auto op_type = GetCNodeType(conv_node); + switch (this->quant_type) { + case QuantType_AwareTraining: { + // sum up from current onnx quant models + if (op_type == schema::PrimitiveType_Conv2D) { + param_value->set_format(schema::Format::Format_KHWC); + } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { + param_value->set_format(schema::Format::Format_CHWK); + } else if (op_type == schema::PrimitiveType_DeConv2D) { + param_value->set_format(schema::Format::Format_KCHW); + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + break; + case QuantType_WeightQuant: + case QuantType_QUANT_NONE: { + // conv (K x C/group x kH x kW) group = 1 + // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) + // deconv (C x K/group x kH x kW) group = 1 + // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) + if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D + || op_type == schema::PrimitiveType_DeConv2D) { + param_value->set_format(schema::Format::Format_KCHW); + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, + const ParamValueLitePtr ¶m_value) const { + MS_ASSERT(conv_cnode != nullptr); + MS_ASSERT(param_value != nullptr); + auto op_type = GetCNodeType(conv_node); + switch (this->quant_type) { + case QuantType_AwareTraining: { + if (op_type == schema::PrimitiveType_Conv2D) { + param_value->set_format(schema::Format::Format_KCHW); + } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { + param_value->set_format(schema::Format::Format_CKHW); + } else { + param_value->set_format(schema::Format::Format_KCHW); + } + } + break; + case QuantType_WeightQuant: + case QuantType_QUANT_NONE: { + // sum up from current ms quant models + if (op_type == schema::PrimitiveType_Conv2D) { + param_value->set_format(schema::Format::Format_KCHW); + } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { + param_value->set_format(schema::Format::Format_CKHW); + } else if (op_type == schema::PrimitiveType_DeConv2D) { + param_value->set_format(schema::Format::Format_KCHW); + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + break; + default: { + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_node, + const ParamValueLitePtr ¶m_value) const { + MS_ASSERT(conv_cnode != nullptr); + MS_ASSERT(param_value != nullptr); + auto op_type = GetCNodeType(conv_node); + switch (this->quant_type) { + case QuantType_AwareTraining: + case QuantType_PostTraining: + case QuantType_WeightQuant: + case QuantType_QUANT_NONE: { + if (op_type == schema::PrimitiveType_Conv2D) { + param_value->set_format(schema::Format::Format_KHWC); + } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { + param_value->set_format(schema::Format::Format_CHWK); + } else if (op_type == schema::PrimitiveType_DeConv2D) { + param_value->set_format(schema::Format::Format_CHWK); + } else { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + break; + default: { + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " + << conv_node->fullname_with_scope(); + return lite::RET_ERROR; + } + } + return lite::RET_OK; +} + +bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { + MS_ASSERT(graph != nullptr); + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto conv_cnode = node->cast(); + auto type = opt::GetCNodeType(node); + if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D + && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { + continue; + } + MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); + auto weight_node = conv_cnode->input(kConvWeightIndex); + MS_ASSERT(weight_node != nullptr); + auto param_value = GetLiteParamValue(weight_node); + if (param_value == nullptr) { + MS_LOG(ERROR) << "weight node must param value"; + return false; + } + lite::STATUS status; + switch (fmk_type) { + case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value); + break; + case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value); + break; + case FmkType_ONNX:status = HardCodeONNX(node, param_value); + break; + case FmkType_MS:status = HardCodeMS(node, param_value); + break; + default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); + return false; + } + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope(); + return false; + } + } + return false; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h new file mode 100644 index 0000000000..f9d698dd7e --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" +#include "src/param_value_lite.h" + +using mindspore::lite::converter::FmkType; +using mindspore::schema::QuantType; +namespace mindspore::opt { +class WeightFormatHardCodePass : public Pass { + public: + WeightFormatHardCodePass() : Pass("weight_format_hardcode_pass") {} + ~WeightFormatHardCodePass() override = default; + void SetQuantType(QuantType type); + void SetFmkType(FmkType fmkType); + bool Run(const FuncGraphPtr &graph) override; + + private: + lite::STATUS HardCodeCAFFE(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; + lite::STATUS HardCodeONNX(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; + lite::STATUS HardCodeMS(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; + lite::STATUS HardCodeTFLITE(const AnfNodePtr &node, const ParamValueLitePtr ¶m_value) const; + + private: + QuantType quant_type = schema::QuantType_QUANT_NONE; + FmkType fmk_type = lite::converter::FmkType_TF; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_ diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc new file mode 100644 index 0000000000..1f7df2ae4f --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/optimizer/graph/weight_format_transform_pass.h" +#include +#include "tools/optimizer/common/gllo_utils.h" + +using mindspore::lite::converter::FmkType_CAFFE; +using mindspore::lite::converter::FmkType_TFLITE; +using mindspore::lite::converter::FmkType_ONNX; +using mindspore::lite::converter::FmkType_MS; +using mindspore::schema::QuantType_WeightQuant; +using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_AwareTraining; +using mindspore::schema::QuantType_PostTraining; + +namespace mindspore::opt { +namespace { +constexpr size_t kConvWeightIndex = 2; +} // namespace +void WeightFormatTransformPass::SetQuantType(QuantType type) { + this->quant_type = type; +} +void WeightFormatTransformPass::SetFmkType(FmkType type) { + this->fmk_type = type; +} +void WeightFormatTransformPass::SetDstFormat(schema::Format format) { + this->dst_format = format; +} +lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) { + MS_ASSERT(graph != nullptr); + auto node_list = TopoSort(graph->get_return()); + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto type = opt::GetCNodeType(node); + if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D + && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { + continue; + } + auto conv_cnode = node->cast(); + MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); + auto weight_node = conv_cnode->input(kConvWeightIndex); + MS_ASSERT(weight_node != nullptr); + auto weight_value = GetLiteParamValue(weight_node); + if (weight_value == nullptr) { + MS_LOG(ERROR) << "weight node must param value"; + return false; + } + MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32 + || weight_value->tensor_type() == TypeId::kNumberTypeUInt8); + lite::STATUS status; + schema::Format weight_dst_format = schema::Format::Format_KHWC; + if (dst_format != schema::Format::Format_NUM_OF_FORMAT) { + weight_dst_format = dst_format; + } + status = TransFilterFormat(weight_value, weight_dst_format); + if (status == RET_OK) { + weight_value->set_format(weight_dst_format); + } else { + MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To" + << EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope() + << "quant type:" << quant_type; + return ERROR; + } + auto type_id = static_cast(weight_value->tensor_type()); + auto type_ptr = TypeIdToType(type_id); + auto abstract_tensor = std::make_shared(type_ptr, weight_value->tensor_shape()); + weight_node->set_abstract(abstract_tensor); + } + return RET_OK; +} + +bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto status = ConvWeightFormatTrans(func_graph); + if (status != lite::RET_OK) { + MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status; + return status; + } + return false; +} +} // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.h b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.h new file mode 100644 index 0000000000..78992f4ddd --- /dev/null +++ b/mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.h @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_ +#include +#include "schema/inner/model_generated.h" +#include "tools/converter/converter_flags.h" +#include "backend/optimizer/common/pass.h" + +using mindspore::lite::converter::FmkType; +using mindspore::schema::QuantType; +namespace mindspore::opt { +class WeightFormatTransformPass : public Pass { + public: + WeightFormatTransformPass() : Pass("weight_format_transform_pass") {} + ~WeightFormatTransformPass() override = default; + void SetQuantType(QuantType type); + void SetFmkType(FmkType fmkType); + void SetDstFormat(schema::Format format); + bool Run(const FuncGraphPtr &graph) override; + + private: + lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph); + + private: + QuantType quant_type = schema::QuantType_QUANT_NONE; + FmkType fmk_type = lite::converter::FmkType_TF; + schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT; +}; +} // namespace mindspore::opt +#endif // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_