Browse Source

!4548 rearch weight_format_pass

Merge pull request !4548 from hangq/master
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a974a35486
31 changed files with 808 additions and 2238 deletions
  1. +0
    -1
      mindspore/lite/test/CMakeLists.txt
  2. +249
    -74
      mindspore/lite/tools/common/node_util.cc
  3. +15
    -6
      mindspore/lite/tools/common/node_util.h
  4. +0
    -1
      mindspore/lite/tools/converter/CMakeLists.txt
  5. +9
    -44
      mindspore/lite/tools/converter/graphdef_transform.cc
  6. +0
    -2
      mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt
  7. +0
    -7
      mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt
  8. +0
    -101
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc
  9. +0
    -50
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h
  10. +0
    -295
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc
  11. +0
    -51
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h
  12. +0
    -224
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc
  13. +0
    -54
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h
  14. +0
    -41
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc
  15. +0
    -46
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h
  16. +0
    -40
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc
  17. +0
    -45
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h
  18. +0
    -361
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc
  19. +0
    -67
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h
  20. +0
    -126
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc
  21. +0
    -46
      mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h
  22. +2
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt
  23. +99
    -112
      mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc
  24. +10
    -12
      mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h
  25. +6
    -4
      mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h
  26. +213
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc
  27. +52
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h
  28. +140
    -0
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc
  29. +13
    -18
      mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h
  30. +0
    -3
      mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt
  31. +0
    -407
      mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc

+ 0
- 1
mindspore/lite/test/CMakeLists.txt View File

@@ -353,7 +353,6 @@ if (BUILD_CONVERTER)
tflite_parser_mid
caffe_parser_mid
onnx_parser_mid
node_mid
graph_pass_mid
fusion_mid
quantizer_mid


+ 249
- 74
mindspore/lite/tools/common/node_util.cc View File

@@ -24,74 +24,6 @@

namespace mindspore {
namespace lite {
STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr<CNodeT> &node) {
MS_ASSERT(graphT != nullptr);
MS_ASSERT(node != nullptr);
// set quantParam to preNode
for (size_t i = 0; i < node->inputIndex.size(); i++) {
auto preNodeIdexes = GetInputNodeIdx(*graphT, *(node.get()), i);
for (auto preNodeIdx : preNodeIdexes) {
MS_ASSERT(graphT->nodes.size() > preNodeIdx);
auto &preNode = graphT->nodes.at(preNodeIdx);
MS_ASSERT(preNode != nullptr);
// if preNode is not init, it maybe not a quantNode, so skip
// if (preNode->inputIndex.size() + preNode->outputIndex.size() != preNode->quantParam.size()) {
// continue;
// }
auto preNodeOutputIndexes = preNode->outputIndex;
int32_t currentNodeIndexInPre = -1;
for (auto index : preNodeOutputIndexes) {
currentNodeIndexInPre++;
if (index == node->inputIndex.at(i)) {
break;
}
}
MS_ASSERT(currentNodeIndexInPre != -1);
MS_ASSERT(node->quantParam.size() > i);
MS_ASSERT(node->quantParam.at(i) != nullptr);
// auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(i));
// if (quantParamArrayCopy == nullptr) {
// //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str());
// return RET_ERROR;
// }
// preNode->quantParam.at(preNode->inputIndex.size() + currentNodeIndexInPre) =
// std::move(CopyQuantParamArrayT(quantParamArrayCopy));
}
}

// set quantParam to postNode
for (size_t i = 0; i < node->outputIndex.size(); i++) {
auto postNodeIdexes = GetOutputNodeIdx(*graphT, *(node.get()), i);
for (auto postNodeIdx : postNodeIdexes) {
MS_ASSERT(graphT->nodes.size() > postNodeIdx);
auto &postNode = graphT->nodes.at(postNodeIdx);
MS_ASSERT(postNode != nullptr);
// if postNode is not init, it maybe not a quantNode, so skip
// if (postNode->inputIndex.size() + postNode->outputIndex.size() != postNode->quantParam.size()) {
// continue;
// }
auto postNodeInputIndexes = postNode->inputIndex;
int32_t currentNodeIndexInPost = -1;
for (auto index : postNodeInputIndexes) {
currentNodeIndexInPost++;
if (index == node->outputIndex.at(i)) {
break;
}
}
MS_ASSERT(currentNodeIndexInPost != -1);
MS_ASSERT(node->quantParam.size() > node->inputIndex.size() + i);
MS_ASSERT(node->quantParam.at(node->inputIndex.size() + i) != nullptr);
// auto quantParamArrayCopy = CopyQuantParamArrayT(node->quantParam.at(node->inputIndex.size() + i));
// if (quantParamArrayCopy == nullptr) {
// //MS_LOG(ERROR)("CopyQuantParamArray return nullptr, node: %s", node->name.c_str());
// return RET_ERROR;
// }
// postNode->quantParam.at(currentNodeIndexInPost) = std::move(CopyQuantParamArrayT(quantParamArrayCopy));
}
}
return RET_OK;
}

static const std::vector<schema::PrimitiveType> nhwcOpList = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
@@ -121,8 +53,8 @@ std::vector<schema::PrimitiveType> GetUint8OpList() { return uint8OpList; }
STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vector<int32_t> &src_dims,
mindspore::lite::Format dst_format, std::vector<int32_t> *dst_dims) {
if ((src_dims.size() != DIM_DEFAULT_SIZE && src_dims.size() != 3) || src_format == dst_format) {
// MS_LOG(ERROR)("Convert format , src size %lu <3 or src format is equal to dst format,not need convert",
// src_dims.size());
MS_LOG(ERROR) << "Convert format , src size " << src_dims.size()
<< " <3 or src format is equal to dst format,not need convert";
*dst_dims = src_dims;
return RET_PARAM_INVALID;
}
@@ -145,12 +77,12 @@ STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vec
}
break;
default:
// MS_LOG(ERROR)("Not support src format: %d", src_format);
MS_LOG(ERROR) << "Not support src format: " << schema::EnumNameFormat(src_format);
return RET_ERROR;
}

if (nchw_dim.size() == 0) {
// MS_LOG(ERROR)("Param nchw_dim is empty!");
MS_LOG(ERROR) << "Param nchw_dim is empty!";
return RET_ERROR;
}

@@ -172,7 +104,250 @@ STATUS NodeUtils::ConvertDims(mindspore::lite::Format src_format, const std::vec
}
return RET_OK;
}

STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat) {
if (tensor == nullptr) {
return RET_NULL_PTR;
}
std::vector<int32_t> oriDims = tensor->dims;
if (oriDims.size() != (size_t)DIM_DEFAULT_SIZE) {
MS_LOG(ERROR) << "Filter dim-num is not supported, dim-num: " << oriDims.size();
return RET_ERROR;
}
auto srcFormat = tensor->format;
auto dataType = tensor->dataType;
STATUS status;
switch (dstFormat) {
case schema::Format_KHWC: {
switch (srcFormat) {
case schema::Format_KCHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2KHWC);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2KHWC);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CKHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2KHWC);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KHWC);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CHWK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2KHWC);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KHWC);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_KHWC:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
<< schema::EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
case schema::Format_HWCK: {
switch (srcFormat) {
case schema::Format_KCHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_KHWC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKHWC2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKHWC2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CKHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CHWK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2HWCK);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2HWCK);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2HWCK);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_HWCK:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
<< schema::EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
case schema::Format_KCHW: {
switch (srcFormat) {
case schema::Format_KCHW:
return RET_OK;
case schema::Format_HWCK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWCK2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_HWKC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_KHWC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKHWC2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKHWC2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKHWC2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CKHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCKHW2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCKHW2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCKHW2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CHWK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kCHWK2KCHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kCHWK2KCHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kCHWK2KCHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
default:
MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
<< schema::EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
case schema::Format_CKHW: {
switch (srcFormat) {
case schema::Format_HWCK:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWCK2CKHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWCK2CKHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWCK2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_HWKC:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kHWKC2CKHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kHWKC2CKHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kHWKC2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_KCHW:
if (dataType == kNumberTypeFloat32) {
status = TransFilterFormat<float>(tensor, kKCHW2CKHW);
} else if (dataType == kNumberTypeUInt8) {
status = TransFilterFormat<uint8_t>(tensor, kKCHW2CKHW);
} else if (dataType == kNumberTypeInt8) {
status = TransFilterFormat<int8_t>(tensor, kKCHW2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported dataType: " << dataType;
return RET_ERROR;
}
break;
case schema::Format_CKHW:
return RET_OK;
default:
MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
<< schema::EnumNameFormat(dstFormat);
return RET_ERROR;
}
} break;
default:
MS_LOG(ERROR) << "Unsupported transform from " << schema::EnumNameFormat(srcFormat) << " to "
<< schema::EnumNameFormat(dstFormat);
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "TransFilterData failed: " << status;
return status;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore



+ 15
- 6
mindspore/lite/tools/common/node_util.h View File

@@ -53,7 +53,7 @@ class NodeUtils {

// todo check this
enum kTransFilterType {
kKCHW2HWCK,
kKCHW2HWCK, // 0
kKCHW2KHWC,
kCKHW2KHWC,
kCKHW2HWCK,
@@ -63,19 +63,23 @@ enum kTransFilterType {
kHWCK2CKHW,
kHWKC2KCHW,
kHWKC2CKHW,
kNHWC2KCHW,
kNHWC2KCHW, // 10
kNHWC2CKHW,
kNHWC2HWCK,
kKHWC2HWCK,
kCHWK2HWCK,
kKHWC2CHWK,
kCHWK2KHWC
kCHWK2KHWC,
kKHWC2KCHW,
kCKHW2KCHW,
kCHWK2KCHW,
kKCHW2CKHW // 20
};

static STATUS GetFilterDim(std::vector<int32_t> &oriDims, kTransFilterType type, int32_t &filterK, int32_t &filterC,
int32_t &filterH, int32_t &filterW) {
MS_ASSERT(oriDims.size() == 4);
if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC) {
if (type == kKCHW2HWCK || type == kKCHW2HWKC || type == kKCHW2KHWC || type == kKCHW2CKHW) {
filterK = oriDims.at(KCHW_K);
filterC = oriDims.at(KCHW_C);
filterH = oriDims.at(KCHW_H);
@@ -126,7 +130,7 @@ static STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32
tensor->dims = {filterH, filterW, filterK, filterC};
} else if (type == kHWCK2KCHW || type == kHWKC2KCHW || type == kNHWC2KCHW) {
tensor->dims = {filterK, filterC, filterH, filterW};
} else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW) {
} else if (type == kHWCK2CKHW || type == kHWKC2CKHW || type == kNHWC2CKHW || type == kKCHW2CKHW) {
tensor->dims = {filterC, filterK, filterH, filterW};
} else if (type == kKHWC2CHWK) {
tensor->dims = {filterC, filterH, filterW, filterK};
@@ -198,6 +202,7 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
}
} break;
case kKCHW2HWCK:
case kKCHW2CKHW:
case kKCHW2KHWC:
case kKCHW2HWKC: {
for (int k = 0; k < filterK; ++k) {
@@ -211,6 +216,9 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in
} else if (type == kKCHW2KHWC) {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
} else if (type == kKCHW2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c));
@@ -367,7 +375,8 @@ static STATUS TransFilterFormat(schema::TensorT *tensor, kTransFilterType type)

return RET_OK;
}

STATUS TransFilterFormat(schema::TensorT *tensor, schema::Format dstFormat);
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_NODE_UTIL_H


+ 0
- 1
mindspore/lite/tools/converter/CMakeLists.txt View File

@@ -103,7 +103,6 @@ target_link_libraries(converter_lite PRIVATE
onnx_parser_mid
anf_importer_mid
anf_exporter_mid
node_mid
graph_pass_mid
fusion_mid
quantizer_mid


+ 9
- 44
mindspore/lite/tools/converter/graphdef_transform.cc View File

@@ -23,34 +23,12 @@
#include "src/common/op_utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_transpose_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h"
//
// #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/cast_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/concat_v2_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/expand_dims_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/mul_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/range_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/reshape_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/rsqrt_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/shape_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/slice_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/stack_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/strided_slice_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/sub_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/tile_const_fold_pass.h"
// #include "tools/converter/legacy_optimizer/const_fold/transpose_const_fold_pass.h"
//
#include "tools/converter/legacy_optimizer/node/weight_format_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h"
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
@@ -76,19 +54,6 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
break;
}
// case QuantType::QuantType_WeightQuant: {
// MS_LOGI("create WeightQuantizer!");
// mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize));
// break;
// }
// case QuantType_PostTraining: {
// MS_LOGI("create PostTrainningQuantizer!");
// mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile));
// break;
// }
// case QuantType::QuantType_QUANT_NONE:
// MS_LOGD("Not do quantization for model!");
// break;
default:
// MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str());
break;
@@ -97,16 +62,16 @@ void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {

int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
// weight format trans
if (ctx.formatTrans) {
{
Optimizer weightFormatOptimizer;
auto weightFormatPass = new (std::nothrow) WeightFormatPass();
if (weightFormatPass == nullptr) {
MS_LOG(ERROR) << "new weightFormatPass failed";
return RET_ERROR;
}
auto weightHardCodePass = new WeightFormatHardCodePass();
auto weightFormatPass = new WeightFormatTransformPass();
weightHardCodePass->SetQuantType(ctx.quantType);
weightHardCodePass->SetFmkType(ctx.fmk);
weightFormatPass->SetQuantType(ctx.quantType);
weightFormatPass->SetFmkType(ctx.fmk);
// weightFormatPass->SetDstFormat(Format_KHWC);
weightFormatOptimizer.AddPass(weightHardCodePass);
weightFormatOptimizer.AddPass(weightFormatPass);
status = weightFormatOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {


+ 0
- 2
mindspore/lite/tools/converter/legacy_optimizer/CMakeLists.txt View File

@@ -1,6 +1,4 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR})

add_subdirectory(fusion)
#add_subdirectory(const_fold)
add_subdirectory(node)
add_subdirectory(graph)

+ 0
- 7
mindspore/lite/tools/converter/legacy_optimizer/fusion/CMakeLists.txt View File

@@ -1,13 +1,6 @@
add_library(fusion_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/fusion_pattern.cc
${CMAKE_CURRENT_SOURCE_DIR}/fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_bias_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_bn_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_scale_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_activation_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_relu_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_relu6_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/conv_biasadd_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/matmul_biasadd_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast_fusion_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/batchnorm_fold_fusion_pass.cc


+ 0
- 101
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.cc View File

@@ -1,101 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "tools/common/graph_util.h"
#include "src/common/op_utils.h"

namespace mindspore {
namespace lite {
#define CONV_ACTIVATION_MATCH_PATH_LEN 2

STATUS ConvActivationFusionPass::DefinePattern() {
auto convOp = std::make_shared<PatternOp>();
convOp->id = kConvName;
convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
auto actOp = std::make_shared<PatternOp>();
actOp->id = ACTIVATION_NAME;
actOp->types = {schema::PrimitiveType_Activation};
actOp->left = convOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvActivationFusion"));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(convOp);
fusionPattern->AddPatternOp(actOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());

return RET_OK;
}

// 1. change attr of conv
// 2. delete Activation node
STATUS ConvActivationFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != CONV_ACTIVATION_MATCH_PATH_LEN) {
MS_LOG(ERROR) << "Conv-Activation-Fusion should have two NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

auto convPath = matchedPath[kConvName];
auto actPath = matchedPath[ACTIVATION_NAME];
auto &convNode = graph->nodes.at(convPath->nodeIdx);
auto &actNode = graph->nodes.at(actPath->nodeIdx);

// todo if combine conv_relu_fusion and conv_relu6_fusion to conv_activation_fusion
if (actNode->primitive->value.AsActivation()->type != this->activationType) {
return RET_NO_CHANGE;
}

if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) {
convNode->primitive->value.AsConv2D()->activationType = this->activationType;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
convNode->primitive->value.AsDepthwiseConv2D()->activationType = this->activationType;
} else {
MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type;
return RET_ERROR;
}

// remove activation node
MergeNodeAttrFromPost(convNode, actNode);
auto status = IsolateOneWayNode(graph, actPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, subGraph: " << actPath->subGraphIdx << ", node: " << actPath->nodeIdx
<< ", error: " << status;
return status;
}

return RET_OK;
}

STATUS ConvActivationFusionPass::Run(schema::MetaGraphT *graph) {
SetActivationType();
return FusionPass::Run(graph);
}

} // namespace lite
} // namespace mindspore

+ 0
- 50
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h View File

@@ -1,50 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H
#define MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H

#include <memory>
#include <string>
#include <unordered_map>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"

namespace mindspore {
namespace lite {
class ConvActivationFusionPass : public FusionPass {
public:
ConvActivationFusionPass() = default;

~ConvActivationFusionPass() override = default;

STATUS DefinePattern() override;

virtual STATUS SetActivationType() = 0;

// 1. change attr of conv
// 2. delete Activation node
STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;

protected:
schema::ActivationType activationType = schema::ActivationType_RELU;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_CONV_ACTIVATION_FUSION_PASS_H

+ 0
- 295
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.cc View File

@@ -1,295 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h"
#include <cfloat>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <memory>
#include "utils/log_adapter.h"
#include "securec/include/securec.h"
// #include "utils/log_adapter.h"
#include "tools/common/graph_util.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "src/common/op_utils.h"

namespace mindspore {
namespace lite {
#define CONV_BIASADD_MATCH_PATH_LEN 2
#define BIASADD_OP_BIAS_INDEX_IN_WEIGHT 0
#define BIASADD_OP_INPUT_NUM 2
#define BIASADD_OP_CONST_TENSOR_INDEX 1

STATUS ConvBiasAddFusionPass::Run(MetaGraphT *graph) { return FusionPass::Run(graph); }

STATUS ConvBiasAddFusionPass::DefinePattern() {
auto convOp = std::make_shared<PatternOp>();
convOp->id = kConvName;
convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeConv2D};
auto baOp = std::make_shared<PatternOp>();
baOp->id = BIASADD_NAME;
baOp->types = {schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Add};
baOp->left = convOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvBiasAddFusion"));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(convOp);
fusionPattern->AddPatternOp(baOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());

return RET_OK;
}

STATUS ConvBiasAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != CONV_BIASADD_MATCH_PATH_LEN) {
MS_LOG(ERROR) << "Conv-BiasAdd-Fusion should have two NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

auto convPath = matchedPath[kConvName];
auto baPath = matchedPath[BIASADD_NAME];
auto &convNode = graph->nodes.at(convPath->nodeIdx);
auto &baNode = graph->nodes.at(baPath->nodeIdx);
// add/biasadd node the second tensor is not constant tensor, don't fusion
auto baNodeInputIndex = baNode->inputIndex;
if (baNodeInputIndex.size() != BIASADD_OP_INPUT_NUM) {
MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! ";
return RET_ERROR;
}
auto baNodeBiasTensor = graph->allTensors.at(baNodeInputIndex[BIASADD_OP_CONST_TENSOR_INDEX]).get();
MS_ASSERT(baNodeBiasTensor != nullptr);
if (baNodeBiasTensor->nodeType != schema::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}

// 1. generate newBiasTensor for conv
auto status = GenConvBiasTensor(convPath, baPath, graph);
if (RET_OK != status) {
MS_LOG(ERROR) << "GenConvBiasTensor failed, " << status;
return status;
}
if (this->newBiasTensor != nullptr) {
status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor));
this->newBiasTensor = nullptr;
if (status != RET_OK) {
MS_LOG(ERROR) << "AddTensor2Node failed, node: " << convPath->nodeIdx << ", error: " << status;
return status;
}
// add bias quantParam
// todo add quantParam for tensors

// if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) {
// std::unique_ptr<QuantParamArrayT> quantParamArray(new QuantParamArrayT());
// if (quantParamArray == nullptr) {
// MS_LOG(ERROR) << "new QuantParamArrayT failed");
// return RET_ERROR;
// }
// std::unique_ptr<QuantParamT> quantParam(new QuantParamT());
// if (quantParam == nullptr) {
// MS_LOG(ERROR) << "new QuantParamT failed");
// return RET_ERROR;
// }
// quantParam->numBits = -1;
// quantParam->scale = FLT_MAX;
// quantParam->zeroPoint = 0;
// quantParam->narrowRange = true;
// quantParam->min = FLT_MAX;
// quantParam->max = FLT_MAX;
// quantParamArray->param.emplace_back(quantParam.release());
// convNode->quantParam.emplace_back(quantParamArray.release());
// }
}

// 2. change attr of conv
if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) {
convNode->primitive->value.AsConv2D()->hasBias = true;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) {
convNode->primitive->value.AsDeConv2D()->hasBias = true;
} else {
MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type;
return RET_ERROR;
}

// 5. delete BiasAdd node
MergeNodeAttrFromPost(convNode, baNode);
status = IsolateOneWayNode(graph, baPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, graph: %zu, node: %zu, error: %d";
//, baPath->subGraphIdx, baPath->nodeIdx, status);
return status;
}

return RET_OK;
}

#define BIASADD_WEIGHT_SHAPE_SIZE 1
#define BIASADD_BIAS_DIM_INDEX 0

STATUS ConvBiasAddFusionPass::GenConvBiasTensor(std::shared_ptr<Path> convPath, std::shared_ptr<Path> baPath,
MetaGraphT *graph) {
MS_ASSERT(convPath != nullptr);
MS_ASSERT(baPath != nullptr);
MS_ASSERT(graph != nullptr);

auto convNode = graph->nodes.at(convPath->nodeIdx).get();
MS_ASSERT(convNode != nullptr);
auto baNode = graph->nodes.at(baPath->nodeIdx).get();
MS_ASSERT(baNode != nullptr);
int32_t kernelNum = 0;
if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) {
kernelNum = convNode->primitive->value.AsConv2D()->channelOut;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelIn *
convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DeConv2D) {
kernelNum = convNode->primitive->value.AsDeConv2D()->channelOut;
}
auto convWeightTensorIdxes = convNode->inputIndex;
if (convWeightTensorIdxes.size() < CONV_OP_NO_BIAS_INPUT_NUM) {
MS_LOG(ERROR) << convNode->name.c_str() << " node tensors number is invalid! ";
return RET_ERROR;
}
convWeightTensorIdxes.erase(convWeightTensorIdxes.begin());
auto baWeightTensorIdxes = baNode->inputIndex;
if (baWeightTensorIdxes.size() != BIASADD_OP_INPUT_NUM) {
MS_LOG(ERROR) << baNode->name.c_str() << " node tensors number is invalid! ";
return RET_ERROR;
}
baWeightTensorIdxes.erase(baWeightTensorIdxes.begin());

if (convWeightTensorIdxes.empty()) {
MS_LOG(ERROR) << "Conv2D should has one weight tensors at least, current number of weight tensors "
<< convWeightTensorIdxes.size();
return RET_ERROR;
}

if (baWeightTensorIdxes.empty()) {
MS_LOG(ERROR) << "BiasAdd should has one weight tensors at least, current number of weight tensors "
<< baWeightTensorIdxes.size();
return RET_ERROR;
}

TensorT *oldBiasTensor = nullptr;
TensorT *biasTensor = nullptr;

if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) {
oldBiasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get();
MS_ASSERT(oldBiasTensor != nullptr);
}
biasTensor = graph->allTensors.at(baWeightTensorIdxes.at(BIASADD_OP_BIAS_INDEX_IN_WEIGHT)).get();
MS_ASSERT(biasTensor != nullptr);
auto biasDims = biasTensor->dims;
// if biasTensor is a scaler
if (biasDims.empty() && biasTensor->data.data() == nullptr) {
MS_LOG(ERROR) << "BiasAdd node %s bias tensor is invalid" << baNode->name.c_str();
return RET_ERROR;
}
if (!biasDims.empty() && biasDims.size() != BIASADD_WEIGHT_SHAPE_SIZE) {
MS_LOG(ERROR) << "BiasAdd bias tensor should has one dimension, current number of dimension " << biasDims.size()
<< ". or bias tensor is a scaler";
return RET_ERROR;
}

bool bias_const = !biasDims.empty() && biasDims.size() == 1 && biasDims[0] == 1;
if (!biasDims.empty() && !bias_const && biasDims.at(BIASADD_BIAS_DIM_INDEX) != kernelNum) {
MS_LOG(ERROR) << "Size(%d) of BiasAdd(%s) bias tensor should be equal to kernelNum(%d)"
<< biasDims.at(BIASADD_BIAS_DIM_INDEX) << baNode->name.c_str() << kernelNum;
return RET_ERROR;
}

// cal new biasData
this->newBiasData = new (std::nothrow) float[kernelNum];
if (newBiasData == nullptr) {
MS_LOG(ERROR) << "new newBiasData failed";
return RET_ERROR;
}

if (biasDims.empty() && biasTensor->data.data() != nullptr) {
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
if (0 != memset_s(newBiasData, kernelNum * sizeof(float), *biasData, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memset_s newBiasData failed";
return RET_ERROR;
}
} else if (bias_const) {
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
for (size_t i = 0; i < kernelNum; i++) {
newBiasData[i] = *biasData;
}
} else {
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), biasTensor->data.data(), kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s newBiasData failed";
return RET_ERROR;
}
}
if (oldBiasTensor != nullptr) {
auto oldBiasDims = oldBiasTensor->dims;
if (oldBiasDims.size() != 1) {
MS_LOG(ERROR)
<< "Conv bias tensor should has one dimension, current number of dimension %zu"; // oldBiasDims.size());
return RET_ERROR;
}
if (oldBiasDims.at(0) != kernelNum) {
MS_LOG(ERROR)
<< "Size(%zu) of Conv bias tensor should be equal to kernelNum(%d), current number of dimension %zu";
// oldBiasDims.size(), kernelNum);
return RET_ERROR;
}
auto *oldBiasData = reinterpret_cast<float *>(oldBiasTensor->data.data());
for (size_t i = 0; i < kernelNum; i++) {
oldBiasData[i] += newBiasData[i];
}
} else {
auto *newCharBiasData = reinterpret_cast<uint8_t *>(newBiasData);
std::vector<uint8_t> tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t));

auto weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get();
this->newBiasTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT);
// todo biasShape
this->newBiasTensor->dims = {kernelNum};
this->newBiasTensor->dataType = weightTensor->dataType;
this->newBiasTensor->format = weightTensor->format;
this->newBiasTensor->refCount = weightTensor->refCount;
this->newBiasTensor->data.swap(tmpBiasVec);
newCharBiasData = nullptr;
}

