|
|
@@ -31,6 +31,7 @@ namespace mindspore { |
|
|
namespace lite { |
|
|
namespace lite { |
|
|
#define MUL_ADD_MATCH_PATH_LEN 2 |
|
|
#define MUL_ADD_MATCH_PATH_LEN 2 |
|
|
#define ADD_OP_BIAS_INDEX 1 |
|
|
#define ADD_OP_BIAS_INDEX 1 |
|
|
|
|
|
#define MUL_OP_INPUT_INDEX 0 |
|
|
#define MUL_OP_BIAS_INDEX 1 |
|
|
#define MUL_OP_BIAS_INDEX 1 |
|
|
#define MUL_OP_INPUT_NUM 2 |
|
|
#define MUL_OP_INPUT_NUM 2 |
|
|
#define ADD_OP_INPUT_NUM 2 |
|
|
#define ADD_OP_INPUT_NUM 2 |
|
|
@@ -60,6 +61,23 @@ STATUS MulAddFusionPass::DefinePattern() { |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool ScaleInputShapeValid(const std::vector<int> &input_shape, const std::vector<int> &scale_shape, |
|
|
|
|
|
const std::vector<int> &offset_shape) { |
|
|
|
|
|
if (input_shape.size() < scale_shape.size() || scale_shape.size() == 0) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
size_t rank_diff = input_shape.size() - scale_shape.size(); |
|
|
|
|
|
for (size_t i = 0; i < scale_shape.size(); ++i) { |
|
|
|
|
|
if (input_shape[i + rank_diff] != scale_shape[i]) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (scale_shape != offset_shape) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
return true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, |
|
|
STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, |
|
|
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { |
|
|
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) { |
|
|
MS_ASSERT(graph != nullptr); |
|
|
MS_ASSERT(graph != nullptr); |
|
|
@@ -79,7 +97,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN |
|
|
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); |
|
|
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); |
|
|
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); |
|
|
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); |
|
|
MS_ASSERT(mulNodeBiasTensor != nullptr); |
|
|
MS_ASSERT(mulNodeBiasTensor != nullptr); |
|
|
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) { |
|
|
|
|
|
|
|
|
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) { |
|
|
// dont fusion, return |
|
|
// dont fusion, return |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
@@ -96,7 +114,11 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN |
|
|
// dont fusion, return |
|
|
// dont fusion, return |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// scale requires scale shape tail sub of input shape, scale shape same as bias shape |
|
|
|
|
|
const auto &mulNodeInputTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_INPUT_INDEX)); |
|
|
|
|
|
if (!ScaleInputShapeValid(mulNodeInputTensor->dims, mulNodeBiasTensor->dims, addNodeBiasTensor->dims)) { |
|
|
|
|
|
return RET_OK; |
|
|
|
|
|
} |
|
|
// convert mul and add to scale |
|
|
// convert mul and add to scale |
|
|
auto status = AddNewScaleNode(graph, mulNode, addNode.get(), addNodeInputIndex.at(ADD_OP_BIAS_INDEX)); |
|
|
auto status = AddNewScaleNode(graph, mulNode, addNode.get(), addNodeInputIndex.at(ADD_OP_BIAS_INDEX)); |
|
|
if (RET_OK != status) { |
|
|
if (RET_OK != status) { |
|
|
|