Browse Source

!6349 [MSLITE]move weight format pass to anf

Merge pull request !6349 from zhengjun10/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2d7b10da6b
17 changed files with 1038 additions and 273 deletions
  1. +8
    -1
      mindspore/lite/src/ops/conv2d.cc
  2. +2
    -0
      mindspore/lite/test/CMakeLists.txt
  3. +2
    -0
      mindspore/lite/tools/converter/CMakeLists.txt
  4. +13
    -1
      mindspore/lite/tools/converter/anf_transform.cc
  5. +2
    -19
      mindspore/lite/tools/converter/graphdef_transform.cc
  6. +1
    -1
      mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc
  7. +0
    -2
      mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
  8. +0
    -52
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h
  9. +0
    -142
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc
  10. +0
    -53
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h
  11. +558
    -1
      mindspore/lite/tools/optimizer/common/gllo_utils.cc
  12. +45
    -0
      mindspore/lite/tools/optimizer/common/gllo_utils.h
  13. +4
    -1
      mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc
  14. +215
    -0
      mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc
  15. +47
    -0
      mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h
  16. +96
    -0
      mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc
  17. +45
    -0
      mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.h

+ 8
- 1
mindspore/lite/src/ops/conv2d.cc View File

@@ -120,6 +120,8 @@ void ConvertConvWeight(const ParameterPtr &param_node) {
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[1] = filter_k;
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[2] = filter_h;
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[3] = filter_w;
weight->set_tensor_shape({static_cast<int>(filter_c), static_cast<int>(filter_k), static_cast<int>(filter_h),
static_cast<int>(filter_w)});
}
return;
}
@@ -250,7 +252,12 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
MS_LOG(ERROR) << "primitive_ type is error:" << this->primitive_->value.type;
return RET_ERROR;
}
int group = GetValue<int>(prim.GetAttr("group"));
auto groupAttr = prim.GetAttr("group");
if (groupAttr == nullptr) {
MS_LOG(ERROR) << "conv2d op has no group attr,please check pb model";
return RET_NULL_PTR;
}
int group = GetValue<int>(groupAttr);
if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
} else {


+ 2
- 0
mindspore/lite/test/CMakeLists.txt View File

@@ -205,6 +205,8 @@ if(BUILD_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_transform_pass.cc
${LITE_DIR}/tools/optimizer/graph/weight_format_hardcode_pass.cc
)
endif()
### train


+ 2
- 0
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -63,6 +63,8 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_bn_fusion.cc
../optimizer/fusion/constant_folding_fusion.cc
../optimizer/fusion/quant_dtype_cast_fusion.cc
../optimizer/graph/weight_format_transform_pass.cc
../optimizer/graph/weight_format_hardcode_pass.cc
)

add_subdirectory(../anf_importer anf_importer)


+ 13
- 1
mindspore/lite/tools/converter/anf_transform.cc View File

@@ -25,6 +25,8 @@
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@@ -41,6 +43,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
// fusion const_fold
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);

// for now - trainning is not supporting fuse operations
if (config != nullptr && config->trainModel == false) {
@@ -61,11 +64,20 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(true, "conv_tuple_relu6",
schema::PrimitiveType_Activation,
schema::ActivationType_RELU6));
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
weight_format_hardcode_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_hardcode_pass);
auto weight_format_transform_pass = std::make_shared<opt::WeightFormatTransformPass>();
weight_format_transform_pass->SetFmkType(config->fmk);
weight_format_transform_pass->SetQuantType(config->quantType);
graph_pm->AddPass(weight_format_transform_pass);
}

pm->AddPass(std::make_shared<opt::ConstFoldPass>());
optimizer->AddPassManager(pm);
FuncGraphPtr new_graph = optimizer->Optimize(old_graph);
optimizer->AddPassManager(graph_pm);
auto new_graph = optimizer->Optimize(old_graph);
if (new_graph == nullptr) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
return nullptr;