delete (this->newBiasData);
newBiasData = nullptr;

return RET_OK;
}

ConvBiasAddFusionPass::~ConvBiasAddFusionPass() {
if (this->newBiasData != nullptr) {
delete (this->newBiasData);
}
}

} // namespace lite
} // namespace mindspore

+ 0
- 51
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h View File

@@ -1,51 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H
#define MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H

#include <string>
#include <unordered_map>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"

namespace mindspore {
namespace lite {
class ConvBiasAddFusionPass : public FusionPass {
public:
ConvBiasAddFusionPass() = default;

~ConvBiasAddFusionPass() override;

STATUS DefinePattern() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;

protected:
// gen this->newBiasTensor if conv has no bias before
STATUS GenConvBiasTensor(std::shared_ptr<Path> convPath, std::shared_ptr<Path> dstPath, schema::MetaGraphT *graph);

protected:
float *newBiasData = nullptr;
std::unique_ptr<TensorT> newBiasTensor = nullptr;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_CONV_BIASADD_FUSION_PASS_H

+ 0
- 224
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.cc View File

@@ -1,224 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <unordered_map>
#include <memory>
#include <cmath>
#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h"
#include "securec/include/securec.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace lite {
#define CAFFE_BATCHNORM_OP_WEIGHT_NUM 2
#define TF_BATCHNORM_OP_WEIGHT_NUM 4
#define CAFFE_BATCHNORM_MEAN_INDEX 0
#define CAFFE_BATCHNORM_VARIANCE_INDEX 1
#define TF_BATCHNORM_SCALE_INDEX 0
#define TF_BATCHNORM_BIAS_INDEX 1
#define TF_BATCHNORM_MEAN_INDEX 2
#define TF_BATCHNORM_VARIANCE_INDEX 3

constexpr const float EPS = 1e-8;
constexpr const float EPS_DEFAULT_FLOAT = 1e-5;
constexpr const float POW_NUM = 0.5;

STATUS ConvBNFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath);
}

STATUS ConvBNFusionPass::DefinePattern() {
auto convOp = std::make_shared<PatternOp>();
convOp->id = kConvName;
convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
auto bnOp = std::make_shared<PatternOp>();
bnOp->id = DST_NAME;
bnOp->types = {schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_BatchNorm};
bnOp->left = convOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvBatchNormFusion"));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(convOp);
fusionPattern->AddPatternOp(bnOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());

return RET_OK;
}

STATUS ConvBNFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); }

STATUS ConvBNFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(bnPath != nullptr);

BNWeightTensors bnWeightTensors;

auto status = GetBnWeightTensors(graph, bnPath, kernelNum, bnWeightTensors);
if (status != RET_OK) {
MS_LOG(ERROR) << "GetBnWeightTensors error " << status;
return status;
}
schema::TensorT *meanTensor = bnWeightTensors.meanTensor;
schema::TensorT *varianceTensor = bnWeightTensors.varianceTensor;
schema::TensorT *scaleTensor = bnWeightTensors.scaleTensor;
schema::TensorT *biasTensor = bnWeightTensors.biasTensor;

auto *meanData = reinterpret_cast<float *>(meanTensor->data.data());
auto *varianceData = reinterpret_cast<float *>(varianceTensor->data.data());

float eps = EPS_DEFAULT_FLOAT;
status = GetBnEpsilon(graph, bnPath, eps);
if (status != RET_OK) {
MS_LOG(ERROR) << "GetBnEpsilon failed " << status;
return status;
}

// cal transScale, tf : scale/sqrt(variance + eps); caffe : 1/sqrt(variance + eps)
if (memcpy_s(transScale, kernelNum * sizeof(float), varianceData, kernelNum * sizeof(float)) != 0) {
MS_LOG(ERROR) << "memcpy_s transScale error";
return RET_ERROR;
}
// 1/sqrt(variance + eps)
for (int32_t i = 0; i < kernelNum; i++) {
float tmp = transScale[i] + eps;
tmp = pow(tmp, POW_NUM);
transScale[i] = 1 / tmp;
}

if (scaleTensor != nullptr) {
auto *scaleData = reinterpret_cast<float *>(scaleTensor->data.data());
// scale/sqrt(variance + eps)
for (int32_t i = 0; i < kernelNum; i++) {
transScale[i] *= scaleData[i];
}
}

// cal transBias, tf : -scale*mean/sqrt(variance + eps) + bias; caffe : -mean/sqrt(variance + eps)
// -mean/sqrt(variance + eps)
for (int32_t i = 0; i < kernelNum; i++) {
transBias[i] = -meanData[i] * transScale[i];
}

if (biasTensor != nullptr) {
auto *biasData = reinterpret_cast<float *>(biasTensor->data.data());
// -scale*mean/sqrt(variance + eps) + bias
for (int32_t i = 0; i < kernelNum; i++) {
transBias[i] += biasData[i];
}
}

return RET_OK;
}

// BatchNorm weight Tensor definition:
// caffe
// estimated_mean --0
// estimated_variance --1
// tensorflow
// scale -- 0
// bias --1
// estimated_mean --2
// estimated_variance --3
STATUS ConvBNFusionPass::GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum,
BNWeightTensors &bnWeightTensors) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(bnPath != nullptr);
auto bnNode = graph->nodes.at(bnPath->nodeIdx).get();
auto bnWeightTensorIdxes = bnNode->inputIndex;
bnWeightTensorIdxes.erase(bnWeightTensorIdxes.begin());
if (bnWeightTensorIdxes.size() == CAFFE_BATCHNORM_OP_WEIGHT_NUM) {
bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_MEAN_INDEX]).get();
bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[CAFFE_BATCHNORM_VARIANCE_INDEX]).get();
} else if (bnWeightTensorIdxes.size() == TF_BATCHNORM_OP_WEIGHT_NUM) {
bnWeightTensors.scaleTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_SCALE_INDEX]).get();
bnWeightTensors.biasTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_BIAS_INDEX]).get();
bnWeightTensors.meanTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_MEAN_INDEX]).get();
bnWeightTensors.varianceTensor = graph->allTensors.at(bnWeightTensorIdxes[TF_BATCHNORM_VARIANCE_INDEX]).get();
} else {
MS_LOG(ERROR) << "BatchNorm should has " << CAFFE_BATCHNORM_OP_WEIGHT_NUM << " or " << TF_BATCHNORM_OP_WEIGHT_NUM
<< " weight tensors, current number of weight tensors " << bnWeightTensorIdxes.size();
return RET_ERROR;
}

if (bnWeightTensors.meanTensor == nullptr) {
MS_LOG(ERROR) << "BatchNorm's mean tensor is nullptr";
return RET_ERROR;
}

if (bnWeightTensors.varianceTensor == nullptr) {
MS_LOG(ERROR) << "BatchNorm's variance tensor is nullptr";
return RET_ERROR;
}

if (kernelNum != bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float)) {
MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size("
<< bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")";
return RET_ERROR;
}

if (kernelNum != bnWeightTensors.varianceTensor->data.size() * sizeof(uint8_t) / sizeof(float)) {
MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size("
<< bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")";
return RET_ERROR;
}

if (bnWeightTensors.scaleTensor != nullptr) {
if (kernelNum != bnWeightTensors.scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) {
MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size("
<< bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")";
return RET_ERROR;
}
}

if (bnWeightTensors.biasTensor != nullptr) {
if (kernelNum != bnWeightTensors.biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) {
MS_LOG(ERROR) << "conv kernel num " << kernelNum << " is expected to be equal to mean size("
<< bnWeightTensors.meanTensor->data.size() * sizeof(uint8_t) / sizeof(float) << ")";
return RET_ERROR;
}
}
return RET_OK;
}

STATUS ConvBNFusionPass::GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, float &eps) {
MS_ASSERT(graph != nullptr);
auto bnNode = graph->nodes.at(bnPath->nodeIdx).get();
MS_ASSERT(bnNode != nullptr);
if (bnNode->primitive->value.type == schema::PrimitiveType_FusedBatchNorm) {
eps = bnNode->primitive->value.AsFusedBatchNorm()->epsilon;
} else if (bnNode->primitive->value.type == schema::PrimitiveType_BatchNorm) {
eps = bnNode->primitive->value.AsBatchNorm()->epsilon;
} else {
MS_LOG(ERROR) << "match pattern has error, " << bnNode->name.c_str() << " not BatchNorm node";
return RET_ERROR;
}

if (eps < EPS) {
eps = EPS_DEFAULT_FLOAT;
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 0
- 54
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h View File

@@ -1,54 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <unordered_map>
#include <memory>
#ifndef MINDSPORE_CONV_BN_FUSION_PASS_H
#define MINDSPORE_CONV_BN_FUSION_PASS_H

#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h"

namespace mindspore {
namespace lite {
class ConvBNFusionPass : public ConvScaleBiasFusionPass {
public:
ConvBNFusionPass() = default;

~ConvBNFusionPass() override = default;

STATUS DefinePattern() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;

protected:
STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum) override;

// Get and check BNNode weight tensor
STATUS GetBnWeightTensors(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, int32_t kernelNum,
BNWeightTensors &bnWeightTensors);

STATUS GetBnEpsilon(schema::MetaGraphT *graph, std::shared_ptr<Path> bnPath, float &eps);
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_CONV_BN_FUSION_PASS_H


+ 0
- 41
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.cc View File

@@ -1,41 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include <string>
#include <unordered_map>
#include "tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
STATUS ConvRelu6FusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); }

STATUS ConvRelu6FusionPass::SetActivationType() {
this->activationType = ActivationType_RELU6;
return RET_OK;
}

STATUS ConvRelu6FusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath);
}

STATUS ConvRelu6FusionPass::Run(MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); }

} // namespace lite
} // namespace mindspore

+ 0
- 46
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu6_fusion_pass.h View File

@@ -1,46 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H
#define MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H

#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h"
#include <memory>
#include <string>
#include <unordered_map>

namespace mindspore {
namespace lite {
class ConvRelu6FusionPass : public ConvActivationFusionPass {
public:
ConvRelu6FusionPass() = default;

~ConvRelu6FusionPass() override = default;

STATUS DefinePattern() override;

STATUS SetActivationType() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_CONV_RELU6_FUSION_PASS_H


+ 0
- 40
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.cc View File

@@ -1,40 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <memory>
#include <string>
#include <unordered_map>
#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
STATUS ConvReluFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
return ConvActivationFusionPass::DoFusion(graph, patternName, matchedPath);
}

STATUS ConvReluFusionPass::Run(schema::MetaGraphT *graph) { return ConvActivationFusionPass::Run(graph); }

STATUS ConvReluFusionPass::SetActivationType() {
this->activationType = schema::ActivationType_RELU;
return RET_OK;
}

STATUS ConvReluFusionPass::DefinePattern() { return ConvActivationFusionPass::DefinePattern(); }
} // namespace lite
} // namespace mindspore

+ 0
- 45
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h View File

@@ -1,45 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H
#define MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H

#include "tools/converter/legacy_optimizer/fusion/conv_activation_fusion_pass.h"
#include <memory>
#include <string>
#include <unordered_map>

namespace mindspore {
namespace lite {
class ConvReluFusionPass : public ConvActivationFusionPass {
public:
ConvReluFusionPass() = default;

~ConvReluFusionPass() override = default;

STATUS DefinePattern() override;

STATUS SetActivationType() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_CONV_RELU_FUSION_PASS_H

+ 0
- 361
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.cc View File

@@ -1,361 +0,0 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved.
* Description: mslite
* Author: mslite
* Create: 2019-12-13
*/

#include <cfloat>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h"
#include "securec/include/securec.h"
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"
#include "src/common/op_utils.h"
#include "tools/common/graph_util.h"
#include "tools/common/tensor_util.h"

namespace mindspore {
namespace lite {

#define CONV_SCALE_BIAS_MATCH_PATH_LEN 2

// 1. generate biasTensor according to BN weightTensor
// 2. change attr of conv
// 3. delete BN node
STATUS ConvScaleBiasFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
if (matchedPath.size() != CONV_SCALE_BIAS_MATCH_PATH_LEN) {
MS_LOG(ERROR) << "Conv-Scale-Bias-Fusion should have two NodeIndex in matchedPair";
return RET_PARAM_INVALID;
}

auto convPath = matchedPath[kConvName];
MS_ASSERT(convPath != nullptr);
auto dstPath = matchedPath[DST_NAME];
MS_ASSERT(dstPath != nullptr);
MS_ASSERT(subGraph != nullptr);
auto &convNode = graph->nodes.at(convPath->nodeIdx);
MS_ASSERT(convNode != nullptr);
auto &dstNode = graph->nodes.at(dstPath->nodeIdx);
MS_ASSERT(dstNode != nullptr);

// 1. generate new weightTensor and biasTensor for conv
auto status = GenConvWeightTensors(graph, convPath, dstPath);
if (RET_OK != status) {
MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status;
return status;
}
if (convNode->inputIndex.size() == CONV_OP_HAS_BIAS_INPUT_NUM) {
status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT),
std::move(this->newWeightTensor));
this->newWeightTensor = nullptr;
if (status != RET_OK) {
MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx
<< ", node: " << convPath->nodeIdx << ", tensor "
<< convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status;
return status;
}
status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_BIAS_INDEX_IN_INPUT),
std::move(this->newBiasTensor));
this->newBiasTensor = nullptr;
if (status != RET_OK) {
MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx
<< ", node: " << convPath->nodeIdx << ", tensor "
<< convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status;
return status;
}
} else if (convNode->inputIndex.size() == CONV_OP_NO_BIAS_INPUT_NUM) {
status = ReplaceTensorOfNode(graph, convPath->nodeIdx, convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT),
std::move(this->newWeightTensor));
this->newWeightTensor = nullptr;
if (status != RET_OK) {
MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx
<< ", node: " << convPath->nodeIdx << ", tensor "
<< convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status;
return status;
}
status = AddTensor2Node(graph, convPath->nodeIdx, std::move(this->newBiasTensor));
this->newBiasTensor = nullptr;
if (status != RET_OK) {
MS_LOG(ERROR) << "ReplaceTensorOfNode failed, subGraph: " << convPath->subGraphIdx
<< ", node: " << convPath->nodeIdx << ", tensor "
<< convNode->inputIndex.at(CONV_OP_FILTER_INDEX_IN_INPUT) << ", error: " << status;
return status;
}
// if (convNode->name == "Conv_461") {
// }
// add bias quantParam
// todo use tensor quant param
// if (convNode->quantParam.size() == convNode->inputIndex.size() + convNode->outputIndex.size() - 1) {
// std::unique_ptr<QuantParamArrayT> quantParamArray(new QuantParamArrayT());
// if (quantParamArray == nullptr) {
// MS_LOG(ERROR) << "new QuantParamArrayT failed";
// return RET_ERROR;
// }
// std::unique_ptr<QuantParamT> quantParam(new QuantParamT());
// if (quantParam == nullptr) {
// MS_LOG(ERROR) << "new QuantParamT failed";
// return RET_ERROR;
// }
// quantParam->numBits = -1;
// quantParam->scale = FLT_MAX;
// quantParam->zeroPoint = 0;
// quantParam->narrowRange = true;
// quantParam->min = FLT_MAX;
// quantParam->max = FLT_MAX;
// quantParamArray->param.emplace_back(quantParam.release());
// convNode->quantParam.emplace_back(quantParamArray.release());
// }
} else {
MS_LOG(ERROR) << "Conv node should has 2 or 3 weight tensors rather than " << convNode->inputIndex.size();
return RET_ERROR;
}

// 2. change attr of conv
if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) {
convNode->primitive->value.AsConv2D()->hasBias = true;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
convNode->primitive->value.AsDepthwiseConv2D()->hasBias = true;
} else {
MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type;
return RET_ERROR;
}

// 3. delete DST node
MergeNodeAttrFromPost(convNode, dstNode);
status = IsolateOneWayNode(graph, dstPath->nodeIdx);
if (status != RET_OK) {
MS_LOG(ERROR) << "IsolateOneWayNode failed, node: " << dstPath->nodeIdx << ", error: " << status;
return status;
}