+ 2
- 19
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -28,8 +28,6 @@
#include "tools/converter/legacy_optimizer/graph/trans_format_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/infershape_pass.h"
#include "tools/converter/legacy_optimizer/graph/batchnorm_convert_scale_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/trans_format_insert_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
@@ -62,23 +60,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {

int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
{
Optimizer weightFormatOptimizer;
auto weightHardCodePass = new WeightFormatHardCodePass();
auto weightFormatPass = new WeightFormatTransformPass();
weightHardCodePass->SetQuantType(ctx.quantType);
weightHardCodePass->SetFmkType(ctx.fmk);
weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass->SetFmkType(ctx.fmk);
weightFormatOptimizer.AddPass(weightHardCodePass);
weightFormatOptimizer.AddPass(weightFormatPass);
status = weightFormatOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run weightFormatOptimizer graphPasses Failed";
return status;
}
}

{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
@@ -149,6 +130,8 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
formatTransOptimizer.AddPass(formatTransPass);
formatTransOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
formatTransOptimizer.AddPass(new (std::nothrow) InferShapePass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());
formatTransOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpRemovePass());
formatTransOptimizer.AddPass(new (std::nothrow) TransOpInsertPass());
formatTransOptimizer.AddPass(new (std::nothrow) FormatTransFusionPass());


+ 1
- 1
mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc View File

@@ -79,7 +79,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
MS_ASSERT(mulNodeBiasTensor != nullptr);
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) {
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) {
// dont fusion, return
return RET_OK;
}


+ 0
- 2
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt View File

@@ -4,8 +4,6 @@ add_library(graph_pass_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_format_hardcode_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_convert_scale_pass.cc


+ 0
- 52
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h View File

@@ -1,52 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H

#include <memory>
#include "tools/converter/converter_flags.h"
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"

namespace mindspore {
namespace lite {
class WeightFormatHardCodePass : public GraphPass {
public:
WeightFormatHardCodePass() = default;

~WeightFormatHardCodePass() override = default;

void SetQuantType(QuantType quantType);

void SetFmkType(converter::FmkType fmkType);

STATUS Run(MetaGraphT *graph) override;

private:
STATUS HardCodeCAFFE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);
STATUS HardCodeTFLITE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);
STATUS HardCodeONNX(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);
STATUS HardCodeMS(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);

private:
QuantType quantType = QuantType_QUANT_NONE;
converter::FmkType fmkType = converter::FmkType_TF;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H

+ 0
- 142
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc View File

@@ -1,142 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include <queue>
#include "tools/common/node_util.h"
#include "tools/common/converter_op_utils.h"
#include "utils/log_adapter.h"
#include "src/common/utils.h"

namespace mindspore {
namespace lite {
void WeightFormatTransformPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }

void WeightFormatTransformPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; }

void WeightFormatTransformPass::SetDstFormat(schema::Format format) { this->dstFormat = format; }

STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_WeightQuant) {
auto status = QuantDataFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
return status;
}
} else {
auto status = NonQuantDataFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status;
return status;
}
}
return RET_OK;
}

STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D &&
opType != PrimitiveType_DeConv2D && opType != PrimitiveType_DeDepthwiseConv2D) {
continue;
}
MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = graph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT ||
weightTensor->dataType == DataType_DT_INT8);
STATUS status;
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D ||
opType == PrimitiveType_DeConv2D || opType == PrimitiveType_DeDepthwiseConv2D) { // weight should be HWCK
Format curDstFormat;
if (this->dstFormat == Format_NUM_OF_FORMAT) {
curDstFormat = Format_KHWC;
} else {
curDstFormat = this->dstFormat;
}
status = TransFilterFormat(weightTensor.get(), curDstFormat);
if (status == RET_OK) {
weightTensor->format = curDstFormat;
} else {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat)
<< " failed, node : " << node->name;
return ERROR;
}
}
}
return RET_OK;
}