return RET_OK;
}

STATUS ConvScaleBiasFusionPass::GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath,
std::shared_ptr<Path> dstPath) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(convPath != nullptr);
MS_ASSERT(dstPath != nullptr);
MS_ASSERT(subGraph != nullptr);
auto &convNode = graph->nodes.at(convPath->nodeIdx);
MS_ASSERT(convNode != nullptr);
int32_t kernelNum = -1;
if (convNode->primitive->value.type == schema::PrimitiveType_Conv2D) {
kernelNum = convNode->primitive->value.AsConv2D()->channelOut;
} else if (convNode->primitive->value.type == schema::PrimitiveType_DepthwiseConv2D) {
kernelNum = convNode->primitive->value.AsDepthwiseConv2D()->channelMultiplier *
convNode->primitive->value.AsDepthwiseConv2D()->channelIn;
} else {
MS_LOG(ERROR) << "Unsupported opType, " << convNode->primitive->value.type;
return RET_ERROR;
}
if (kernelNum <= 0) {
MS_LOG(ERROR) << "KernelNum should be positive, " << kernelNum;
return RET_ERROR;
}

this->transScale = new (std::nothrow) float[kernelNum];
this->transBias = new (std::nothrow) float[kernelNum];

if (transScale == nullptr) {
MS_LOG(ERROR) << "new transScale failed";
return RET_ERROR;
}

if (transBias == nullptr) {
MS_LOG(ERROR) << "new transBias failed";
return RET_ERROR;
}

if (0 != memset_s(transScale, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memset transScale failed";
return RET_ERROR;
}

if (0 != memset_s(transBias, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memset transBias failed";
return RET_ERROR;
}

auto status = GetTransParam(graph, dstPath, kernelNum);
if (RET_OK != status) {
MS_LOG(ERROR) << "GetTransParam failed, " << status;
return status;
}

status = CalConvWeightTensors(graph, convPath, kernelNum);
if (RET_OK != status) {
MS_LOG(ERROR) << "GenConvWeightTensors failed, " << status;
return status;
}
return RET_OK;
}

STATUS ConvScaleBiasFusionPass::CalNewWeightTensor(TensorT *oldWeightTensor, const int32_t kernelNum,
const size_t kernelSize) {
MS_ASSERT(oldWeightTensor != nullptr);
auto weightData = reinterpret_cast<float *>(oldWeightTensor->data.data());
size_t kernelDataCount = kernelNum * kernelSize;
if (kernelDataCount == 0) {
MS_LOG(ERROR) << "KernelDataCount should be positive, " << kernelDataCount;
return RET_ERROR;
}
this->newWeightData = new (std::nothrow) float[kernelDataCount];
if (newWeightData == nullptr) {
MS_LOG(ERROR) << "new newWeightData failed";
return RET_ERROR;
}

if (0 != memset_s(newWeightData, kernelDataCount * sizeof(float), 0, kernelDataCount * sizeof(float))) {
MS_LOG(ERROR) << "memset newWeightData failed";
return RET_ERROR;
}

for (size_t i = 0; i < kernelNum; i++) {
for (size_t j = 0; j < kernelSize; j++) {
newWeightData[i * kernelSize + j] = weightData[i * kernelSize + j] * transScale[i];
}
}
auto newCharWeightData = reinterpret_cast<uint8_t *>(newWeightData);
std::vector<uint8_t> tmpWeightVec(newCharWeightData,
newCharWeightData + kernelDataCount * sizeof(float) / sizeof(uint8_t));

this->newWeightTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT);
if (this->newWeightTensor == nullptr) {
MS_LOG(ERROR) << "new newWeightTensor failed";
return RET_ERROR;
}
this->newWeightTensor->dims.insert(this->newWeightTensor->dims.begin(), oldWeightTensor->dims.begin(),
oldWeightTensor->dims.end());
this->newWeightTensor->dataType = oldWeightTensor->dataType;
this->newWeightTensor->format = oldWeightTensor->format;
this->newWeightTensor->refCount = oldWeightTensor->refCount;
this->newWeightTensor->data.swap(tmpWeightVec);
delete (this->newWeightData);
newWeightData = nullptr;

return RET_OK;
}

STATUS ConvScaleBiasFusionPass::CalNewBiasTensor(TensorT *oldWeightTensor, TensorT *oldBiasTensor,
const int32_t kernelNum) {
MS_ASSERT(oldWeightTensor != nullptr);
this->newBiasData = new (std::nothrow) float[kernelNum];
if (newBiasData == nullptr) {
MS_LOG(ERROR) << "new newBiasData failed";
return RET_ERROR;
}
if (0 != memset_s(newBiasData, kernelNum * sizeof(float), 0, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memset newBiasData failed";
return RET_ERROR;
}

if (oldBiasTensor != nullptr) {
auto *biasData = reinterpret_cast<float *>(oldBiasTensor->data.data());

for (size_t i = 0; i < kernelNum; i++) {
this->newBiasData[i] = biasData[i] * transScale[i] + transBias[i];
}
} else {
if (0 != memcpy_s(newBiasData, kernelNum * sizeof(float), transBias, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s newBiasData failed";
return RET_ERROR;
}
}
auto *newCharBiasData = reinterpret_cast<uint8_t *>(newBiasData);
std::vector<uint8_t> tmpBiasVec(newCharBiasData, newCharBiasData + kernelNum * sizeof(float) / sizeof(uint8_t));

this->newBiasTensor = std::unique_ptr<TensorT>(new (std::nothrow) TensorT);
if (this->newBiasTensor == nullptr) {
MS_LOG(ERROR) << "new newBiasTensor failed";
return RET_ERROR;
}
// todo biasShape
this->newBiasTensor->dims = {kernelNum};
this->newBiasTensor->dataType = oldWeightTensor->dataType;
this->newBiasTensor->format = oldWeightTensor->format;
this->newBiasTensor->refCount = oldWeightTensor->refCount;
this->newBiasTensor->data.swap(tmpBiasVec);
delete (this->newBiasData);
newCharBiasData = nullptr;
newBiasData = nullptr;
return RET_OK;
}

STATUS ConvScaleBiasFusionPass::CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath,
int32_t kernelNum) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(convPath != nullptr);

auto convNode = graph->nodes.at(convPath->nodeIdx).get();
MS_ASSERT(convNode != nullptr);
auto convWeightTensorIdxes = convNode->inputIndex;
convWeightTensorIdxes.erase(convWeightTensorIdxes.begin());

TensorT *weightTensor = nullptr;
TensorT *biasTensor = nullptr;
if (convWeightTensorIdxes.size() == CONV_OP_NO_BIAS_WEIGHT_NUM) {
weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get();
} else if (convWeightTensorIdxes.size() == CONV_OP_HAS_BIAS_WEIGHT_NUM) {
weightTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_FILTER_INDEX_IN_WEIGHT]).get();
biasTensor = graph->allTensors.at(convWeightTensorIdxes[CONV_OP_BIAS_INDEX_IN_WEIGHT]).get();
} else {
MS_LOG(ERROR) << "Conv2D should has " << CONV_OP_NO_BIAS_WEIGHT_NUM << " or " << CONV_OP_HAS_BIAS_WEIGHT_NUM
<< " weight tensors, current number of weight tensors " << convWeightTensorIdxes.size();
return RET_ERROR;
}
if (weightTensor == nullptr) {
MS_LOG(ERROR) << "Conv2D's weight tensor is nullptr";
return RET_ERROR;
}

auto weightShape = weightTensor->dims;
if (weightShape.size() != CONV_FILTER_SHAPE_SIZE) {
MS_LOG(ERROR) << "Size of dims of weight tensor should be " << CONV_FILTER_SHAPE_SIZE << " rather than "
<< weightShape.size();
return RET_ERROR;
}
size_t kernelSize = GetShapeSize(*weightTensor) / kernelNum;

// cal new weightData
auto status = CalNewWeightTensor(weightTensor, kernelNum, kernelSize);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalNewWeightTensor error " << status;
return status;
}
// cal new biasData
status = CalNewBiasTensor(weightTensor, biasTensor, kernelNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalNewBiasTensor error " << status;
return status;
}
return RET_OK;
}

STATUS ConvScaleBiasFusionPass::Run(schema::MetaGraphT *graph) { return FusionPass::Run(graph); }

ConvScaleBiasFusionPass::~ConvScaleBiasFusionPass() {
if (this->transScale != nullptr) {
delete (this->transScale);
}
if (this->transBias != nullptr) {
delete (this->transBias);
}
if (this->newWeightData != nullptr) {
delete (this->newWeightData);
}
if (this->newBiasData != nullptr) {
delete (this->newBiasData);
}
}

} // namespace lite
} // namespace mindspore

+ 0
- 67
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h View File

@@ -1,67 +0,0 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2018-2019. All rights reserved.
* Description: mslite
* Author: mslite
* Create: 2019-12-13
*/

#ifndef MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H
#define MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H

#include <string>
#include <unordered_map>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/fusion_pass.h"