STATUS WeightFormatTransformPass::NonQuantDataFormatTrans(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D &&
opType != PrimitiveType_DeDepthwiseConv2D) {
continue;
}
MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = graph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT);
STATUS status;
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D) {
schema::Format curDstFormat;
if (this->dstFormat == schema::Format::Format_NUM_OF_FORMAT) {
curDstFormat = schema::Format::Format_KHWC;
} else {
curDstFormat = this->dstFormat;
}
status = TransFilterFormat(weightTensor.get(), curDstFormat);
if (status == RET_OK) {
// node->attr.AsConv2D()->format = schema::Format::Format_NCHW;
weightTensor->format = curDstFormat;
} else {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat)
<< " failed, node : " << node->name;
return ERROR;
}
} else { // weight should be CKHW
schema::Format curDstFormat;
if (this->dstFormat == schema::Format::Format_NUM_OF_FORMAT) {
curDstFormat = schema::Format::Format_KHWC;
} else {
curDstFormat = this->dstFormat;
}
status = TransFilterFormat(weightTensor.get(), curDstFormat);
if (status == RET_OK) {
// node->attr.AsDepthwiseConv2D()->format = schema::Format::Format_NCHW;
weightTensor->format = curDstFormat;
} else {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To" << EnumNameFormat(curDstFormat)
<< " failed, node : " << node->name;
return ERROR;
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 53
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h View File

@@ -1,53 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H

#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"

namespace mindspore {
namespace lite {
class WeightFormatTransformPass : public GraphPass {
public:
WeightFormatTransformPass() = default;

~WeightFormatTransformPass() override = default;

void SetQuantType(QuantType quantType);

void SetFmkType(converter::FmkType fmkType);

void SetDstFormat(schema::Format format);

STATUS Run(MetaGraphT *graph) override;

private:
STATUS QuantDataFormatTrans(MetaGraphT *graph);

STATUS NonQuantDataFormatTrans(MetaGraphT *graph);

private:
QuantType quantType = QuantType_QUANT_NONE;
converter::FmkType fmkType = converter::FmkType_TF;
schema::Format dstFormat = schema::Format::Format_NUM_OF_FORMAT;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H

+ 558
- 1
mindspore/lite/tools/optimizer/common/gllo_utils.cc View File

@@ -18,6 +18,7 @@
#include <algorithm>
#include <utility>
#include "src/ops/primitive_c.h"
#include "src/common/common.h"
#include "frontend/operator/ops.h"
#include "backend/optimizer/common/helper.h"

@@ -391,7 +392,17 @@ schema::PrimitiveType GetCNodeType(const BaseRef &n) {
}
return schema::PrimitiveType_NONE;
}

ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node) {
MS_ASSERT(node != nullptr);
if (!utils::isa<ParameterPtr>(node)) {
MS_LOG(ERROR) << "get lite param value node must paramter";
return nullptr;
}
auto param = node->cast<ParameterPtr>();
MS_ASSERT(param != nullptr);
auto param_value = std::dynamic_pointer_cast<ParamValueLite>(param->default_param());
return param_value;
}
bool IsParamNode(const BaseRef &n) {
if (!utils::isa<ParameterPtr>(n)) {
return false;
@@ -542,5 +553,551 @@ std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>> GetRealNodeUsedListByOu
}
return output_node_list;
}
STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
int32_t *filterH, int32_t *filterW) {
MS_ASSERT(oriDims.size() == 4);
if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
*filterK = oriDims.at(lite::KCHW_K);
*filterC = oriDims.at(lite::KCHW_C);
*filterH = oriDims.at(lite::KCHW_H);
*filterW = oriDims.at(lite::KCHW_W);
} else if (type == kCKHW2HWCK || type == kCKHW2HWKC || type == kCKHW2KHWC) {
*filterC = oriDims.at(lite::CKHW_C);
*filterK = oriDims.at(lite::CKHW_K);
*filterH = oriDims.at(lite::CKHW_H);
*filterW = oriDims.at(lite::CKHW_W);
} else if (type == kHWCK2KCHW || type == kHWCK2CKHW) {
*filterH = oriDims.at(lite::HWCK_H);
*filterW = oriDims.at(lite::HWCK_W);
*filterC = oriDims.at(lite::HWCK_C);
*filterK = oriDims.at(lite::HWCK_K);
} else if (type == kHWKC2KCHW || type == kHWKC2CKHW) {
*filterH = oriDims.at(lite::HWKC_H);
*filterW = oriDims.at(lite::HWKC_W);
*filterK = oriDims.at(lite::HWKC_K);
*filterC = oriDims.at(lite::HWKC_C);
} else if (type == kNHWC2KCHW || type == kNHWC2HWCK || type == kNHWC2CKHW) {
*filterK = oriDims.at(lite::NHWC_N);
*filterH = oriDims.at(lite::NHWC_H);
*filterW = oriDims.at(lite::NHWC_W);
*filterC = oriDims.at(lite::NHWC_C);
} else if (type == kCHWK2HWCK || type == kCHWK2KHWC) {
*filterC = oriDims.at(lite::CHWK_C);
*filterH = oriDims.at(lite::CHWK_H);
*filterW = oriDims.at(lite::CHWK_W);
*filterK = oriDims.at(lite::CHWK_K);
} else if (type == kKHWC2HWCK || type == kKHWC2CHWK) {
*filterK = oriDims.at(lite::KHWC_K);
*filterH = oriDims.at(lite::KHWC_H);
*filterW = oriDims.at(lite::KHWC_W);
*filterC = oriDims.at(lite::KHWC_C);
} else {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
return RET_ERROR;
}
return RET_OK;
}

STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW) {
MS_ASSERT(tensor != nullptr);
if (type == kKCHW2HWCK || type == kCKHW2HWCK || type == kNHWC2HWCK || type == kKHWC2HWCK || type == kCHWK2HWCK) {
tensor->set_tensor_shape({filterH, filterW, filterC, filterK});
} else if (type == kKCHW2HWKC || type == kCKHW2HWKC) {
tensor->set_tensor_shape({filterH, filterW, filterK, filterC});
} else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
tensor->set_tensor_shape({filterK, filterC, filterH, filterW});
} else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
tensor->set_tensor_shape({filterC, filterK, filterH, filterW});
} else if (type == kKHWC2CHWK) {
tensor->set_tensor_shape({filterC, filterH, filterW, filterK});
} else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
tensor->set_tensor_shape({filterK, filterH, filterW, filterC});
} else {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
return RET_ERROR;
}
return RET_OK;
}
template<typename T>
static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW) {
MS_ASSERT(tensor != nullptr);
int count = filterH * filterW * filterC * filterK;
if (count <= 0) {
MS_LOG(ERROR) << "Dim size invalid";
return RET_ERROR;
}
std::unique_ptr<T[]> buf(new(std::nothrow) T[count]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return RET_ERROR;
}

void *originWeightData = tensor->tensor_addr();
T *weightData = static_cast<T *>(originWeightData);

if (weightData == nullptr) {
MS_LOG(ERROR) << "weightData is nullptr";
return RET_ERROR;
}
T *p1Buff = nullptr;
T *p2Buff = nullptr;
switch (type) {
case kCHWK2HWCK:
case kCHWK2KHWC: {
for (int c = 0; c < filterC; ++c) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int k = 0; k < filterK; ++k) {
p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k));
if (type == kCHWK2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCHWK2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kKHWC2HWCK: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kKCHW2HWCK:
case kKCHW2CKHW:
case kKCHW2KHWC:
case kKCHW2HWKC: {
for (int k = 0; k < filterK; ++k) {
for (int c = 0; c < filterC; ++c) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
if (type == kKCHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kKCHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else if (type == kKCHW2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kCKHW2HWCK:
case kCKHW2KHWC:
case kCKHW2HWKC: {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
if (type == kCKHW2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kCKHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kHWCK2KCHW:
case kHWCK2CKHW: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
if (type == kHWCK2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kHWKC2KCHW:
case kHWKC2CKHW: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
for (int k = 0; k < filterK; ++k) {
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kHWKC2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kNHWC2HWCK:
case kNHWC2KCHW:
case kNHWC2CKHW: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c));
if (type == kNHWC2HWCK) {
p2Buff =
buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k));
} else if (type == kNHWC2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
}
*p2Buff = *p1Buff;
}
}
}
}
}
break;
case kKHWC2CHWK: {
for (int k = 0; k < filterK; ++k) {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k));
*p2Buff = *p1Buff;
}
}
}
}
}
break;
default: {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
return RET_ERROR;
}
}

auto ret = ::memcpy_s(tensor->tensor_addr(), count * sizeof(T), buf.get(), count * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return RET_ERROR;
}
return RET_OK;
}

template<typename T>
static STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type) {
MS_ASSERT(tensor != nullptr);
auto oriDims = tensor->tensor_shape();
if (oriDims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
return lite::RET_ERROR;
}

int32_t filterH;
int32_t filterW;
int32_t filterC;
int32_t filterK;
auto status = GetFilterDim(oriDims, type, &filterK, &filterC, &filterH, &filterW);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "GetFilterDim failed: " << status;
return status;
}
status = SetFilterDim(tensor, type, filterK, filterC, filterH, filterW);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "SetFilterDim failed: " << status;
return status;
}
status = TransFilterData<T>(tensor, type, filterK, filterC, filterH, filterW);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "TransFilterData failed: " << status;
return status;
}

return lite::RET_OK;
}

STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format) {
if (tensor == nullptr) {
return lite::RET_NULL_PTR;
}
auto ori_dims = tensor->tensor_shape();
if (ori_dims.size() != (size_t)lite::DIM_DEFAULT_SIZE) {
MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << ori_dims.size();
return lite::RET_ERROR;
}
auto src_format = tensor->format();
auto data_type = tensor->tensor_type();
lite::STATUS status;
switch (dst_format) {
case schema::Format::Format_KHWC: {
switch (src_format) {
case schema::Format::Format_KCHW:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CHWK:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:return RET_OK;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
case schema::Format::Format_HWCK: {
switch (src_format) {
case schema::Format::Format_KCHW:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CHWK:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return lite::RET_ERROR;
}
break;
case schema::Format::Format_HWCK:return RET_OK;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
case schema::Format::Format_KCHW: {
switch (src_format) {
case schema::Format::Format_KCHW:return RET_OK;
case schema::Format::Format_HWCK:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_HWKC:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_KHWC:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CHWK:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
case schema::Format::Format_CKHW: {
switch (src_format) {
case schema::Format::Format_HWCK:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_HWKC:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_KCHW:
if (data_type == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
} else if (data_type == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
} else if (data_type == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported data_type: " << data_type;
return RET_ERROR;
}
break;
case schema::Format::Format_CKHW:return RET_OK;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
return RET_ERROR;
}
}
break;
default:MS_LOG(ERROR) << "Unsupported transform from " << src_format << " to "
<< EnumNameFormat(dst_format);
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilterData failed: " << status;
return status;
}
return RET_OK;
}
} // namespace opt
} // namespace mindspore

+ 45
- 0
mindspore/lite/tools/optimizer/common/gllo_utils.h View File

@@ -18,6 +18,7 @@
#define MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_

#include <memory>
#include <vector>
#include "src/ops//primitive_c.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
@@ -28,6 +29,9 @@
#include "tools/converter/return_code.h"

using PrimitiveCPtr = std::shared_ptr<mindspore::lite::PrimitiveC>;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
bool IsRealCNodeKernel(const AnfNodePtr &node);
@@ -68,6 +72,47 @@ size_t GetOutputTensorNum(const AnfNodePtr &node);
bool IsMultiOutputTensors(const FuncGraphPtr &graph, const AnfNodePtr &node);

size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item);

ParamValueLitePtr GetLiteParamValue(const AnfNodePtr &node);

enum kTransFilterType {
kKCHW2HWCK, // 0
kKCHW2KHWC,
kCKHW2KHWC,
kCKHW2HWCK,
kKCHW2HWKC,
kCKHW2HWKC,
kHWCK2KCHW,
kHWCK2CKHW,
kHWKC2KCHW,
kHWKC2CKHW,
kNHWC2KCHW, // 10
kNHWC2CKHW,
kNHWC2HWCK,
kKHWC2HWCK,
kCHWK2HWCK,
kKHWC2CHWK,
kCHWK2KHWC,
kKHWC2KCHW,
kCKHW2KCHW,
kCHWK2KCHW,
kKCHW2CKHW // 20
};

STATUS GetFilterDim(const std::vector<int32_t> &oriDims, kTransFilterType type, int32_t *filterK, int32_t *filterC,
int32_t *filterH, int32_t *filterW);

STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW);

template<typename T>
static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType type, int32_t filterK, int32_t filterC,
int32_t filterH, int32_t filterW);