namespace mindspore {
namespace lite {
struct BNWeightTensors {
schema::TensorT *meanTensor = nullptr;
schema::TensorT *varianceTensor = nullptr;
schema::TensorT *scaleTensor = nullptr;
schema::TensorT *biasTensor = nullptr;
};

class ConvScaleBiasFusionPass : public FusionPass {
public:
ConvScaleBiasFusionPass() = default;

~ConvScaleBiasFusionPass() override;

STATUS DefinePattern() override = 0;

// 1. generate biasTensor according to BN weightTensor
// 2. change attr of conv
// 3. delete BN node
STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;

protected:
// call GetTransParam() and CalConvWeightTensors()
STATUS GenConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath,
std::shared_ptr<Path> dstPath);

// fill this->transScale and this->transBias
virtual STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> dstPath, int32_t kernelNum) = 0;

// fill this->newWeightTensor and this->newBiasTensor according to this->transScale and this->transBias
STATUS CalConvWeightTensors(schema::MetaGraphT *graph, const std::shared_ptr<Path> &convPath, int32_t kernelNum);

STATUS CalNewWeightTensor(schema::TensorT *oldWeightTensor, int32_t kernelNum, size_t kernelSize);

STATUS CalNewBiasTensor(schema::TensorT *oldWeightTensor, schema::TensorT *oldBiasTensor, int32_t kernelNum);

protected:
float *transScale = nullptr;
float *transBias = nullptr;
float *newWeightData = nullptr;
float *newBiasData = nullptr;
std::unique_ptr<schema::TensorT> newWeightTensor = nullptr;
std::unique_ptr<schema::TensorT> newBiasTensor = nullptr;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_CONV_SCALE_BIAS_FUSION_PASS_H

+ 0
- 126
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.cc View File

@@ -1,126 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>
#include <unordered_map>
#include <memory>
#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h"
#include "securec/include/securec.h"
#include "utils/log_adapter.h"
#include "include/errorcode.h"
#include "schema/inner/model_generated.h"

namespace mindspore {
namespace lite {
#define SCALE_OP_NO_BIAS_WEIGHT_NUM 1
#define SCALE_OP_HAS_BIAS_WEIGHT_NUM 2

#define SCALE_OP_SCALE_INDEX_IN_WEIGHT 0
#define SCALE_OP_BIAS_INDEX_IN_WEIGHT 1

STATUS ConvScaleFusionPass::DefinePattern() {
auto convOp = std::make_shared<PatternOp>();
convOp->id = kConvName;
convOp->types = {schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D};
auto scaleOp = std::make_shared<PatternOp>();
scaleOp->id = DST_NAME;
scaleOp->types = {schema::PrimitiveType_Scale};
scaleOp->left = convOp;

std::unique_ptr<FusionPattern> fusionPattern(new (std::nothrow) FusionPattern("ConvScaleFusion"));
if (fusionPattern == nullptr) {
MS_LOG(ERROR) << "new fusionPattern failed";
return RET_ERROR;
}
fusionPattern->AddPatternOp(convOp);
fusionPattern->AddPatternOp(scaleOp);
fusionPattern->Finish();

this->patterns.emplace_back(fusionPattern.release());

return RET_OK;
}

STATUS ConvScaleFusionPass::DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
return ConvScaleBiasFusionPass::DoFusion(graph, patternName, matchedPath);
}

STATUS ConvScaleFusionPass::Run(schema::MetaGraphT *graph) { return ConvScaleBiasFusionPass::Run(graph); }

STATUS ConvScaleFusionPass::GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> scalePath,
int32_t kernelNum) {
MS_ASSERT(graph != nullptr);
MS_ASSERT(scalePath != nullptr);

auto scaleNode = graph->nodes.at(scalePath->nodeIdx).get();
MS_ASSERT(scaleNode != nullptr);
auto scaleWeightTensorIdxes = scaleNode->inputIndex;
scaleWeightTensorIdxes.erase(scaleWeightTensorIdxes.begin());

schema::TensorT *scaleTensor = nullptr;
schema::TensorT *biasTensor = nullptr;

if (scaleWeightTensorIdxes.size() == SCALE_OP_NO_BIAS_WEIGHT_NUM) {
scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get();
} else if (scaleWeightTensorIdxes.size() == SCALE_OP_HAS_BIAS_WEIGHT_NUM) {
scaleTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_SCALE_INDEX_IN_WEIGHT]).get();
biasTensor = graph->allTensors.at(scaleWeightTensorIdxes[SCALE_OP_BIAS_INDEX_IN_WEIGHT]).get();
} else {
MS_LOG(ERROR) << "Scale should has %d or %d weight tensors, current number of weight tensors %zu";
// SCALE_OP_NO_BIAS_WEIGHT_NUM, SCALE_OP_HAS_BIAS_WEIGHT_NUM, scaleWeightTensorIdxes.size());
return RET_ERROR;
}

if (scaleTensor == nullptr) {
MS_LOG(ERROR) << "Scale's scale tensor is nullptr";
return RET_ERROR;
}

if (kernelNum != scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float)) {
MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to scale size(%lu)";
//, kernelNum, scaleTensor->data.size() * sizeof(uint8_t) / sizeof(float));
return RET_ERROR;
}

const float *scaleData = reinterpret_cast<float *>(scaleTensor->data.data());

if (0 != memcpy_s(transScale, kernelNum * sizeof(float), scaleData, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s transScale failed";
return RET_ERROR;
}

if (biasTensor != nullptr) {
if (kernelNum != biasTensor->data.size() * sizeof(uint8_t) / sizeof(float)) {
MS_LOG(ERROR) << "conv kernel num %u is expected to be equal to bias size(%lu)";
//, kernelNum, biasTensor->data.size() * sizeof(uint8_t) / sizeof(float));
return RET_ERROR;
}

const float *biasData = reinterpret_cast<float *>(biasTensor->data.data());

if (0 != memcpy_s(transBias, kernelNum * sizeof(float), biasData, kernelNum * sizeof(float))) {
MS_LOG(ERROR) << "memcpy_s transBias failed";
return RET_ERROR;
}
}

return RET_OK;
}
} // namespace lite
} // namespace mindspore



+ 0
- 46
mindspore/lite/tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h View File

@@ -1,46 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H
#define MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H

#include "tools/converter/legacy_optimizer/fusion/conv_scale_bias_fusion_pass.h"
#include <string>
#include <unordered_map>
#include <memory>

namespace mindspore {
namespace lite {
class ConvScaleFusionPass : public ConvScaleBiasFusionPass {
public:
ConvScaleFusionPass() = default;

~ConvScaleFusionPass() override = default;

STATUS DefinePattern() override;

STATUS DoFusion(schema::MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) override;

STATUS Run(schema::MetaGraphT *graph) override;

private:
STATUS GetTransParam(schema::MetaGraphT *graph, std::shared_ptr<Path> scalePath, int32_t kernelNum) override;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_CONV_SCALE_FUSION_PASS_H

+ 2
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt View File

@@ -4,6 +4,8 @@ add_library(graph_pass_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_format_hardcode_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_format_transform_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/unused_node_remove_pass.cc
)

+ 99
- 112
mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.cc View File

@@ -25,147 +25,134 @@

namespace mindspore {
namespace lite {

STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (node->primitive->value.type != PrimitiveType_Eltwise) {
continue;
}
auto node_name = node->name;
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
auto pre_type = schema::PrimitiveType_NONE;
size_t has_trans_count = 0;
auto can_fusion = true;
for (auto input_node_index : input_node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
if (pre_type == schema::PrimitiveType_NONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
pre_type = pre_node->primitive->value.type;
bool EltwiseFormatTransPass::CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node) {
auto input_node_indexes = GetInputNodeIdx(*graph, *node);
pre_type_ = schema::PrimitiveType_NONE;
size_t has_trans_count = 0;
auto can_fusion = true;
for (auto input_node_index : input_node_indexes) {
MS_ASSERT(graph->nodes.size() > input_node_index);
auto &pre_node = graph->nodes.at(input_node_index);
MS_ASSERT(pre_node != nullptr);
if (pre_type_ == schema::PrimitiveType_NONE) {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
pre_type_ = pre_node->primitive->value.type;
has_trans_count++;
}
} else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (pre_type_ != pre_node->primitive->value.type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
} else {
if (pre_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
pre_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (pre_type != pre_node->primitive->value.type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
}
}
}
if (!can_fusion) {
continue;
}
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
auto post_type = schema::PrimitiveType_NONE;
for (auto output_node_index : output_node_indexes) {
MS_ASSERT(graph->nodes.size() > output_node_index);
auto &post_node = graph->nodes.at(output_node_index);
MS_ASSERT(post_node != nullptr);
if (post_type == schema::PrimitiveType_NONE) {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
post_type = post_node->primitive->value.type;
}
if (!can_fusion) {
return false;
}
auto output_node_indexes = GetOutputNodeIdx(*graph, *node);
post_type_ = schema::PrimitiveType_NONE;
for (auto output_node_index : output_node_indexes) {
MS_ASSERT(graph->nodes.size() > output_node_index);
auto &post_node = graph->nodes.at(output_node_index);
MS_ASSERT(post_node != nullptr);
if (post_type_ == schema::PrimitiveType_NONE) {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
post_type_ = post_node->primitive->value.type;
has_trans_count++;
}
} else {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (post_type_ != post_node->primitive->value.type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
} else {
if (post_node->primitive->value.type == schema::PrimitiveType_Nchw2Nhwc ||
post_node->primitive->value.type == schema::PrimitiveType_Nhwc2Nchw) {
if (post_type != post_node->primitive->value.type) {
can_fusion = false;
break;
} else {
has_trans_count++;
}
}
}
}
if (!can_fusion) {
continue;
}
auto total_node_count = input_node_indexes.size() + output_node_indexes.size();
size_t half_count = total_node_count / 2;
if (total_node_count % 2 == 0) {
can_fusion = has_trans_count > half_count;
} else {
can_fusion = has_trans_count >= half_count;
}
if (!can_fusion) {
return false;
}
if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
return false;
}
auto total_node_count = input_node_indexes.size() + output_node_indexes.size();
size_t half_count = total_node_count / 2;
if (total_node_count % 2 == 0) {
can_fusion = has_trans_count > half_count;
} else {
can_fusion = has_trans_count >= half_count;
}
return can_fusion;
}

STATUS EltwiseFormatTransPass::FindOutTransType() {
pre_insert_trans_type_ = kNHWC2NCHW;
post_insert_trans_type_ = kNHWC2NCHW;
if (pre_type_ == PrimitiveType_NONE && post_type_ != PrimitiveType_NONE) {
pre_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
} else if (pre_type_ != PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
} else if (pre_type_ == PrimitiveType_NONE && post_type_ == PrimitiveType_NONE) {
MS_ASSERT(false);
} else {
if (pre_type_ == post_type_) {
MS_LOG(ERROR) << "Unknow error";
return RET_ERROR;
}
if (!can_fusion) {
pre_insert_trans_type_ = pre_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type_ = post_type_ == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
}
return RET_OK;
}

STATUS EltwiseFormatTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
if (node->primitive->value.type != PrimitiveType_Eltwise) {
continue;
}
FormatTransNodeType pre_insert_trans_type = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type = kNHWC2NCHW;
if (pre_type == PrimitiveType_NONE && post_type != PrimitiveType_NONE) {
pre_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
} else if (pre_type != PrimitiveType_NONE && post_type == PrimitiveType_NONE) {
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNHWC2NCHW : kNCHW2NHWC;
} else if (pre_type == PrimitiveType_NONE && post_type == PrimitiveType_NONE) {
auto node_name = node->name;
if (!CanFusion(graph, node)) {
continue;
} else {
if (pre_type == post_type) {
MS_LOG(ERROR) << "Unknow error";
return RET_ERROR;
}
pre_insert_trans_type = pre_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
post_insert_trans_type = post_type == schema::PrimitiveType_Nhwc2Nchw ? kNCHW2NHWC : kNHWC2NCHW;
}
auto ret = FindOutTransType();
if (ret != RET_OK) {
MS_LOG(ERROR) << "FindOutTransType error";
return ret;
}

STATUS status = RET_OK;
auto input_tensor_size = (*iter)->inputIndex.size();
for (auto i = 0; i < input_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type, &status);
iter = InsertFormatTransNode(graph, iter, kBefore, i, pre_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type << "before " << (*iter)->name << " failed";
MS_LOG(ERROR) << "Insert" << pre_insert_trans_type_ << "before " << (*iter)->name << " failed";
return status;
}
}
auto output_tensor_size = (*iter)->outputIndex.size();
for (auto i = 0; i < output_tensor_size; i++) {
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type, &status);
iter = InsertFormatTransNode(graph, iter, kAfter, i, post_insert_trans_type_, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "Insert" << post_insert_trans_type << "Node before " << (*iter)->name << " failed";
MS_LOG(ERROR) << "Insert" << post_insert_trans_type_ << "Node before " << (*iter)->name << " failed";
return status;
}
}
}
return RET_OK;
}

NodeIter EltwiseFormatTransPass::InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter,
InsertPlace place, size_t inoutIdx, FormatTransNodeType nodeType,
STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
if (place == kBefore) {
tileName = existNodeName + "_pre";
} else {
tileName = existNodeName + "_post";
}
auto transNode = std::make_unique<schema::CNodeT>();
transNode->primitive = std::make_unique<schema::PrimitiveT>();

if (nodeType == kNCHW2NHWC) {
transNode->name = "nchw2nhwc_" + tileName + std::to_string(id++);
transNode->primitive->value.type = schema::PrimitiveType_Nchw2Nhwc;
} else {
transNode->name = "nhwc2nchw_" + tileName + std::to_string(id++);
transNode->primitive->value.type = schema::PrimitiveType_Nhwc2Nchw;
}
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode);
}

void EltwiseFormatTransPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }

void EltwiseFormatTransPass::SetFmk(converter::FmkType fmkType) { this->fmkType = fmkType; }
} // namespace lite
} // namespace mindspore

+ 10
- 12
mindspore/lite/tools/converter/legacy_optimizer/graph/eltwise_format_trans_pass.h View File

@@ -17,7 +17,7 @@
#ifndef MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H
#define MINDSPORE_PREDICT_ELTWISE_FORMAT_TRANS_PASS_H

#include "tools/converter/optimizer.h"
#include <memory>
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/format_trans_pass.h"
@@ -25,26 +25,24 @@
namespace mindspore {
namespace lite {

class EltwiseFormatTransPass : public GraphPass {
class EltwiseFormatTransPass : public FormatTransPass {
public:
EltwiseFormatTransPass() : id(0) {}
EltwiseFormatTransPass() : FormatTransPass() {}

~EltwiseFormatTransPass() override = default;

STATUS Run(schema::MetaGraphT *graph) override;

void SetQuantType(QuantType quantType);

void SetFmk(converter::FmkType fmkType);

private:
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
FormatTransNodeType nodeType, STATUS *errorCode);
bool CanFusion(schema::MetaGraphT *graph, const std::unique_ptr<CNodeT> &node);

STATUS FindOutTransType();

private:
size_t id;
QuantType quantType = QuantType_QUANT_NONE;
converter::FmkType fmkType = converter::FmkType_TF;
FormatTransNodeType pre_insert_trans_type_ = kNHWC2NCHW;
FormatTransNodeType post_insert_trans_type_ = kNHWC2NCHW;
schema::PrimitiveType pre_type_ = schema::PrimitiveType_NONE;
schema::PrimitiveType post_type_ = schema::PrimitiveType_NONE;
};
} // namespace lite
} // namespace mindspore


+ 6
- 4
mindspore/lite/tools/converter/legacy_optimizer/graph/format_trans_pass.h View File

@@ -37,16 +37,19 @@ class FormatTransPass : public GraphPass {

void SetFmk(converter::FmkType fmkType);

protected:
NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
FormatTransNodeType nodeType, STATUS *errorCode);

private:
STATUS DoModelInputFormatTrans(schema::MetaGraphT *graph);

STATUS DoNodeInoutFormatTrans(schema::MetaGraphT *graph);

NodeIter InsertFormatTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
FormatTransNodeType nodeType, STATUS *errorCode);
protected:
size_t id = 0;

private:
size_t id;
QuantType quantType = QuantType_QUANT_NONE;
converter::FmkType fmkType = converter::FmkType_TF;
};
@@ -54,4 +57,3 @@ class FormatTransPass : public GraphPass {
} // namespace mindspore

#endif // MINDSPORE_PREDICT_FORMAT_TRANS_PASS_H


+ 213
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.cc View File

@@ -0,0 +1,213 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h"
#include "tools/common/converter_op_utils.h"
#include "utils/log_adapter.h"
#include "src/common/utils.h"

namespace mindspore {
namespace lite {
void WeightFormatHardCodePass::SetQuantType(QuantType quantType) { this->quantType = quantType; }

void WeightFormatHardCodePass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; }

// pre set tensor format
// non quant, filterFormat:
// conv deconv depth dedepth
// caffe K(C/g)HW C(K/g)HW / / // todo with deconvOp
// tf HWCK HWKC HWCK HWKC
// onnx K(C/g)HW C(K/g)HW / /

// awareing quant, filterFormat:
// conv deconv depth dedepth
// onnx KHWC ? CHWK ?
// tf HWCK ? HWCK ?
STATUS WeightFormatHardCodePass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D &&
opType != PrimitiveType_DeDepthwiseConv2D) {
continue;
}
MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = graph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dims.size() == 4 || weightTensor->dims.empty()); // for conv with fakqQuant before weight
STATUS status;
switch (fmkType) {
case converter::FmkType_CAFFE:
status = HardCodeCAFFE(node, weightTensor);
break;
case converter::FmkType_TFLITE:
status = HardCodeTFLITE(node, weightTensor);
break;
case converter::FmkType_ONNX:
status = HardCodeONNX(node, weightTensor);
break;
case converter::FmkType_MS:
status = HardCodeMS(node, weightTensor);
break;
default:
MS_LOG(ERROR) << "Unsupported fmkType: " << fmkType << ", node: " << node->name;
return RET_ERROR;
}
if (status != RET_OK) {
MS_LOG(ERROR) << "Format hardCode faild: " << status << ", node: " << node->name;
return RET_ERROR;
}
}
return RET_OK;
}

STATUS WeightFormatHardCodePass::HardCodeCAFFE(const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<TensorT> &weightTensor) {
MS_ASSERT(node != nullptr);
MS_ASSERT(weightTensor != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
switch (this->quantType) {
case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) {
weightTensor->format = Format_KCHW;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name;
return RET_ERROR;
}
}
return RET_OK;
}

STATUS WeightFormatHardCodePass::HardCodeONNX(const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<TensorT> &weightTensor) {
MS_ASSERT(node != nullptr);
MS_ASSERT(weightTensor != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
switch (this->quantType) {
case QuantType_AwareTraining: {
// sum up from current onnx quant models
if (opType == PrimitiveType_Conv2D) {
weightTensor->format = Format_KHWC;
} else if (opType == PrimitiveType_DepthwiseConv2D) {
weightTensor->format = Format_CHWK;
} else if (opType == PrimitiveType_DeConv2D) {
weightTensor->format = Format_CKHW;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
return RET_ERROR;
}
} break;
case QuantType_QUANT_NONE: {
// conv (K x C/group x kH x kW) group = 1
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
// deconv (C x K/group x kH x kW) group = 1
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) {
weightTensor->format = Format_KCHW;
} else if (opType == PrimitiveType_DeConv2D) {
weightTensor->format = Format_CKHW;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
return RET_ERROR;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name;
return RET_ERROR;
}
}
return RET_OK;
}

STATUS WeightFormatHardCodePass::HardCodeMS(const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<TensorT> &weightTensor) {
MS_ASSERT(node != nullptr);
MS_ASSERT(weightTensor != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
switch (this->quantType) {
case QuantType_AwareTraining: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_HWCK;
} else if (opType == PrimitiveType_DepthwiseConv2D) {
weightTensor->format = Format_CKHW;
} else {
weightTensor->format = schema::Format_HWKC;
}
} break;
case QuantType_QUANT_NONE: {
// sum up from current ms quant models
if (opType == PrimitiveType_Conv2D) {
weightTensor->format = Format_KCHW;
} else if (opType == PrimitiveType_DepthwiseConv2D) {
weightTensor->format = Format_CKHW;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
return RET_ERROR;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(node->quantType) << ", node: " << node->name;
return RET_ERROR;
}
}
return RET_OK;
}

STATUS WeightFormatHardCodePass::HardCodeTFLITE(const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<TensorT> &weightTensor) {
MS_ASSERT(node != nullptr);
MS_ASSERT(weightTensor != nullptr);
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
switch (this->quantType) {
case QuantType_AwareTraining:
case QuantType_PostTraining:
case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CHWK;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
return RET_ERROR;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
return RET_ERROR;
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

+ 52
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_hardcode_pass.h View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H

#include <memory>
#include "tools/converter/converter_flags.h"
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"

namespace mindspore {
namespace lite {
class WeightFormatHardCodePass : public GraphPass {
public:
WeightFormatHardCodePass() = default;

~WeightFormatHardCodePass() override = default;

void SetQuantType(QuantType quantType);

void SetFmkType(converter::FmkType fmkType);

STATUS Run(MetaGraphT *graph) override;

private:
STATUS HardCodeCAFFE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);
STATUS HardCodeTFLITE(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);
STATUS HardCodeONNX(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);
STATUS HardCodeMS(const std::unique_ptr<CNodeT> &node, const std::unique_ptr<TensorT> &weightTensor);

private:
QuantType quantType = QuantType_QUANT_NONE;
converter::FmkType fmkType = converter::FmkType_TF;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_HARDCODE_PASS_H

+ 140
- 0
mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.cc View File

@@ -0,0 +1,140 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h"
#include <queue>
#include "tools/common/node_util.h"
#include "tools/common/converter_op_utils.h"
#include "utils/log_adapter.h"
#include "src/common/utils.h"

namespace mindspore {
namespace lite {
void WeightFormatTransformPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }

void WeightFormatTransformPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; }

void WeightFormatTransformPass::SetDstFormat(Format format) { this->dstFormat = format; }

STATUS WeightFormatTransformPass::Run(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (this->quantType == QuantType_AwareTraining) {
auto status = QuantDataFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
return status;
}
} else {
auto status = NonQuantDataFormatTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status;
return status;
}
}
return RET_OK;
}

STATUS WeightFormatTransformPass::QuantDataFormatTrans(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D) {
continue;
}
MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = graph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT);
STATUS status;
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
Format curDstFormat;
if (this->dstFormat == Format_NUM_OF_FORMAT) {
curDstFormat = Format_KHWC;
} else {
curDstFormat = this->dstFormat;
}
status = TransFilterFormat(weightTensor.get(), curDstFormat);
if (status == RET_OK) {
// node->primitive->value.AsConv2D()->format = schema::Format_NHWC;
weightTensor->format = curDstFormat;
} else {
MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To"
<< EnumNameFormat(curDstFormat) << " failed, node : " << node->name;
// todo(00445839): consider varible weight condition
}
}
}
return RET_OK;
}