template<typename T>
static lite::STATUS TransFilterFormat(const ParamValueLitePtr &tensor, kTransFilterType type);

STATUS TransFilterFormat(const ParamValueLitePtr &tensor, schema::Format dst_format);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_PASS_COMMON_GLLO_UTILS_H_

+ 4
- 1
mindspore/lite/tools/optimizer/fusion/constant_folding_fusion.cc View File

@@ -195,7 +195,10 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
}
changed = true;
auto output_nums = GetOutputTensorNum(input_cnode);
std::vector<Tensor *> output_tensors{output_nums, new Tensor()};
std::vector<Tensor *> output_tensors;
for (size_t j = 0; j < output_nums; j++) {
output_tensors.push_back(new Tensor());
}
auto lite_primitive = GetValueNode<std::shared_ptr<PrimitiveC>>(input_cnode->input(0));
if (lite_primitive == nullptr) {
MS_LOG(ERROR) << "lite_primitive is nullptr";


+ 215
- 0
mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc View File

@@ -0,0 +1,215 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/weight_format_hardcode_pass.h"
#include <memory>
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
void WeightFormatHardCodePass::SetQuantType(QuantType type) {
this->quant_type = type;
}
void WeightFormatHardCodePass::SetFmkType(FmkType type) {
this->fmk_type = type;
}
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr);
switch (quant_type) {
case QuantType_WeightQuant:
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr);
auto op_type = GetCNodeType(conv_node);
switch (this->quant_type) {
case QuantType_AwareTraining: {
// sum up from current onnx quant models
if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_KHWC);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
param_value->set_format(schema::Format::Format_CHWK);
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
// conv (K x C/group x kH x kW) group = 1
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
// deconv (C x K/group x kH x kW) group = 1
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D
|| op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr);
auto op_type = GetCNodeType(conv_node);
switch (this->quant_type) {
case QuantType_AwareTraining: {
if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
param_value->set_format(schema::Format::Format_CKHW);
} else {
param_value->set_format(schema::Format::Format_KCHW);
}
}
break;
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
// sum up from current ms quant models
if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
param_value->set_format(schema::Format::Format_CKHW);
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr);
MS_ASSERT(param_value != nullptr);
auto op_type = GetCNodeType(conv_node);
switch (this->quant_type) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
if (op_type == schema::PrimitiveType_Conv2D) {
param_value->set_format(schema::Format::Format_KHWC);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
param_value->set_format(schema::Format::Format_CHWK);
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_CHWK);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
default: {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}

bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNode>(node)) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto param_value = GetLiteParamValue(weight_node);
if (param_value == nullptr) {
MS_LOG(ERROR) << "weight node must param value";
return false;
}
lite::STATUS status;
switch (fmk_type) {
case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value);
break;
case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value);
break;
case FmkType_ONNX:status = HardCodeONNX(node, param_value);
break;
case FmkType_MS:status = HardCodeMS(node, param_value);
break;
default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
return false;
}
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "schema::Format hardCode faild: " << status << ", node: " << node->fullname_with_scope();
return false;
}
}
return false;
}
} // namespace mindspore::opt

+ 47
- 0
mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.h View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"
#include "src/param_value_lite.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatHardCodePass : public Pass {
public:
WeightFormatHardCodePass() : Pass("weight_format_hardcode_pass") {}
~WeightFormatHardCodePass() override = default;
void SetQuantType(QuantType type);
void SetFmkType(FmkType fmkType);
bool Run(const FuncGraphPtr &graph) override;

private:
lite::STATUS HardCodeCAFFE(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
lite::STATUS HardCodeONNX(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
lite::STATUS HardCodeMS(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;
lite::STATUS HardCodeTFLITE(const AnfNodePtr &node, const ParamValueLitePtr &param_value) const;

private:
QuantType quant_type = schema::QuantType_QUANT_NONE;
FmkType fmk_type = lite::converter::FmkType_TF;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_GRAPH_WEIGHT_FORMAT_HARDCODE_PASS_H_

+ 96
- 0
mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.cc View File

@@ -0,0 +1,96 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/graph/weight_format_transform_pass.h"
#include <memory>
#include "tools/optimizer/common/gllo_utils.h"

using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;

namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
void WeightFormatTransformPass::SetQuantType(QuantType type) {
this->quant_type = type;
}
void WeightFormatTransformPass::SetFmkType(FmkType type) {
this->fmk_type = type;
}
void WeightFormatTransformPass::SetDstFormat(schema::Format format) {
this->dst_format = format;
}
lite::STATUS WeightFormatTransformPass::ConvWeightFormatTrans(const FuncGraphPtr &graph) {
MS_ASSERT(graph != nullptr);
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {
if (!utils::isa<CNodePtr>(node)) {
continue;
}
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}
auto conv_cnode = node->cast<CNodePtr>();
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
auto weight_node = conv_cnode->input(kConvWeightIndex);
MS_ASSERT(weight_node != nullptr);
auto weight_value = GetLiteParamValue(weight_node);
if (weight_value == nullptr) {
MS_LOG(ERROR) << "weight node must param value";
return false;
}
MS_ASSERT(weight_value->tensor_type() == TypeId::kNumberTypeFloat32
|| weight_value->tensor_type() == TypeId::kNumberTypeUInt8);
lite::STATUS status;
schema::Format weight_dst_format = schema::Format::Format_KHWC;
if (dst_format != schema::Format::Format_NUM_OF_FORMAT) {
weight_dst_format = dst_format;
}
status = TransFilterFormat(weight_value, weight_dst_format);
if (status == RET_OK) {
weight_value->set_format(weight_dst_format);
} else {
MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(schema::EnumValuesFormat()[weight_value->format()]) << "To"
<< EnumNameFormat(weight_dst_format) << " failed, node : " << node->fullname_with_scope()
<< "quant type:" << quant_type;
return ERROR;
}
auto type_id = static_cast<TypeId>(weight_value->tensor_type());
auto type_ptr = TypeIdToType(type_id);
auto abstract_tensor = std::make_shared<abstract::AbstractTensor>(type_ptr, weight_value->tensor_shape());
weight_node->set_abstract(abstract_tensor);
}
return RET_OK;
}

bool WeightFormatTransformPass::Run(const FuncGraphPtr &func_graph) {
MS_ASSERT(func_graph != nullptr);
auto status = ConvWeightFormatTrans(func_graph);
if (status != lite::RET_OK) {
MS_LOG(ERROR) << "Conv2D weight FormatTrans failed: " << status;
return status;
}
return false;
}
} // namespace mindspore::opt

+ 45
- 0
mindspore/lite/tools/optimizer/graph/weight_format_transform_pass.h View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_
#include <string>
#include "schema/inner/model_generated.h"
#include "tools/converter/converter_flags.h"
#include "backend/optimizer/common/pass.h"

using mindspore::lite::converter::FmkType;
using mindspore::schema::QuantType;
namespace mindspore::opt {
class WeightFormatTransformPass : public Pass {
public:
WeightFormatTransformPass() : Pass("weight_format_transform_pass") {}
~WeightFormatTransformPass() override = default;
void SetQuantType(QuantType type);
void SetFmkType(FmkType fmkType);
void SetDstFormat(schema::Format format);
bool Run(const FuncGraphPtr &graph) override;

private:
lite::STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph);

private:
QuantType quant_type = schema::QuantType_QUANT_NONE;
FmkType fmk_type = lite::converter::FmkType_TF;
schema::Format dst_format = schema::Format::Format_NUM_OF_FORMAT;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_WEIGHT_FORMAT_TRANSFORM_PASS_H_

Loading…
Cancel
Save