STATUS WeightFormatTransformPass::NonQuantDataFormatTrans(MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive != nullptr);
auto opType = node->primitive->value.type;
if (opType != PrimitiveType_Conv2D && opType != PrimitiveType_DepthwiseConv2D && opType != PrimitiveType_DeConv2D &&
opType != PrimitiveType_DeDepthwiseConv2D) {
continue;
}
MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = graph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == DataType_DT_UINT8 || weightTensor->dataType == DataType_DT_FLOAT);
STATUS status;
if (opType == PrimitiveType_Conv2D || opType == PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D) {
Format curDstFormat;
if (this->dstFormat == Format_NUM_OF_FORMAT) {
curDstFormat = Format_KHWC;
} else {
curDstFormat = this->dstFormat;
}
status = TransFilterFormat(weightTensor.get(), curDstFormat);
if (status == RET_OK) {
// node->attr.AsConv2D()->format = Format_NCHW;
weightTensor->format = curDstFormat;
} else {
MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To"
<< EnumNameFormat(curDstFormat) << " failed, node : " << node->name;
// todo(00445839): consider varible weight condition
}
} else { // weight should be CKHW
Format curDstFormat;
if (this->dstFormat == Format_NUM_OF_FORMAT) {
curDstFormat = Format_KHWC;
} else {
curDstFormat = this->dstFormat;
}
status = TransFilterFormat(weightTensor.get(), curDstFormat);
if (status == RET_OK) {
// node->attr.AsDepthwiseConv2D()->format = Format_NCHW;
weightTensor->format = curDstFormat;
} else {
MS_LOG(WARNING) << "TransFilter " << EnumNameFormat(weightTensor->format) << "To"
<< EnumNameFormat(curDstFormat) << " failed, node : " << node->name;
// todo(00445839): consider varible weight condition
}
}
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore

mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.h → mindspore/lite/tools/converter/legacy_optimizer/graph/weight_format_transform_pass.h View File

@@ -14,45 +14,40 @@
* limitations under the License.
*/

#ifndef MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H
#define MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H
#define MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H

#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "utils/log_adapter.h"

namespace mindspore {
namespace lite {
class WeightFormatPass : public NodePass {
class WeightFormatTransformPass : public GraphPass {
public:
WeightFormatPass() = default;
WeightFormatTransformPass() = default;

~WeightFormatPass() override = default;
~WeightFormatTransformPass() override = default;

void SetQuantType(QuantType quantType);

void SetFmkType(converter::FmkType fmkType);

int Run(GraphNode *graphNode) override;
void SetDstFormat(Format format);

private:
// correct weightTensor->Format
int ShapeFormatTrans(GraphNode *graphNode);
STATUS Run(MetaGraphT *graph) override;

// transform weightTensor data and format
// if quant : conv transform dataFormat to NHWC, weight format to HWCK
// if quant : depth transform dataFormat to NCHW, weight format to CKHW
int QuantDataFormatTrans(GraphNode *graphNode);
private:
STATUS QuantDataFormatTrans(MetaGraphT *graph);

// if no quant : transform dataFormat to NCHW, weight format to KCHW/CKHW
int NonQuantDataFormatTrans(GraphNode *graphNode);
STATUS NonQuantDataFormatTrans(MetaGraphT *graph);

private:
QuantType quantType = QuantType_QUANT_NONE;
converter::FmkType fmkType = converter::FmkType_TF;
Format dstFormat = Format_NUM_OF_FORMAT;
};
} // namespace lite
} // namespace mindspore

#endif // MINDSPORE_PREDICT_WEIGHT_FORMAT_PASS_H

#endif // MINDSPORE_LITE_TOOLS_CONVERTER_LEGACY_OPTIMIZER_WEIGHT_FORMAT_TRANSFORM_PASS_H

+ 0
- 3
mindspore/lite/tools/converter/legacy_optimizer/node/CMakeLists.txt View File

@@ -1,3 +0,0 @@
add_library(node_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/weight_format_pass.cc
)

+ 0
- 407
mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc View File

@@ -1,407 +0,0 @@
/**
* Copyright 201+ Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "tools/converter/legacy_optimizer/node/weight_format_pass.h"
#include "tools/common/node_util.h"
#include "tools/common/tensor_util.h"

namespace mindspore {
namespace lite {
int WeightFormatPass::Run(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr);
auto status = ShapeFormatTrans(graphNode);
if (status != 0) {
MS_LOG(ERROR) << "ShapeFormatTrans failed: " << status;
return status;
}
if (this->quantType == QuantType_AwareTraining || this->quantType == QuantType_PostTraining) {
status = QuantDataFormatTrans(graphNode);
if (status != 0) {
MS_LOG(ERROR) << "QuantDataFormatTrans failed: " << status;
return status;
}
} else {
status = NonQuantDataFormatTrans(graphNode);
if (status != 0) {
MS_LOG(ERROR) << "NonQuantDataFormatTrans failed: " << status;
return status;
}
}
return 0;
}

void WeightFormatPass::SetQuantType(QuantType quantType) { this->quantType = quantType; }

void WeightFormatPass::SetFmkType(converter::FmkType fmkType) { this->fmkType = fmkType; }

// pre set tensor format
// non quant, filterFormat:
// conv deconv depth dedepth
// caffe K(C/g)HW C(K/g)HW / /
// tf HWCK HWKC HWCK HWKC
// onnx K(C/g)HW C(K/g)HW / /

// awareing quant, filterFormat:
// conv deconv depth dedepth
// onnx KHWC ? CHWK ?
// tf HWCK ? HWCK ?
int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr);
auto &subGraph = graphNode->subGraph;
auto &node = graphNode->opDef;
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return 0;
}
MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = subGraph->allTensors[weightIndex];
auto &shape = weightTensor->dims;
MS_ASSERT(shape.size() == 4);
if (fmkType == converter::FmkType_CAFFE) {
switch (node->quantType) {
case QuantType_QUANT_NONE: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D ||
opType == schema::PrimitiveType_DeConv2D || opType == schema::PrimitiveType_DeDepthwiseConv2D) {
weightTensor->format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "Invalid opType: " << schema::EnumNamePrimitiveType(opType)
<< ", node: " << node->name.c_str();
return -1;
}
} break;
default: {
MS_LOG(ERROR) << "Invalid quantType: " << schema::EnumNameQuantType(node->quantType)
<< ", node: " << node->name.c_str();
return -1;
}
}
return 0;
} else if (fmkType == converter::FmkType_MS) {
switch (node->quantType) {
case QuantType_AwareTraining: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_HWCK;
} else {
weightTensor->format = schema::Format_HWKC;
}
} break;
case QuantType_QUANT_NONE: {
// conv [filter_height, filter_width, in_channels, out_channels]
// depthwise [filter_height, filter_width, in_channels, channel_multiplier]
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_KCHW;
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(opType) << ", node: " << node->name;
return -1;
}
} break;
default: {
MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str();
return -1;
}
}
return 0;
} else if (fmkType == converter::FmkType_TF) {
switch (node->quantType) {
case QuantType_AwareTraining: {
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_HWCK;
} else {
weightTensor->format = schema::Format_HWKC;
}
} break;
case QuantType_QUANT_NONE: {
// conv [filter_height, filter_width, in_channels, out_channels]
// depthwise [filter_height, filter_width, in_channels, channel_multiplier]
if (opType == schema::PrimitiveType_Conv2D || opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_HWCK;
} else {
weightTensor->format = schema::Format_HWKC;
}
} break;
default: {
MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str();
return -1;
}
}
return 0;
} else if (fmkType == converter::FmkType_TFLITE) {
switch (node->quantType) {
case QuantType_QUANT_NONE:
case QuantType_AwareTraining:
case QuantType_PostTraining: {
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CHWK;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "Unsupported format";
return -1;
}
} break;
default: {
MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str();
return -1;
}
}
MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format;
return 0;
} else if (fmkType == converter::FmkType_ONNX) {
switch (node->quantType) {
case QuantType_AwareTraining: {
// sum up from current onnx quant models
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KHWC;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_CHWK;
} else {
MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str();
return -1;
}
} break;
case QuantType_QUANT_NONE: {
// conv (K x C/group x kH x kW) group = 1
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
// deconv (C x K/group x kH x kW) group = 1
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (opType == schema::PrimitiveType_Conv2D) {
weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) {
weightTensor->format = schema::Format_KCHW;
} else if (opType == schema::PrimitiveType_DeConv2D) {
weightTensor->format = schema::Format_CKHW;
} else {
MS_LOG(ERROR) << "Invalid opType: %d, node: " << opType, node->name.c_str();
return -1;
}
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: %d, node: " << node->quantType, node->name.c_str();
return -1;
}
}
} else {
MS_LOG(ERROR) << "Invalid fmkType: %d, node: " << fmkType, node->name.c_str();
return -1;
}
return 0;
}

// inference needed filterFormat:
// conv deconv depth dedepth
// uint8 KHWC KHWC KHWC KHWC
int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr);
auto &subGraph = graphNode->subGraph;
auto &node = graphNode->opDef;
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return RET_OK;
}

MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
}
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : "
<< (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_CKHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2KHWC);
} else {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
}

} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
}
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW")
<< "To KHWC failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
}
return 0;
}

// inference needed filterFormat:
// conv deconv depth dedepth
// fp32 KCHW CKHW CKHW CKHW
int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr);
auto &subGraph = graphNode->subGraph;
auto &node = graphNode->opDef;
MS_ASSERT(subGraph != nullptr);
MS_ASSERT(node != nullptr);
auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return 0;
}

MS_ASSERT(node->inputIndex.size() >= 2);
auto weightIndex = node->inputIndex.at(1);
MS_ASSERT(subGraph->allTensors.size() > weightIndex);
auto &weightTensor = subGraph->allTensors[weightIndex];
if (weightTensor->dataType != TypeId::kNumberTypeFloat32) {
MS_LOG(ERROR) << "weight tensor data should be float";
// return -1;
}
STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be KCHW
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) {
status = RET_OK;
} else if (weightTensor->format == schema::Format_CHWK) {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter " << ((weightTensor->format == schema::Format_HWCK) ? "HWCK" : "NHWC")
<< "ToKCHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be CKHW
if (fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter HWCKToCKHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tflite
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter HWKCToKCHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeDepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KHWC) {
return 0;
} else if (weightTensor->format == schema::Format_KCHW) { // from caffe
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_HWKC) { // from tf or onnx
status = TransFilterFormat<float>(weightTensor.get(), kHWKC2CKHW);
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsDeDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_CKHW;
} else {
MS_LOG(WARNING) << "TransFilter HWKCToCKHW failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
}
return 0;
}
} // namespace lite
} // namespace mindspore

Loading…
Cancel
Save